In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle as pkl
from collections import defaultdict,deque,Counter,OrderedDict
from torch.utils.data import DataLoader

In [2]:
with open("train.p","rb") as f:
    traindict = pkl.load(f)
with open("val.p","rb") as f:
    valdict = pkl.load(f)
with open("test.p","rb") as f:
    testdict = pkl.load(f)

In [3]:
with open('tagset.txt') as f:
    alltags = f.read()

alltags = list(map(lambda strline: strline.split('\t')[1], alltags.split('\n')))
alltags = set(alltags)

In [4]:
word_freq = Counter([word[0] if not word[0].isnumeric() and word[1] in alltags else 'UNK' for word in traindict['tagged_words']])
generic_vocab = ['SOS','EOS','PAD']+list([w for w in word_freq if word_freq[w]>5])
generic_word2id = {}
generic_id2word = {}
for i,word in enumerate(generic_vocab):
    generic_word2id[word] = i
    generic_id2word[i] = word

In [5]:
knowledgebase = defaultdict(deque)
for (word,tag) in traindict['tagged_words']:
    if tag in alltags:
        if word not in knowledgebase[tag]:
            if word in word_freq and word_freq[word]>5:
                knowledgebase[tag].append(word)

In [6]:
for tag in knowledgebase.keys():
    knowledgebase[tag].appendleft('UNK')
    knowledgebase[tag].appendleft('PAD')
    knowledgebase[tag].appendleft('EOS')
    knowledgebase[tag].appendleft('SOS')

In [7]:
tag2id = defaultdict(int)
id2tag = defaultdict(str)
for i, tag in enumerate(alltags):
    tag2id[tag] = i
    id2tag[i] = tag

In [8]:
class PTBDataset(object):
    def __init__(self, instanceDict, word2id):
        self.root = instanceDict['tagged_sents']
        self.word2id = word2id
        self.sents = [[s[0] for s in sentences] for sentences in self.root]
        self.sents.sort(key=lambda x:len(x))
    
    def __len__(self):
        return len(self.root)
    
    def __getitem__(self,idx):
        target_sent = [self.word2id[word] if word in self.word2id else self.word2id['UNK'] for word in self.sents[idx]]
        input_sent = [self.word2id['SOS']] + target_sent
        target_sent.append(self.word2id['EOS'])
        return (torch.as_tensor([input_sent], dtype=torch.long), torch.as_tensor([target_sent], dtype=torch.long))
        

In [9]:
class POSDataset(object):
    def __init__(self, instanceDict, word2id, tag2id):
        self.root = instanceDict['tagged_sents']
        self.tag2id = tag2id
        self.word2id = word2id
        self.root.sort(key=lambda x:len(x))
        self.sents = [[s[0] for s in sentences] for sentences in self.root]
        self.tags = [[s[1] for s in sentences] for sentences in self.root]
    
    def __len__(self):
        return len(self.root)
    
    def __getitem__(self,idx):
        target_labels = [self.tag2id[tag] for tag in self.tags[idx]]
        target_sent = [self.word2id[word] if word in self.word2id else self.word2id['UNK'] for word in self.sents[idx]]
        input_sent = [self.word2id['SOS']]+[self.word2id[word] if word in self.word2id else self.word2id['UNK'] for word in self.sents[idx]]
        target_sent.append(self.word2id['EOS'])
        return (torch.as_tensor([input_sent], dtype=torch.long), torch.as_tensor([target_sent], dtype=torch.long), torch.as_tensor([target_labels], dtype=torch.long))

In [10]:
train_dataset = PTBDataset(traindict,generic_word2id)
val_dataset = POSDataset(valdict,generic_word2id,tag2id)

In [11]:
def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    for t in list_of_tensors:
        padded_tensor = torch.cat([t, torch.tensor([[pad_token]*(max_length - t.size(-1))], dtype=torch.long)], dim = -1)
        padded_list.append(padded_tensor)
    padded_tensor = torch.cat(padded_list, dim=0)
    return padded_tensor
def pad_collate_fn_lm(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    pad_token = 2    
    input_tensor = pad_list_of_tensors(input_list, pad_token)
    target_tensor = pad_list_of_tensors(target_list, pad_token)
    return input_tensor, target_tensor
def pad_collate_fn_pos(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    target_labels = [s[2] for s in batch]
    pad_token_input = 2 
    pad_token_tags = 37
    input_tensor = pad_list_of_tensors(input_list, pad_token_input)
    target_tensor = pad_list_of_tensors(target_list, pad_token_input)
    target_labels = pad_list_of_tensors(target_labels, pad_token_tags)
    return input_tensor, target_tensor, target_labels

In [12]:
len(train_dataset)

38219

In [13]:
train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=pad_collate_fn_lm, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=pad_collate_fn_pos,pin_memory=True)

In [14]:
class LM(nn.Module):
    def __init__(self,vocab_size,hidden_size,token_embedding_size,tag_embedding_size,tags,mask,device):

        super(LM, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.token_embedding_size = token_embedding_size
        self.tag_embedding_size = tag_embedding_size
        self.num_tags = len(tags)+1
        self.tags = tags
        self.mask = mask
        self.device=device
        
        self.token_embedding = nn.Embedding(self.vocab_size,self.token_embedding_size)
        self.lstm = nn.LSTM(self.token_embedding_size, self.hidden_size, num_layers = 1, batch_first = True, bias=False)
        self.tag_linear = nn.Linear(self.tag_embedding_size,self.num_tags,bias=False)
        self.lower_hidden = nn.Linear(self.hidden_size,self.tag_embedding_size,bias=False)
        self.tag_projections = nn.ModuleList([nn.Linear(self.hidden_size,self.vocab_size,bias=False) for i in range(self.num_tags)])
        
        
        
    def forward(self,input_seq):
        if self.training:
            batch_size,sent_len = input_seq.shape[0],input_seq.shape[1]
            h = torch.zeros((1,batch_size,self.hidden_size),device=self.device)
            c = torch.zeros((1,batch_size,self.hidden_size),device=self.device)
            embeddings = self.token_embedding(input_seq) #batch_size, sent_len, embed_size
            outputs = torch.zeros((batch_size,sent_len,self.vocab_size),device=device)
            for idx in range(sent_len):
                embedding_input = embeddings[:,idx,:].view(batch_size,1,self.token_embedding_size)
                _,(h,c) = self.lstm(embedding_input,(h,c))
                h_lower = self.lower_hidden(h.transpose(0,1).view(batch_size,-1)) #batch_size,100
                tag_weights = F.softmax(self.tag_linear(h_lower),dim=-1) #batch_size,num_tags
                #word_distributions = torch.cat([(F.softmax((self.tag_projections[i](h.squeeze(0)))*self.mask[i],dim=-1)).unsqueeze(1) for i in range(self.num_tags)],dim=1)#batch_size,num_tags,vocab_size
                #print(word_distributions[0][:2])
                word_distributions_logits = torch.cat([((self.tag_projections[i](h.squeeze(0)))*self.mask[i]).unsqueeze(1) for i in range(self.num_tags)],dim=1)
                word_distributions = F.softmax(word_distributions_logits,dim=-1)
#                 print(word_distributions_logits[0][:2])
#                 print(torch.max(word_distributions_logits[0][0]))
#                 print(word_distributions[0][:2])
                attended_words = torch.bmm(tag_weights.unsqueeze(1),word_distributions)
                outputs[:,idx,:] = attended_words.squeeze(1)
            return torch.log(outputs)
        elif self.eval:
            with torch.no_grad():
                batch_size,sent_len = input_seq.shape[0],input_seq.shape[1]
                h = torch.zeros((1,input_seq.shape[0],self.hidden_size),device=self.device)
                c = torch.zeros((1,input_seq.shape[0],self.hidden_size),device=self.device)
                embeddings = self.token_embedding(input_seq) #batch_size, sent_len, embed_size
                pred_tag = torch.zeros((batch_size,sent_len),device=device)
                pred_word = torch.zeros((batch_size,sent_len,self.vocab_size),device=device)
                for idx in range(sent_len):
                    embedding_input = embeddings[:,idx,:].view(batch_size,1,self.token_embedding_size)
                    _,(h,c) = self.lstm(embedding_input,(h,c))
                    h_lower = self.lower_hidden(h.transpose(0,1).view(batch_size,-1)) #batch_size,100
                    tag_weights = F.softmax(self.tag_linear(h_lower),dim=-1) #batch_size,num_tags
                    pred_tag[:,idx] = torch.argmax(tag_weights[:,:-1],dim=-1) 
                    
                    word_distributions = torch.cat([(F.softmax((self.tag_projections[i](h.squeeze(0)))*self.mask[i],dim=-1)).unsqueeze(1) for i in range(self.num_tags)],dim=1)#batch_size,tag_vocab_size
                    attended_words = torch.bmm(tag_weights.unsqueeze(1),word_distributions)
                    pred_word[:,idx,:] = attended_words.squeeze(1)
            
                return pred_tag,torch.log(pred_word)

In [15]:
VOCAB_SIZE = len(generic_vocab)
HIDDEN_SIZE = 512
EMBEDDING_SIZE = 256
TAG_EMBEDDING_SIZE = 128
device = torch.device('cuda:2')
mask = torch.zeros(len(knowledgebase.keys())+1,VOCAB_SIZE,device=device)
for i,tag in enumerate(knowledgebase.keys()):
    idx = [generic_word2id[word] for word in knowledgebase[tag]]
    mask[i,idx] = 1
mask[-1] = 1
mask[mask==0] = -10000
lang_model = LM(vocab_size=VOCAB_SIZE, 
                hidden_size=HIDDEN_SIZE,
                token_embedding_size=EMBEDDING_SIZE,
                tag_embedding_size=TAG_EMBEDDING_SIZE,
                tags = knowledgebase.keys(), 
                mask = mask,
               device = device).to(device)
criterion = nn.NLLLoss(ignore_index=2)
optimizer = optim.Adam(lang_model.parameters())

In [16]:
def evaluate(model,val_loader,criterion,device):
    model.eval()
    token_acc = 0
    total_tokens = 0
    sent_acc = 0
    total_sent = 0
    val_nll = 0
    for batch,(input_seq,target_seq,target_labels) in enumerate(val_loader):
        input_seq = input_seq.to(device)
        target_labels = target_labels.to(device)
        target_seq = target_seq.to(device)
        pred_labels,pred_words = model(input_seq)
        batch_size,sent_len = input_seq.shape[0],input_seq.shape[1]
        for i in range(batch_size):
            word_correct = 0
            total_sent+=1
            sent_len = 0
            for j in range(input_seq.shape[1]-1):
                if target_labels[i,j]!=37:
                    #print(target_labels[i,j])
                    sent_len+=1
                    if pred_labels[i,j].long()==target_labels[i,j]:
                        word_correct+=1
            total_tokens+=sent_len
            token_acc+=word_correct
            if sent_len==word_correct:
                sent_acc+=1
        loss = 0
        for i in range(input_seq.shape[1]):
            loss+= criterion(pred_words[:,i,:],target_seq[:,i])
        val_nll+=loss.item()
    token_acc = float(token_acc)/float(total_tokens)
    sent_acc = float(sent_acc)/float(total_sent)
    val_nll/=total_tokens
    return token_acc,sent_acc,val_nll

In [17]:
def train(model,train_loader,val_loader,optimizer,criterion,device,num_epochs=10):
    c_point = 238*2
    for epoch in range(num_epochs):
        train_loss = 0
        total_tokens = 0
        for batch,(input_seq,target_seq) in enumerate(train_loader):
            model.train()
            input_seq = input_seq.to(device)
            target_seq = target_seq.to(device)
            optimizer.zero_grad()
            pred_seq = model(input_seq)
            loss = 0
            for i in range(input_seq.shape[1]):
                loss+= criterion(pred_seq[:,i,:],target_seq[:,i])
            #print(loss)
            train_loss+=loss.item()
            total_tokens+=(input_seq.shape[1]*input_seq.shape[0])
            loss/=(input_seq.shape[1]*input_seq.shape[0])
            loss.backward()
            optimizer.step()
            if (batch+1)%c_point==0:
                token_acc,sent_acc, val_loss = evaluate(model,val_loader,criterion,device)
                print('epoch: {} | step: {}/{} | train loss: {} | token acc: {} | sent acc: {} | val loss: {}'.format(epoch+1,
                                                                                                                      (batch+1)//c_point,
                                                                                                                      len(train_loader)//c_point,
                                                                                                                      round(train_loss/total_tokens,3),
                                                                                                                      round(token_acc,3),
                                                                                                                      round(sent_acc,3),
                                                                                                                      round(val_loss,3)))
                train_loss = 0
                total_tokens = 0
    

In [None]:
train(lang_model,train_loader,val_loader,optimizer,criterion,device)

epoch: 1 | step: 1/5 | train loss: 0.319 | token acc: 0.072 | sent acc: 0.0 | val loss: 0.522
epoch: 1 | step: 2/5 | train loss: 0.304 | token acc: 0.001 | sent acc: 0.0 | val loss: 0.417
epoch: 1 | step: 3/5 | train loss: 0.294 | token acc: 0.008 | sent acc: 0.0 | val loss: 0.36
epoch: 1 | step: 4/5 | train loss: 0.284 | token acc: 0.012 | sent acc: 0.0 | val loss: 0.309


In [23]:
torch.tensor([[1,2],[3,4]])*torch.tensor([1,2])

tensor([[1, 4],
        [3, 8]])

In [18]:
len(train_loader)

2389