In [1]:
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms, utils

import numpy as np
import json

import torch
import sys

In [2]:
class itemDataset(Dataset):
    def __init__(self,file_name,mode='train',transform=None):

        if(mode=='test'):
            self.token = {}
            for name in ['nodes','edges','tokens']:
                self.token[name] =  {}
                with open('./token/{0}'.format(name)) as f:
                    for i,line in enumerate(f):
                        self.token[name][line.strip()] = i
        elif(mode=='train'):
            self.token = {}
            for name in ['nodes','edges','tokens']:
                self.token[name] =  {}
                self.token[name]['pad'] = 0
                
        self.read_json(file_name)
        
        if(mode=='train'):
            for name in ['nodes','edges','tokens']:
                with open('./token/{0}'.format(name),'w') as f:
                    for name in self.token[name]:
                        f.write("{0}\n".format(name))
        self.transform = transform

    def read_json(self,file_name):
        def type2id(data,dtype):
            for name in data:
                try:
                    return self.token[dtype][name]
                except:
                    self.token[dtype][name] = len(self.token[dtype])
                    return self.token[dtype][name]
        
        def word2id(data):
            ans = []
            for word in data:
                word = word.lower()
                try:
                    ans.append(self.token['tokens'][word])
                except:
                    self.token['tokens'][word] = len(self.token['tokens'])
                    ans.append(self.token['tokens'][word])
            return ans

        self.data = []
        self.sent = []
        for i,line in enumerate(open(file_name)):
            temp = json.loads(line)
            for j in range(len(temp['nodes'])):
                temp['nodes'][j] = [temp['nodes'][j][0],type2id(temp['nodes'][j][1],'nodes')]
            for j in range(len(temp['edges'])):
                temp['edges'][j] = [temp['edges'][j][0],temp['edges'][j][1],type2id(temp['edges'][j][2],'edges')]
            
            for j in range(len(temp['edges'])):
                self.data.append( temp['edges'][j] )
                self.data[-1].append(i)
            self.sent.append({'tokens':word2id(temp['tokens']),'nodes':temp['nodes']})
            

    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        sample = {}
        sample['edge'] = [self.data[idx][0],self.data[idx][1]]
        sample['sent'] = self.sent[self.data[idx][3]]['tokens']
        sample['nodes'] = self.sent[self.data[idx][3]]['nodes']
        
        sample['label'] = self.data[idx][2]

        if(transforms):
            sample = self.transform(sample)
        return sample

class ToTensor(object):
    def __call__(self,sample):
        sample['sent'] = torch.tensor(sample['sent'],dtype=torch.long)
        sample['sent_len'] = len(sample['sent'])
        sample['label'] = torch.tensor(sample['label'],dtype=torch.long)
        return sample.copy()

In [16]:
traindata = itemDataset('./train.json',transform=transforms.Compose([ToTensor()]))

In [28]:
def collate_fn(sample):
    data = {}
    
    for name in ['sent_len','label']:
        data[name] = torch.tensor([_[name] for _ in sample],dtype=torch.long)
        
    batch_size,sent_len = len(data['sent_len']),data['sent_len'].max().item()
    
    data['sent'] = torch.stack([ torch.cat([ _['sent'],torch.zeros(sent_len-_['sent_len'],dtype=torch.long) ] ) for _ in sample])
    
    
    data['node'] = torch.zeros(batch_size,sent_len,dtype=torch.long)
    
    for i in range(len(sample)):
        for line in sample[i]['nodes']:
            for num in range(line[0][0],line[0][1]):
                data['node'][i][num] = line[1]
    
    data['edge'] = torch.zeros(batch_size,sent_len,dtype=torch.long) 
    for i in range(len(sample)):
        for j,line in enumerate(sample[i]['edge']):
            for num in range(line[0],line[1]):
                data['edge'][i][num] = j+1
    
    return data

In [29]:
data = collate_fn([traindata[0],traindata[30]])

In [19]:
import torch.nn as nn
import torch.nn.functional as F

In [20]:
for name in traindata.token:
    print(name,len(traindata.token[name]))

nodes 16
edges 4
tokens 5096


In [12]:
class Temp:
    def __init__(self):
        pass

In [13]:
args = Temp()
args.hidden_dim = 128

args.num_layer = 1
args.bidirectional = True
args.batch_first = True
args.dropout=0

In [30]:
data

{'sent_len': tensor([47, 58]),
 'label': tensor([1, 2]),
 'sent': tensor([[ 1,  2,  3,  4,  5,  6,  7,  4,  8,  4,  9, 10, 11, 12, 13, 14, 15, 16,
          17,  1, 18,  4, 19, 20, 21, 11, 22, 13, 23, 24, 25, 26, 27, 28, 29, 30,
          21, 11, 31, 13, 32, 33, 28, 29, 34, 35, 36,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0],
         [37, 38, 39, 30, 40,  4, 41, 42, 43, 44, 21, 45, 46, 47, 48,  4, 49,  4,
          50, 30, 51, 52, 53, 54, 55, 56, 57,  4, 58,  4, 59, 50,  4, 60, 54, 61,
          62, 63,  4, 64,  4, 59, 50,  4, 65, 54, 30, 66, 67, 68,  4, 69,  4, 59,
          50,  4, 70, 36]]),
 'node': tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 3, 0, 0, 4, 4, 4, 4, 0, 0, 0,
          0, 3, 3, 3, 0, 1, 1, 0, 0, 0, 2, 0, 0, 3, 3, 3, 1, 0, 0, 2, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 0, 4, 4, 0, 3, 0, 6, 6, 6, 6, 6, 0,
          4, 4, 4, 0, 3, 0, 6, 6, 6, 6, 0, 4, 4, 4, 0, 3, 0, 6, 6, 6, 6, 0, 0, 4,
          

In [62]:
class RNN(nn.Module):
    def __init__(self,token,args):
        super(RNN,self).__init__()
        
        self.word_emb = nn.Embedding(len(token['tokens']),args.hidden_dim,padding_idx=0)
        self.ner_emb = nn.Embedding(len(token['nodes']),args.hidden_dim,padding_idx=0)
        self.edge_emb = nn.Embedding(3,args.hidden_dim,padding_idx=0)
    
        self.rnn = nn.LSTM(
            input_size=args.hidden_dim,
            hidden_size=args.hidden_dim,
            num_layers=args.num_layer,
            batch_first=args.batch_first,
            dropout=args.dropout,
            bidirectional=args.bidirectional
        )
        
        self.hidden_size = args.hidden_dim
        self.num_layer = args.num_layer
        self.batch_first = args.batch_first
        self.dropout = args.dropout
        self.bidirectional = args.bidirectional
        
        self.dense_1 = nn.Linear(2*args.hidden_dim,32)
        self.act_1 = nn.ReLU()
        self.dense_2 = nn.Linear(32,len(token['edges']))
        
        
    def forward(self,data,data_len,data_ner,data_point):
        def pack(seq,seq_length):
            sorted_seq_lengths, indices = torch.sort(seq_length, descending=True)
            _, desorted_indices = torch.sort(indices, descending=False)

            if self.batch_first:
                seq = seq[indices]
            else:
                seq = seq[:, indices]
            packed_inputs = nn.utils.rnn.pack_padded_sequence(seq,
                                                            sorted_seq_lengths.cpu().numpy(),
                                                            batch_first=self.batch_first)

            return packed_inputs,desorted_indices

        def unpack(res, state,desorted_indices):
            padded_res,_ = nn.utils.rnn.pad_packed_sequence(res, batch_first=self.batch_first)

            state = [state[i][:,desorted_indices] for i in range(len(state)) ] 

            if(self.batch_first):
                desorted_res = padded_res[desorted_indices]
            else:
                desorted_res = padded_res[:, desorted_indices]

            return desorted_res,state

        def feat_extract(output,length,mask):
            """
            answer_output: batch*sentence*feat_len
            query_output:  batch*sentence*feat_len
            for simple rnn, we just take the output from 
            """
            if( self.batch_first == False ):
                output = output.transpose(0,1) 

            output = [torch.cat([ output[i][ length[i]-1 ][:self.hidden_size] , 
                                        output[i][0][self.hidden_size:]] , dim=-1 ) for i in range(length.shape[0])]
            output = torch.stack(output,dim=0)

            return output
        #first check for the mask ans the embedding
        mask =  data.eq(0)

        word = self.word_emb(data)
        word = word + self.ner_emb(data_ner)
        word = word + self.edge_emb(data_point)
        
        #query part
        packed_inputs,desorted_indices = pack(word,data_len)
        res, state = self.rnn(packed_inputs)
        query_res,_ = unpack(res, state,desorted_indices)

        #extract the representation of the sentence
        query_result = feat_extract(query_res,data_len.int(),mask)

        output = self.dense_1(query_result)
        output = self.act_1(output)
        output = self.dense_2(output)
        
        return output

In [63]:
model = RNN(traindata.token,args)

In [65]:
output = model(data['sent'],data['sent_len'],data['node'],data['edge'])

In [66]:
criterion = nn.CrossEntropyLoss(reduction='sum')

In [67]:
data['label']

tensor([1, 2])

In [68]:
loss = criterion(output,data['label'])