In [34]:
#!nvidia-smi

In [1]:
import csv
import re
import numpy as np
import pickle
import torch
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def categorical(prob, n_samples):
    """
    sample a categorical distribution from a vect of probabilities
    """
    prob = prob.unsqueeze(0).repeat(n_samples, 1)
    cum_prob = torch.cumsum(prob, dim=-1)
    r = torch.rand(n_samples, 1)
    # argmax finds the index of the first True value in the last axis.
    samples = torch.argmax((cum_prob > r).int(), dim=-1)
    return samples.numpy()

In [3]:
def clean_data(data):
    if (len(data) > 0) and (data[-1] == ' '):
        return data[:-1]
    return data

In [4]:
def remove_punctuation(word):
    #print(word, len(word))
    if len(word) == 0:
        return []
    if len(word) == 1:
        return [word]
    
    if word[0] in [',', ':', '!', '(', ')', ';', '.']:
        if not word[-1] in [',', ':', '!', '(', ')', ';', '.']:
            return [word[0]] + remove_punctuation(word[1:])
        else:
            return [word[0]] + remove_punctuation(word[1:-1]) + [word[-1]]
    else:
        if not word[-1] in [',', ':', '!', '(', ')', ';', '.']:
            return [word]
        else:
            return remove_punctuation(word[:-1]) + [word[-1]]

In [5]:
with open("NLP_data/2020.06.03_CHUSJ_Data_PatientID.csv", 'r', encoding = 'ISO-8859-1') as file: #latin-1 delimiter = '\t'
    csvreader = csv.reader((line.replace('\0','').replace('\t-', '').replace('\t', '').replace(' \x19', "'") for line in file))
    notes = []
    for row in csvreader:
        notes.append(row)

In [7]:
"""
with open("NLP_data/2020.07.06_data_visualization_wnumeric.csv", 'r', encoding='cp1250') as file:
    csvreader = csv.reader((line.replace('\0','') for line in file))
    count = 0
    for row in csvreader:
        if count < 7:
            print(row)
            count += 1
        else:
            break
"""

'\nwith open("NLP_data/2020.07.06_data_visualization_wnumeric.csv", \'r\', encoding=\'cp1250\') as file:\n    csvreader = csv.reader((line.replace(\'\x00\',\'\') for line in file))\n    count = 0\n    for row in csvreader:\n        if count < 7:\n            print(row)\n            count += 1\n        else:\n            break\n'

In [6]:
with open("NLP_data/Labelling Le - 0 to 100.csv", 'r', encoding = 'utf-8') as file:
    csvreader = csv.reader((line.replace('\0','') for line in file))
    labels = []
    for row in csvreader:
        row[2] = clean_data(row[2])
        labels.append(row)

In [7]:
len(labels), len(notes)

(918, 11508)

In [11]:
notes[204]

['101',
 '2463011',
 'T21 avec CAV complet opéré ce jour Défaillance cardiaque sous lasix   Pancreas annulaire - opéré 21/06/2012: laparotomie et duodenoduodenostomie Reflux gastro-oesophagiené.   2/09/2012: Admis 1 jour pour Intoxication lanoxin.  Résolution rapide avec DigiBind x 1  27/1/2013: Admis 1 jour pour tableau de Gastroentérite  Réhydratation',
 'yes']

# Collecting new labels

In [11]:
labels[5]

['1',
 '85',
 'Valeur Saturation Pulsée en Oxygène',
 'Value Saturation Pulsed in Oxygen',
 'saturation habituelle',
 '%']

In [13]:
texts = []
for i in range(100):
    texts.append(notes[2*i+2][2].split())

In [14]:
for i in range(100):
    clean_text = []
    for word in texts[i]:
        cleaned_words = remove_punctuation(word)
        clean_text.append(cleaned_words)
    texts[i] = sum(clean_text, [])

In [15]:
attr_to_class = {"Fraction d'éjection":"Fraction d'éjection",
                "Valeur de la fraction d'éjection en Simson":"Fraction d'éjection",
                "Fraction de raccourcissement": "Fraction de raccourcissement",
                'Fréquence cardiaque : bradycardie':'Fréquence cardiaque',
                'Diamètre Artère Pulmonaire Droite distale':'Diamètre Artère Pulmonaire',
                'Diamètre Artère Pulmonaire Droite proximale': 'Diamètre Artère Pulmonaire',
                'Diamètre Artère Pulmonaire Gauche proximale': 'Diamètre Artère Pulmonaire',
                'Diamètre Artère Pulmonaire Principale': 'Diamètre Artère Pulmonaire',
                'Diamètre Artère Pulmonaire Droite': 'Diamètre Artère Pulmonaire',
                'Diamètre Artère Pulmonaire Gauche': 'Diamètre Artère Pulmonaire',
                'diamètre Artère Pulmonaire': 'Diamètre Artère Pulmonaire',
                'diamètre Artère Pulmonaire Droite': 'Diamètre Artère Pulmonaire',
                'Saturation pulsée en oxygène': 'Saturation en oxygène',
                'Valeur Saturation Pulsée en Oxygène': 'Saturation en oxygène',
                'saturation veineuse en oxygène': 'Saturation en oxygène',
                'Valeur de la Saturation Pulsée en oxygène': 'Saturation en oxygène',
                'saturation artérielle en oxygène': 'Saturation en oxygène',
                'Objectif cible de Saturation en Oxygène': 'Saturation en oxygène',
                'score apgar à 1 minute': 'apgar',
                'score apgar à 10 minutes': 'apgar',
                'score apgar à 5 minutes': 'apgar',
                "score d'apgar (à une minute et cinq minutes)": 'apgar',
                "score d'apgar (à une minute, cinq minutes et 10 minutes)": 'apgar',
                } 

In [16]:
class_to_index = {"Fraction d'éjection": 1,
                  "Fraction de raccourcissement": 2,
                  'Fréquence cardiaque': 3,
                  'Diamètre Artère Pulmonaire': 4,
                  'Saturation en oxygène': 5,
                  'apgar': 6
}

In [17]:
#fixing some splitted tokens
labels[439][1] = '65%'
labels[478][1] = '46%'
labels[815][1] = '27%'
labels[284][1] = '31.3%'
labels[310][1] = '40,7%'
labels[480][1] = '29%'
labels[100][1] = '4,5mm'
labels[175][1] = '5.5mm'
labels[176][1] = '6.6mm'
labels[381][1] = '2.8mm'
labels[382][1] = '2.3mm'
labels[486][1] = '4.9mm'
labels[487][1] = '5.7mm'
labels[655][1] = '18mm'
labels[4][1] = '80-85%'
labels[19][1] = '87%'
labels[165][1] = '85-88%'
labels[166][1] = '75%'
labels[224][1] = '50-65%'
labels[228][1] = '70-75%'
labels[394][1] = '25%'
labels[401][1] = '96%'
labels[529][1] = '80-85%'
labels[543][1] = '85-90%'
labels[550][1] = '65%'
labels[601][1] = '65-85%'
labels[603][1] = '75%'
labels[738][1] = '92%'
labels[13][1] = '8-9-9'
labels[24][1] = '8-9-9'
labels[328][1] = '9-9-10'
labels[340][1] = '8-9'
labels[396][1] = '8-9'
labels[417][1] = '1-2-3'
labels[448][1] = '9-9-9-'
labels[545][1] = '7-9-10'
labels[563][1] = '8-9-9'
labels[572][1] = '7-8-10'
labels[583][1] = '7-8-8'
labels[644][1] = '8-9-9'
labels[674][1] = '9-9-9'
labels[724][1] = '8-9-9'
labels[739][1] = '8/8/9'
labels[752][1] = '5.5.7'
labels[764][1] = '6/7/9'
labels[777][1] = '9/9/9'
labels[845][1] = '6/8/8'
labels[913][1] = '9-9-9'
labels[6][1] = '6,5'
labels[7][1] = '7,1'
labels[538][1] = '1,6'
labels[886][1] = '60-68'

In [18]:
lines_to_ignore = [5, 26, 116, 121, 131, 144, 160, 341, 356, 357, 379, 397, 405, 406, 418, 419, 449, 450, 522, 530, 537, 544, 546, 547, 564, 565, 573, 574, 584, 585, 605, 606, 607, 608, 645, 646, 648, 675, 676, 704, 725, 726, 740, 741, 753, 754, 765, 766, 778, 779, 795, 801, 846, 847, 887, 914, 915]

In [19]:
classes = {k:[] for k in range(100)}
for i in range(len(labels)):
    if labels[i][2] in attr_to_class:
        classe = attr_to_class[labels[i][2]]
        if not (i in lines_to_ignore):
            patient_id = int(labels[i][0])
            classes[patient_id].append((labels[i][1], classe))

In [20]:
pos_classes = {k:np.zeros(len(texts[k]), dtype=int) for k in range(100)}
print("list of doublons")
for i in range(100):
    for (value, classe) in classes[i]:
        if not value == '9 -9-10': #special case to deal with later
            index_class = class_to_index[classe]
            index = texts[i].index(value)
            indices = [k for k in range(len(texts[i])) if texts[i][k]==value]
            if len(indices)> 1:
                print(i, 'value=', value, 'class=', index_class, 'pos=', indices)
            pos_classes[i][index] = index_class
        else:
            index_class = class_to_index[classe]
            index = texts[i].index('-9-10')
            pos_classes[i][index] = index_class
            pos_classes[i][index-1] = index_class

list of doublons
43 value= 7 class= 6 pos= [23, 25]
43 value= 7 class= 6 pos= [23, 25]
52 value= 24 class= 1 pos= [143, 224]
52 value= 23 class= 2 pos= [135, 216]
69 value= 11 class= 4 pos= [136, 182]
69 value= 11 class= 4 pos= [136, 182]
69 value= 10 class= 4 pos= [144, 187]
70 value= 5.5 class= 4 pos= [100, 106]


In [21]:
#dealing with numerical values doublons in single notes
pos_classes[43][23] = 6
pos_classes[43][25] = 6
pos_classes[52][143] = 1
pos_classes[52][224] = 1
pos_classes[52][135] = 2
pos_classes[52][216] = 2
pos_classes[69][136] = 4
pos_classes[69][182] = 4
pos_classes[69][144] = 4
pos_classes[70][100] = 0
pos_classes[70][106] = 4

In [22]:
pos_classes[52]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [23]:
pos_classes[52]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [21]:
#texts[2]

In [24]:
dataset = []
minimal_sentence_size = 7
for i in range(100):
    text = texts[i]
    class_indexes = pos_classes[i]
    #listing the breaking points while avoiding breaks like "Dr. Fournier."
    breaks = [-1]
    for j in range(len(text)):
        if (text[j] == '.') and (j > breaks[-1] + minimal_sentence_size):
            breaks.append(j)
    if breaks[-1]!= len(text)-1:
        breaks.append(len(text)-1)
        
    for j in range(len(breaks)-1):
        sample = {'tokens': text[breaks[j]+1: breaks[j+1]+1], 
                  'classes': class_indexes[breaks[j]+1: breaks[j+1]+1],
                  'extracted_from': i}
        dataset.append(sample)

In [25]:
len(dataset)

451

# Splitting the dataset

In [26]:
class_to_sample = {1:[], 2:[], 3:[], 4:[], 5:[], 6:[]}
for i in range(len(dataset)):
    sample = dataset[i]
    for class_idx in sample['classes']:
        if class_idx != 0:
            class_to_sample[class_idx].append(i)

In [27]:
class_to_sample= {k:np.array(v) for k,v in class_to_sample.items()}
class_to_sample

{1: array([232, 249, 250, 254, 413]),
 2: array([163, 177, 249, 249, 253]),
 3: array([432, 433]),
 4: array([ 12,  13,  55, 104, 105, 211, 211, 261, 261, 286, 288, 330, 330,
        333, 333, 342, 424, 424]),
 5: array([  7,   8,  90,  91, 129, 207, 218, 270, 278, 305, 321, 321, 373]),
 6: array([  8,  21,  46,  49,  60,  67,  87,  98, 129, 143, 143, 158, 179,
        186, 205, 205, 205, 214, 216, 216, 216, 222, 241, 278, 307, 311,
        316, 322, 338, 359, 372, 375, 385, 400, 425, 450])}

In [28]:
mask_arr = categorical(torch.tensor([.15, .15, .7]), len(dataset))

In [29]:
mask_arr

array([1, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 1, 0, 2, 0, 2, 0, 2, 2, 2,
       1, 1, 2, 2, 2, 2, 2, 0, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 1, 2, 1, 2, 0, 2,
       2, 1, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2,
       2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0,
       2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 0,
       2, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 1, 2, 2,
       2, 2, 2, 1, 2, 0, 0, 2, 0, 0, 2, 0, 0, 1, 1, 0, 2, 0, 2, 2, 2, 2,
       0, 2, 2, 2, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 2, 2, 2, 0, 2, 2, 0, 2,
       2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 0,
       2, 0, 2, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 2, 2,
       2, 1, 2, 1, 2, 1, 2, 0, 2, 2, 1, 2, 1, 1, 1, 2, 0, 2, 2, 1, 2, 0,
       0, 1, 1, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 0, 1,
       2, 1, 2, 2, 0, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2,

In [39]:
mask_arr[class_to_sample[6]]

array([0, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, 1, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0])

In [40]:
#balancing the datasets (test, val, train)
mask_arr[232] = 2
mask_arr[253] = 1
mask_arr[163] = 2
mask_arr[261] = 0
mask_arr[424] = 0
mask_arr[91] = 1
mask_arr[129] = 1
mask_arr[87] = 0
mask_arr[98] = 0
mask_arr[359] = 0

In [41]:
train_ds = []
val_ds = []
test_ds = []
for i in range(len(dataset)):
    if mask_arr[i]==0:
        test_ds.append(dataset[i])
    elif mask_arr[i]==1:
        val_ds.append(dataset[i])
    else:
        train_ds.append(dataset[i])

In [42]:
print(len(train_ds), len(val_ds), len(test_ds))

306 64 81


In [43]:
with open("test", "wb") as fp:   #Pickling
   pickle.dump(test_ds, fp)
 
#with open("test", "rb") as fp:   # Unpickling
#   test_ds = pickle.load(fp)

with open("val", "wb") as fp:   #Pickling
   pickle.dump(val_ds, fp)
 
#with open("val", "rb") as fp:   # Unpickling
#   val_ds = pickle.load(fp)

with open("train", "wb") as fp:   #Pickling
   pickle.dump(train_ds, fp)
 
#with open("train", "rb") as fp:   # Unpickling
#   train_ds = pickle.load(fp)

# Blind dataset

In [44]:
tokenizer = AutoTokenizer.from_pretrained('camembert-bio-model')

In [45]:
with open("test", "rb") as fp:   # Unpickling
    test_ds = pickle.load(fp)
 
with open("val", "rb") as fp:   # Unpickling
    val_ds = pickle.load(fp)
 
with open("train", "rb") as fp:   # Unpickling
    train_ds = pickle.load(fp)

blind_test_ds = []
blind_val_ds = []
blind_train_ds = []

for sample in test_ds:
    blind_sample = {k:v for (k,v) in sample.items()}
    for i in range(len(sample['tokens'])):
        if sample['classes'][i] != 0:
            blind_sample['tokens'][i] = 'nombre'
    blind_test_ds.append(blind_sample)
            
for sample in val_ds:
    blind_sample = {k:v for (k,v) in sample.items()}
    for i in range(len(sample['tokens'])):
        if sample['classes'][i] != 0:
            blind_sample['tokens'][i] = 'nombre'
    blind_val_ds.append(blind_sample)
    
for sample in train_ds:
    blind_sample = {k:v for (k,v) in sample.items()}
    for i in range(len(sample['tokens'])):
        if sample['classes'][i] != 0:
            blind_sample['tokens'][i] = 'nombre' 
    blind_train_ds.append(blind_sample)

In [46]:
with open("blind_test", "wb") as fp:   #Pickling
   pickle.dump(blind_test_ds, fp)
 
#with open("blind_test", "rb") as fp:   # Unpickling
#   blind_test_ds = pickle.load(fp)

with open("blind_val", "wb") as fp:   #Pickling
   pickle.dump(blind_val_ds, fp)
 
#with open("blind_val", "rb") as fp:   # Unpickling
#   blind_val_ds = pickle.load(fp)

with open("blind_train", "wb") as fp:   #Pickling
   pickle.dump(blind_train_ds, fp)
 
#with open("blind_train", "rb") as fp:   # Unpickling
#   blind_train_ds = pickle.load(fp)

In [47]:
print(len(blind_train_ds), len(blind_val_ds), len(blind_test_ds))

306 64 81


# FOR LATER, OTHER ATTRIBUTES TO ADD

In [None]:
"""
attr_to_class = {'2ème geste et 2ème pare': 'geste et pare',
                '3ème geste et 3ème pare': 'geste et pare',
                'Abbréviation I pour 1': 'chiffre romain',
                'Abbréviation deux dimensions': 'dimensions',
                'Age (heures)': 'age',
                'Age (semaines)': 'age',
                'Age (jours)': 'age',
                'Date (année)': 'date',
                }
            
#fixing some splitted tokens
labels[213][1] = 'J1'
labels[214][1] = 'J2'
labels[407][1] = 'J7'
labels[421][1] = '8j'
labels[705][1] = 'J6'
labels[732][1] = 'J2'
labels[771][1] = 'J4'
labels[521][1] = '07/2014'
labels[522][1] = '2-3'
labels[531][1] = '2013-04-03'
labels[532][1] = '2013-03-29'
labels[new][1] = '2013-05-15'
labels[540][2] = 'Age (semaines)'

# need space before 2 à 4mois in notes[148] (should not alter note though)
# pay attention to parenthesis when extracting data
# add 2-3 mois as a 'duration' to note[118]
# pay attention to -2010, ... in notes[106]
# missing label for the third date in notes[122]
# remove space in 9 -9-10 to get 9-9-10 in notes[58]
"""

# Creation of the unlabeled dataset for MLM task

In [28]:
texts = []
for i in range(5753):
    texts.append(notes[2*i+2][2].split())

for i in range(5753):
    clean_text = []
    for word in texts[i]:
        cleaned_words = remove_punctuation(word)
        clean_text.append(cleaned_words)
    texts[i] = sum(clean_text, [])
    
dataset = []
minimal_sentence_size = 7
for i in range(5753):
    text = texts[i]
    #listing the breaking points while avoiding breaks like "Dr. Fournier."
    breaks = [-1]
    for j in range(len(text)):
        if (text[j] == '.') and (j > breaks[-1] + minimal_sentence_size):
            breaks.append(j)
    if breaks[-1]!= len(text)-1:
        breaks.append(len(text)-1)
        
    for j in range(len(breaks)-1):
        sample = {'tokens': text[breaks[j]+1: breaks[j+1]+1], 
                  'extracted_from': i}
        dataset.append(sample)

In [29]:
len(dataset)

26166

In [17]:
rand = np.random.rand(len(dataset))
mask_arr = (rand < 0.15)
mlm_train_val_ds = []
mlm_test_ds = []
for i in range(len(dataset)):
    if mask_arr[i]:
        mlm_test_ds.append(dataset[i])
    else:
        mlm_train_val_ds.append(dataset[i])

In [18]:
rand = np.random.rand(len(mlm_train_val_ds))
mask_arr = (rand < 0.15)
mlm_train_ds = []
mlm_val_ds = []
for i in range(len(mlm_train_val_ds)):
    if mask_arr[i]:
        mlm_val_ds.append(mlm_train_val_ds[i])
    else:
        mlm_train_ds.append(mlm_train_val_ds[i])

In [19]:
print(len(mlm_train_ds), len(mlm_val_ds), len(mlm_test_ds))

18868 3339 3959


In [20]:
with open("mlm_test", "wb") as fp:   #Pickling
   pickle.dump(mlm_test_ds, fp)
 
#with open("mlm_test", "rb") as fp:   # Unpickling
#   mlm_test_ds = pickle.load(fp)

with open("mlm_val", "wb") as fp:   #Pickling
   pickle.dump(mlm_val_ds, fp)
 
#with open("mlm_val", "rb") as fp:   # Unpickling
#   mlm_val_ds = pickle.load(fp)

with open("mlm_train", "wb") as fp:   #Pickling
   pickle.dump(mlm_train_ds, fp)
 
#with open("mlm_train", "rb") as fp:   # Unpickling
#   mlm_train_ds = pickle.load(fp)