In [None]:
import numpy as np
import torchtext
import random
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch
import torchvision
from collections import defaultdict
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils, datasets
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler

In [None]:
reviews=['No man is an island','Entire of itself',
'Every man is a piece of the continent','part of the main',
'If a clod be washed away by the sea','Europe is the less',
'As well as if a promontory were','As well as if a manor of thy friend',
'Or of thine own were','Any manâ€™s death diminishes me',
'Because I am involved in mankind',
'And therefore never send to know for whom the bell tolls',
'It tolls for thee']
 
labels=[random.randint(0, 1) for i in range(13)]
 
dataset=list(zip(reviews,labels))

In [None]:
class CLRPDataset(Dataset):
    def __init__(self):
        self.text = reviews
        
        #self.text  =  df.review.to_numpy()
        #self.label =  df.sentiment.to_numpy()
        
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self,idx):
        
        encode = self.tokenizer(self.text[idx],
                                truncation=False,
                                return_attention_mask = True,
                                return_token_type_ids=True,
                                padding=False
                                )
        
        return {
            
            'input_ids': encode['input_ids'],
            'attention_mask': encode['attention_mask'],
            'token_type_ids': encode['token_type_ids']
        }
        

In [None]:
class collate:
    
    def __init__(self,config):
        
        self.config = config
        self.seq_dic = defaultdict(int) 
        self.batch_record = defaultdict(list)
        self.bn = 0
        
    def __call__(self,batch):
        
        out = {'input_ids' :[],
               'attention_mask':[],
               'token_type_ids':[],
                'target':[]
            
        }
        
        for i in batch:
            for k,v in i.items():
                out[k].append(v)
                
        if self.config['bucket']:
            max_pad = 0
            
            for p in out['input_ids']:
                if max_pad < len(p):
                    max_pad = len(p)
                    
        else:
            max_pad = self.config['max_len']
        
        
        #self.batch_record[str(self.bn)] = [len(x) for x in out['input_ids']]  
        #self.seq_dic[str(self.bn)] = max_pad
        #self.bn+=1
        
        
        for i in range(len(batch)):
            input_ids = out['input_ids'][i]
            attention_mask = out['attention_mask'][i]
            token_type_ids = out['token_type_ids'][i]
            
            str_len = len(input_ids)
            out['input_ids'][i] = (out['input_ids'][i] +[1]*(max_pad - str_len))[:max_pad]
            out['attention_mask'][i] = (out['attention_mask'][i] + [0] * (max_pad - str_len))[:max_pad]
            out['token_type_ids'][i] = (out['token_type_ids'][i] + [0] * (max_pad - str_len))[:max_pad]
            
            
        out['input_ids'] = torch.tensor(out['input_ids'],dtype=torch.long)
        out['attention_mask'] = torch.tensor(out['attention_mask'],dtype=torch.long)
        out['token_type_ids'] = torch.tensor(out['token_type_ids'],dtype=torch.long)  
        return out

In [None]:
config = {
    'bucket':False,
    'max_len':12,
    'batch_size':4
}

train_ds = CLRPDataset()
sequence = collate(config)

train_dataloader = DataLoader(train_ds,
                              batch_size=config['batch_size'],
                             collate_fn=sequence,
                             shuffle=False)


for i,data in enumerate(train_dataloader):
    for l in data['input_ids']:
        print(len(l))
    print('**********')

In [None]:
config = {
    'bucket':True,
    'batch_size':4
}

train_ds = CLRPDataset()
sequence = collate(config)

train_dataloader = DataLoader(train_ds,
                              batch_size=config['batch_size'],
                             collate_fn=sequence,
                             shuffle=True)


for i,data in enumerate(train_dataloader):
    print(data)
    #for l in data['input_ids']:
    #    print(len(l))
    #print('**********')
    