In [18]:
from vocab import Vocabulary
import os
from torch.utils.data import DataLoader,Dataset
import pickle as pkl
from collections import defaultdict,deque,Counter,OrderedDict
from models import LM_latent
import torch
import pandas as pd

In [2]:
device = 'cuda:0'

In [3]:
PATH = './Models/hid_512_emb_64/net_epoch_0.pth'
BEST_PATH = './Models/hid_512_emb_64/net_best_weights.pth'

In [4]:
checkpoint = torch.load(PATH)

In [5]:
tag2id = checkpoint['hyperparams']['tagtoid']
vocab = checkpoint['hyperparams']['vocab']
batch_size = 10
hidden_size = checkpoint['hyperparams']['hidden_size']
token_embedding_size = checkpoint['hyperparams']['token_embedding']
tag_embedding_size = checkpoint['hyperparams']['tag_emb_size']
lstm_layers = checkpoint['hyperparams']['lstmLayers']

In [6]:
train_pickle_file = os.path.join('/scratch/rj1408/pos_lm/ptb_wsj_pos', 'train.p')
val_pickle_file = os.path.join('/scratch/rj1408/pos_lm/ptb_wsj_pos', 'val.p')
test_pickle_file = os.path.join('/scratch/rj1408/pos_lm/ptb_wsj_pos', 'test.p')

with open(train_pickle_file,"rb") as a:
    traindict = pkl.load(a)
with open(val_pickle_file,"rb") as a:
    valdict = pkl.load(a)
with open(test_pickle_file,"rb") as a:
    testdict = pkl.load(a)

id2tag = defaultdict(str)
for tag, i in tag2id.items():
    id2tag[i] = tag
    
UNKNOWN_TAG = tag2id['UNKNOWN']
PAD_TAG_ID = -51

In [7]:
class POSDataset(Dataset):
    def __init__(self, instanceDict, vocab, tag2id, id2tag, max_sent_len=60):
        self.root = instanceDict['tagged_sents']
        self.vocab = vocab
        self.tag2id = tag2id
        self.id2tag = id2tag
        
        self.sents = [[s[0] for s in sentences] for sentences in self.root]
        self.input_sents = []
        self.output_sents = []
        self.tags = []
        for sample in self.sents:
            
            if max_sent_len == None:
                mlength = len(sample)
            else:
                mlength = max_sent_len
                
            newsample = [Vocabulary.BOS] + sample[:mlength] + [Vocabulary.EOS]
            input_toks = self.vocab.encode_token_seq(newsample[:-1])
            output_toks = [self.vocab.encode_token_seq_tag(newsample[1:], self.id2tag[tagid]) for tagid in self.id2tag]
            self.input_sents.append(input_toks)
            self.output_sents.append(output_toks)
            
        for sentences in self.root:
            
            if max_sent_len == None:
                mlength = len(sentences)
            else:
                mlength = max_sent_len
            
            outputsample = sentences[:mlength] + [(Vocabulary.EOS, 'UNKNOWN')]
            outputsample = [self.tag2id[tup[1]] if tup[1] in self.tag2id else self.tag2id['UNKNOWN'] for tup in outputsample]
            self.tags.append(outputsample)
        
    def __len__(self):
        return len(self.root)
    
    def __getitem__(self,idx):
        target_tensor = torch.as_tensor(self.tags[idx], dtype=torch.long)
        input_tensor = torch.as_tensor(self.input_sents[idx], dtype=torch.long)
        output_tensor = torch.as_tensor(self.output_sents[idx], dtype=torch.long)
        return (input_tensor, output_tensor, target_tensor)

In [8]:
def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    for t in list_of_tensors:
        padding = torch.zeros(list(t.shape)[:-1] + [max_length - t.size(-1)], dtype=torch.long) + pad_token
        padded_tensor = torch.cat([t, padding], dim = -1)
        padded_list.append(padded_tensor)
    padded_tensor = torch.stack(padded_list)
    return padded_tensor

def pad_collate_fn_pos(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    target_labels = [s[2] for s in batch]
    pad_token_input = 2
    pad_token_output = Vocabulary.PADTOKEN_FOR_TAGVOCAB
    pad_token_tags = PAD_TAG_ID
    input_tensor = pad_list_of_tensors(input_list, pad_token_input)
    target_tensor = pad_list_of_tensors(target_list, pad_token_output)
    target_labels = pad_list_of_tensors(target_labels, pad_token_tags)
    return input_tensor, target_tensor, target_labels


In [9]:
tag_wise_vocabsize = dict([(tag2id[tup[0]], tup[1][2]) for tup in vocab.tag_specific_vocab.items()])

datasets = {}
dataloaders = {}

datasets["train"] = POSDataset(traindict, vocab, tag2id, id2tag)
datasets["valid"] = POSDataset(valdict, vocab, tag2id, id2tag, None)
datasets["test"] = POSDataset(testdict, vocab, tag2id, id2tag, None)

dataloaders["train"] = DataLoader(datasets["train"], batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn_pos, pin_memory=True)
dataloaders["valid"] = DataLoader(datasets["valid"], batch_size=batch_size, shuffle=False, collate_fn=pad_collate_fn_pos, pin_memory=True)
dataloaders["test"] = DataLoader(datasets["test"], batch_size=batch_size, shuffle=False, collate_fn=pad_collate_fn_pos, pin_memory=True)

In [11]:
model = LM_latent(vocab.vocab_size, tag_wise_vocabsize, hidden_size, token_embedding_size, tag_embedding_size, lstm_layers).to(device)

In [12]:
model.load_state_dict(torch.load(BEST_PATH))
model.eval()

LM_latent(
  (token_embedding): Embedding(10004, 64)
  (lstm): LSTM(64, 512, num_layers=3, batch_first=True)
  (tag_linear): Linear(in_features=256, out_features=47, bias=True)
  (lower_hidden): Linear(in_features=512, out_features=256, bias=True)
  (tag_projections): ModuleList(
    (0): Linear(in_features=512, out_features=3, bias=True)
    (1): Linear(in_features=512, out_features=2, bias=True)
    (2): Linear(in_features=512, out_features=9, bias=True)
    (3): Linear(in_features=512, out_features=563, bias=True)
    (4): Linear(in_features=512, out_features=36, bias=True)
    (5): Linear(in_features=512, out_features=28, bias=True)
    (6): Linear(in_features=512, out_features=47, bias=True)
    (7): Linear(in_features=512, out_features=4873, bias=True)
    (8): Linear(in_features=512, out_features=3, bias=True)
    (9): Linear(in_features=512, out_features=9, bias=True)
    (10): Linear(in_features=512, out_features=1988, bias=True)
    (11): Linear(in_features=512, out_features=

In [13]:
def getTagPredictions(tag_logits):
    #targets dim # btchsize x numtags x sentLen
    btch_size = tag_logits.shape[0]
    sent_len = tag_logits.shape[1]
    num_tags = tag_logits.shape[2]
    predictions = torch.max(tag_logits, dim=-1).indices
    return predictions

In [24]:
for batch_num, (inputs, targets, labels) in enumerate(dataloaders["valid"]):
    inputs = inputs.to(device)
    outputs = model(inputs)
    predictions = getTagPredictions(outputs[0])
    break
for i in range(len(predictions)):
    df = pd.DataFrame({'Actual':[id2tag[x] for x in labels.numpy()[i]], 'Predicted':[id2tag[x] for x in predictions.cpu().numpy()[i]]})
    print(df)
    print()

     Actual Predicted
0        DT       NNP
1       NNP        NN
2       NNP        NN
3       NNP        NN
4       VBD       VBD
5        DT        DT
6        CD        JJ
7        NN        NN
8        NN        NN
9        NN        IN
10       IN        IN
11      NNP        DT
12      NNP       NNP
13      NNP       NNP
14      NNP       NNP
15        ,         ,
16       RB        DT
17      JJR        IN
18       IN        IN
19   -NONE-        DT
20      VBN       VBG
21   -NONE-        DT
22       JJ        RP
23       NN        NN
24       IN         .
25       DT        DT
26       NN        NN
27       NN        IN
28       NN        IN
29       CC         .
30       RB        DT
31      PDT       VBN
32       DT        DT
33       NN        NN
34      VBN        IN
35   -NONE-    -NONE-
36       IN        RP
37       DT        DT
38       NN        NN
39        .         .
40  UNKNOWN   UNKNOWN
41            UNKNOWN

     Actual Predicted
0        DT       NNP
1        