In [1]:
import os
import json
import torch
from torch.utils.data import Dataset
import transformers
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import DataCollatorForWholeWordMask
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open('processed_data.json', 'r') as f:
    data = json.load(f)

In [3]:
class LMDataset(Dataset):
    def __init__(
        self,
        list_data,
        tokenizer,
        max_seq_len,
    ):
        
        #
        self.data = list_data
        self.tokenizer = tokenizer
        self.max_len = max_seq_len
    
        return
    
    def __len__(
        self,    
    ):
        return len(self.data)
    
    def __getitem__(
        self,
        idx,
    ):
        #
        example = self.data[idx]
        
        # tokenize
        tokenized = self.tokenizer.encode_plus(
            text=example,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
        )
        
        #
        #tokenized['labels'] = tokenized['input_ids']
        
        return tokenized

class LMDataloader():
    
    def __init__(
        self, 
        dict_data,
        tokenizer,
        mlm_probability,
        max_seq_len,
        batch_size=8,
        validation_size=0.05,
        fixed_seed_val=0,
    ):
        
        # covert dictionary of sentences into a list of grouped sentences
        list_data = self.group_sentences(
            dict_data=dict_data,
            max_len=max_seq_len,
        )
        
        # split the data
        train, val = self.split_data(
            list_data=list_data,
            validation_size=validation_size,
            fixed_seed_val=fixed_seed_val,
        )
        
        # convert into dataset format
        self.dataset = {}
        self.dataset['train'] = LMDataset(
            list_data=train,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
        )
        self.dataset['validation'] = LMDataset(
            list_data=val,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
        )
        
        # define data collator
        collator_obj = DataCollatorForWholeWordMask(
            tokenizer=tokenizer,
            mlm=True,
            mlm_probability=mlm_probability,
            return_tensors='pt',
        )
        
        #
        self.dataloader = {}
        for split in self.dataset:
            self.dataloader[split] = DataLoader(
                self.dataset[split],
                batch_size=batch_size,
                shuffle=False,
                collate_fn=collator_obj,
            )
        
    def group_sentences(
        self,
        dict_data,
        max_len=128,
        sentence_split_ratio=1.3,
    ):
        print('grouping sentences...')
        list_out = []
        cur_sequence = ''
        idx = 0
        for k_, v_ in tqdm(dict_data.items()):
            idx += 1
            new_sequence = cur_sequence + v_['TEXT'] + ' '
            print((len(new_sequence.split(' ')) * sentence_split_ratio))
            if (len(new_sequence.split(' ')) * sentence_split_ratio) >= max_len:
                list_out.append(cur_sequence)
                cur_sequence = v_['TEXT'] + ' '
            else:
                cur_sequence = new_sequence
            #
            if idx > 1000:
                break
        print('done')
        
        return list_out
    
    def split_data(
        self,
        list_data,
        validation_size,
        fixed_seed_val,
    ):
        
        #
        print('splitting data...')
        
        #
        np.random.seed(fixed_seed_val)
        val_indices = np.random.choice(
            range(len(list_data)), 
            size=int(validation_size * len(list_data)),
            replace=False,
        )
        train_indices = list(set(list(range(len(list_data)))) - set(val_indices))
        
        #
        #print((train_indices.__len__(), val_indices.__len__()))
        #print(train_indices)
        print(list_data.__len__())
        #
        val_ = np.array(list_data)[val_indices].tolist()
        train_ = np.array(list_data)[train_indices].tolist()
        
        #
        assert (len(val_) + len(train_)) == len(list_data)
        assert set(val_indices).union(train_indices) == set(list(range(len(list_data))))
        
        print('done...')
        
        return train_, val_

In [None]:

# tokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
#tokenizer.add_tokens(['<sent-sep>'], special_tokens=True)

#
d_ = LMDataloader(
    dict_data=data,
    tokenizer=tokenizer,
    mlm_probability=0.15,
    max_seq_len=128,
    batch_size=8,
    validation_size=0.05,
    fixed_seed_val=0,
)

In [None]:
for split in ['train', 'validation']:
    for batch in d_.dataloader[split]:
        #print(batch.keys())
        print(tokenizer.batch_decode(batch['input_ids'][0]))
        for id_ in batch['labels'][0]:
            if id_.item() != -100:
                print(tokenizer.decode(id_))