In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle as pkl
from collections import defaultdict,deque,Counter,OrderedDict
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
from torch.optim import lr_scheduler
import os
import time
import copy
import random

from models import LM_latent, LM_latent_type_rep
from vocab import Vocabulary
from datasets import POSDataset
from utils import pad_list_of_tensors, pad_collate_fn_pos, log_sum_exp
from tqdm import tqdm

from evaluate import *

In [2]:
torch.manual_seed(9090)
cudnn.benchmark = True
batch_size = 64
    
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    device = 'cuda'
else:
    device = 'cpu'

In [3]:
model_left_to_right = '/misc/vlgscratch4/BrunaGroup/rj1408/nlu/ptb_wsj_pos/models/base_lstm_defpar/a/'

#Load forward model
forward_model_params = torch.load(os.path.join(model_left_to_right, 'net_epoch_0.pth'), map_location=device)
forward_model_weights = torch.load(os.path.join(model_left_to_right, 'net_best_weights.pth'), map_location=device)
forward_tag2id = forward_model_params["hyperparams"]["tagtoid"]
forward_id2tag = defaultdict(str)
for tag in forward_tag2id:
    forward_id2tag[forward_tag2id[tag]] = tag
forward_vocab = forward_model_params["hyperparams"]["vocab"]
hidden_size = forward_model_params["hyperparams"]["hidden_size"]
tok_emb_size = forward_model_params["hyperparams"]["token_embedding"]
tag_emb_size = forward_model_params["hyperparams"]["tag_emb_size"]
lstm_layers = forward_model_params["hyperparams"]["lstmLayers"]
dropout_p =  0.1

tag_wise_vocabsize = dict([(forward_tag2id[tup[0]], tup[1][2]) for tup in forward_vocab.tag_specific_vocab.items()])
forward_model = LM_latent(forward_vocab.vocab_size, tag_wise_vocabsize, hidden_size, tok_emb_size, tag_emb_size, lstm_layers, dropout_p)
forward_model.load_state_dict(forward_model_weights)
forward_model = forward_model.to(device)

In [4]:
model_right_to_left = '/misc/vlgscratch4/BrunaGroup/rj1408/nlu/ptb_wsj_pos/models/base_lstm_defpar_reverse/a/'

#Load backward model
backward_model_params = torch.load(os.path.join(model_right_to_left, 'net_best_weights.pth'), map_location=device)
backward_tag2id = backward_model_params["hyperparams"]["tagtoid"]
backward_id2tag = defaultdict(str)
for tag in backward_tag2id:
    backward_id2tag[backward_tag2id[tag]] = tag
backward_vocab = backward_model_params["hyperparams"]["vocab"]
hidden_size = backward_model_params["hyperparams"]["hidden_size"]
tok_emb_size = backward_model_params["hyperparams"]["token_embedding"]
tag_emb_size = backward_model_params["hyperparams"]["tag_emb_size"]
lstm_layers = backward_model_params["hyperparams"]["lstmLayers"]
dropout_p =  0.1

tag_wise_vocabsize = dict([(backward_tag2id[tup[0]], tup[1][2]) for tup in backward_vocab.tag_specific_vocab.items()])
backward_model = LM_latent(backward_vocab.vocab_size, tag_wise_vocabsize, hidden_size, tok_emb_size, tag_emb_size, lstm_layers, dropout_p)
backward_model.load_state_dict(backward_model_params["model_state_dict"])
backward_model = backward_model.to(device)

In [5]:
FOR_UNKNOWN_TAG = forward_tag2id['UNKNOWN']
BACK_UNKNOWN_TAG = backward_tag2id['UNKNOWN']
PAD_TAG_ID = -51

In [6]:
test_pickle_file = '/misc/vlgscratch4/BrunaGroup/rj1408/nlu/ptb_wsj_pos/val.p'

with open(test_pickle_file,"rb") as a:
    testdict = pkl.load(a)

forward_dataset = POSDataset(testdict, forward_vocab, forward_tag2id, forward_id2tag, None, False)
backward_dataset = POSDataset(testdict, backward_vocab, backward_tag2id, backward_id2tag, None, True)
forward_loader = DataLoader(forward_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate_fn_pos, pin_memory=True)
backward_loader = DataLoader(backward_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate_fn_pos, pin_memory=True)

In [7]:
backid_to_forid = {}
for backid in backward_id2tag:
    backtag = backward_id2tag[backid]
    forid = forward_tag2id[backtag]
    backid_to_forid[backid] = forid

In [8]:
def reshapeTensor(backward_prob):
    ans = torch.zeros(backward_prob.shape, dtype=torch.float, device=device, requires_grad=False)
    for j,prob in enumerate(backward_prob):
        ans[backid_to_forid[j]] = prob
    return ans

In [9]:
def convertTensTotag(tens, forward = True):
    ans = []
    for ele in tens:
        tagid = ele.item()
        if forward:
            ans.append(forward_id2tag[tagid])
        else:
            ans.append(backward_id2tag[tagid])
    return ans

In [10]:
def getTokenAccuracy(for_taglogits, for_labels, back_taglogits, back_labels):
    #tag_logits ->  btchsize x sentlen x numtags
    #labels -> btchsize x sentlen
    btchsize = for_labels.shape[0]
    numtags = for_taglogits.shape[2]
    
    #print(convertTensTotag(for_labels[0], True), convertTensTotag(back_labels[0], False))
    
    forward_prob = F.softmax(for_taglogits, dim=-1)
    backward_prob = F.softmax(back_taglogits, dim =-1)
    num = 0
    
    for i,forlabTens in enumerate(for_labels):
        forwardmask = ((forlabTens != PAD_TAG_ID) & (forlabTens != FOR_UNKNOWN_TAG))
        paddingRemovedForward = forlabTens[forwardmask]
        expandedmask = forwardmask.unsqueeze(-1).expand(for_taglogits[0].shape)
        forward_probs_flat = forward_prob[i][expandedmask]
        forward_probs = forward_probs_flat.view(-1, numtags)
        
        backwardmask = ((back_labels[i] != PAD_TAG_ID) & (back_labels[i] != BACK_UNKNOWN_TAG))
        paddingRemovedBackward = back_labels[i][backwardmask]
        expandedmask = backwardmask.unsqueeze(-1).expand(back_taglogits[0].shape)
        backward_probs_flat = backward_prob[i][expandedmask]
        backward_probs = backward_probs_flat.view(-1, numtags)
        
        if paddingRemovedBackward.shape != paddingRemovedForward.shape:
            print(paddingRemovedForward.shape, paddingRemovedBackward.shape)
            print(convertTensTotag(paddingRemovedForward, True), convertTensTotag(paddingRemovedBackward, False))
            print(for_labels.shape, back_labels.shape)
        assert paddingRemovedForward.shape == paddingRemovedBackward.shape
        
        validTokens = paddingRemovedForward.shape[0]
        for j in range(validTokens):
            lab = paddingRemovedForward[j]
            forward_prob_token = forward_probs[j]
            backward_prob_token = reshapeTensor(backward_probs[validTokens - j - 1])
            final_prob_token = (forward_prob_token + backward_prob_token)/2
            num += ((torch.max(final_prob_token, dim=0).indices).item() == lab.item())*1
            
    mask = ((for_labels != FOR_UNKNOWN_TAG) & (for_labels != PAD_TAG_ID))
    den = for_labels[mask].shape[0]

    return num, den

In [11]:
def getSentenceAccuracy(for_taglogits, for_labels, back_taglogits, back_labels):
    #tag_logits ->  btchsize x sentlen x numtags
    #labels -> btchsize x sentlen
    btchsize = for_labels.shape[0]
    numtags = for_taglogits.shape[2]
    
    forward_prob = F.softmax(for_taglogits, dim=-1)
    backward_prob = F.softmax(back_taglogits, dim =-1)
    sentCount = 0
    
    for i,forlabTens in enumerate(for_labels):
        forwardmask = ((forlabTens != PAD_TAG_ID) & (forlabTens != FOR_UNKNOWN_TAG))
        paddingRemovedForward = forlabTens[forwardmask]
        expandedmask = forwardmask.unsqueeze(-1).expand(for_taglogits[0].shape)
        forward_probs_flat = forward_prob[i][expandedmask]
        forward_probs = forward_probs_flat.view(-1, numtags)

        backwardmask = ((back_labels[i] != PAD_TAG_ID) & (back_labels[i] != BACK_UNKNOWN_TAG))
        paddingRemovedBackward = back_labels[i][backwardmask]
        expandedmask = backwardmask.unsqueeze(-1).expand(back_taglogits[0].shape)
        backward_probs_flat = backward_prob[i][expandedmask]
        backward_probs = backward_probs_flat.view(-1, numtags)
        
        validTokens = paddingRemovedForward.shape[0]
        prob_tens_list = []
        
        for j in range(validTokens):
            forward_prob_token = forward_probs[j]
            backward_prob_token = reshapeTensor(backward_probs[validTokens - j - 1])
            final_prob_token = (forward_prob_token + backward_prob_token)/2
            
            prob_tens_list.append(final_prob_token)
        
        predictions = torch.stack(prob_tens_list, dim=0)
        predictions = torch.max(predictions, dim=-1).indices
        result = torch.equal(predictions, paddingRemovedForward)
        sentCount += result*1
        
    mask = ((for_labels != FOR_UNKNOWN_TAG) & (for_labels != PAD_TAG_ID))
    den = for_labels[mask].shape[0]

    return sentCount, den

In [12]:
forward_model.eval()   # Set model to evaluate mode
backward_model.eval()

running_word = 0
running_sent = 0
total_words = 0
total_sents = 0
n_samples = 0

backward_loader = iter(backward_loader)
# Iterate over data.
for batch_num, (for_inputs, for_targets, for_labels) in tqdm(enumerate(forward_loader)):
    
    back_inputs, back_targets, back_labels = next(backward_loader)
    back_inputs = back_inputs.to(device)
    back_targets = back_targets.to(device)
    back_labels = back_labels.to(device)
    
    for_inputs = for_inputs.to(device)
    for_targets = for_targets.to(device)
    for_labels = for_labels.to(device)
    
    if for_labels.shape != back_labels.shape:
        print(batch_num)
        print(for_labels.shape, back_labels.shape)
        print(convertTensTotag(for_labels[0], True), convertTensTotag(back_labels[0], False))
        print(for_inputs.shape, back_inputs.shape)
    assert for_labels.shape == back_labels.shape
    
    batchSize = for_inputs.size(0)
    n_samples += batchSize

    for_outputs = forward_model(for_inputs)
    back_outputs = backward_model(back_inputs)
    
    # statistics
    num, den = getTokenAccuracy(for_outputs[0], for_labels, back_outputs[0], back_labels)
    running_word += num
    total_words += den
    num, den = getSentenceAccuracy(for_outputs[0], for_labels, back_outputs[0], back_labels)
    running_sent += num
    total_sents += den
    
    if batch_num%5==0:
        print("current acc: ", running_word/total_words)
        print("current sent acc: ", running_sent/total_sents)

# Metrics
tokenaccuracy = running_word/total_words
sentaccuracy = running_sent/total_sents

1it [00:06,  6.24s/it]

current acc:  0.5796178343949044
current sent acc:  0.0


6it [00:43,  7.08s/it]

current acc:  0.5967413441955194
current sent acc:  0.0


11it [01:16,  6.62s/it]

current acc:  0.6086408735366686
current sent acc:  5.249619402593312e-05


16it [01:46,  6.19s/it]

current acc:  0.6083345915748151
current sent acc:  0.0002264834667069304


21it [02:18,  6.36s/it]

current acc:  0.6215887363114897
current sent acc:  0.0002027927458137783


26it [02:52,  6.64s/it]

current acc:  0.624073642507574
current sent acc:  0.0002563505010487066


31it [03:24,  6.53s/it]

current acc:  0.6245431322460209
current sent acc:  0.00031440361564157986


36it [03:55,  5.99s/it]

current acc:  0.6178140282480079
current sent acc:  0.00027358845456721727


41it [04:28,  6.62s/it]

current acc:  0.616851932197304
current sent acc:  0.00023937403689352344


46it [05:00,  6.62s/it]

current acc:  0.6173336902690403
current sent acc:  0.0002409316021951546


51it [05:33,  6.55s/it]

current acc:  0.6144979074436945
current sent acc:  0.00021772262138036142


56it [06:05,  6.56s/it]

current acc:  0.6146484891199013
current sent acc:  0.00020923266672539864


61it [06:39,  6.68s/it]

current acc:  0.6163556309581516
current sent acc:  0.00020197123929552432


66it [07:12,  6.50s/it]

current acc:  0.6144792588109316
current sent acc:  0.0002141626705153871


71it [07:44,  6.33s/it]

current acc:  0.6123359853015097
current sent acc:  0.000199330941362037


76it [08:16,  6.29s/it]

current acc:  0.6133946606312818
current sent acc:  0.00019468983475700276


81it [08:48,  6.72s/it]

current acc:  0.6118890959570988
current sent acc:  0.0002053778572243563


86it [09:23,  7.05s/it]

current acc:  0.6126982084387824
current sent acc:  0.00020017443772430262


87it [09:25,  5.75s/it]
