In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
from torch.nn.utils.rnn import pad_sequence
import numpy as np

In [18]:
class NERset(Dataset):
    def __init__(self,mode):
        with open("textData.pkl", "rb") as f:
            self.data = pickle.load(f)
            self.train_or_test = mode

    def __getitem__(self, index):
        name = self.data[index]['name']
        char_tokens_tensors = self.data[index]['char_input']
        word_tokens_tensors = self.data[index]['word_input']
        if self.train_or_test=='train':
            label_ids = self.data[index]['char_tag']
        else:
            label_ids = None
        find_sep = np.argwhere(char_tokens_tensors==3)
        
        thesis_or_context=find_sep[0][0].item()+1
        segments_tensor = torch.tensor([0]*thesis_or_context +[1]*(char_tokens_tensors.shape[0]-thesis_or_context))
        
        context_len= char_tokens_tensors.shape[0]
        if context_len >=512:
            sep_tensor = torch.tensor([3])
            other_label = torch.tensor([1]*1+[0]*20,dtype=torch.float)
            other_label = torch.unsqueeze(other_label,0)
            char_tokens_tensors = char_tokens_tensors[0:511]
            word_tokens_tensors = word_tokens_tensors[0:511]
            segments_tensor = segments_tensor[0:512]
            label_ids = label_ids[0:511]
            char_tokens_tensors = torch.cat( (char_tokens_tensors, sep_tensor),0)
            word_tokens_tensors = torch.cat( (word_tokens_tensors, sep_tensor),0)
            label_ids = torch.cat((label_ids,other_label),0)
        return (name,char_tokens_tensors,word_tokens_tensors,segments_tensor,label_ids)

    def __len__(self):
        return len(self.data)
    def create_mini_batch(self,samples):
        # sample[0]:name
        # sample[1]:char
        # sample[2]:word
        # sample[3]:segments
        # sample[4]:mask
        # sample[5]:label
        
        name = [s[0] for s in samples]
        char_tokens_tensors = [s[1] for s in samples]
        word_tokens_tensors = [s[2] for s in samples]
        segments_tensors = [s[3] for s in samples]
        if self.train_or_test=='train':
            label_tensors= [s[4] for s in samples]
        else:
            label_tensors=None
        
        before_pad_length=[]
        for i in range(len(char_tokens_tensors)):
            before_pad_length.append(char_tokens_tensors[i].shape[0])
        
        char_tokens_tensors = pad_sequence(char_tokens_tensors,batch_first=True)
        word_tokens_tensors = pad_sequence(word_tokens_tensors,batch_first=True)
        segments_tensors = pad_sequence(segments_tensors,batch_first=True)
        max_pad_length = char_tokens_tensors[0].shape[0]
        masks_tensors = torch.zeros(char_tokens_tensors.shape,dtype=torch.long)
        masks_tensors = masks_tensors.masked_fill(char_tokens_tensors != 0, 1)
        if self.train_or_test=='train':
            other_label = torch.tensor([1]*1+[0]*20,dtype=torch.float)
            other_label = torch.unsqueeze(other_label,0)
            for i in range(len(before_pad_length)):
                for j in range(max_pad_length-before_pad_length[i]):
                    label_tensors[i] = torch.cat((label_tensors[i],other_label),0)
            r_label_tensors=torch.stack([i for i in (label_tensors)])
        else:
            r_label_tensors=None
            
        return name,char_tokens_tensors,word_tokens_tensors,segments_tensors,masks_tensors,r_label_tensors

In [19]:
if __name__ == "__main__":
    dataset = NERset('train')
    dataloader = DataLoader(dataset, batch_size=2, shuffle=False,collate_fn=dataset.create_mini_batch)
    print(len(dataloader))
    for i,data in enumerate(dataloader):
        if data[1].shape[1]>=512:
            print(data[1][1])
            print(data[3][1])
            #print(data[5])
            break
        
        

2228
tensor([   2,  164, 1137,  240,  683,    3,    1,    1,  607,  544,    8,  393,
         279,   46,   11,  141,    8,  623,  306,   24,  587,  477,    8,  218,
          46,   11,  149, 1003,  183,  231,  210,   18,    6,  570,  573,   15,
         790,    9,   46,   42,    6,   14,   46,   11,   10,    1,    1,    1,
         576,  169,   91,  271,  233,   76,  484, 1164,  411,  779,  230,   28,
          34,  352,  425, 1328,  214,  587,  777,  201,  464,    6,   58,  607,
         544,  393,  279,  141,    6,  587,  477,    8,  218,   46,   11,  240,
         514,   59,    6,   58,    1,  607,  544,    8,  393,  279,   46,   11,
          44,   14,   20,   17,   74,   24,   19,  141,   59,    8, 1508,  191,
          16,   24,   19,  141,   17,   36,   11,   44,   14,   10,    1,    1,
           1,   80,  172,    6,  233,   76, 1588,   12,  694,   76,    8,   94,
          19,   18,   12,    7,  730,  255,    1,    1,   34,    1,    1,   34,
           1,    1,   25,  253,    