In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pickle as pkl
from collections import defaultdict,deque,Counter
from torch.utils.data import DataLoader

In [2]:
with open("/scratch/rj1408/pos_lm/ptb_wsj_pos/train.p","rb") as f:
    traindict = pkl.load(f)
with open("/scratch/rj1408/pos_lm/ptb_wsj_pos/val.p","rb") as f:
    valdict = pkl.load(f)
with open("/scratch/rj1408/pos_lm/ptb_wsj_pos/test.p","rb") as f:
    testdict = pkl.load(f)

In [3]:
len(traindict['tagged_words'])

974254

In [4]:
word_freq = Counter([word[0] if not word[0].isnumeric() else 'UNK' for word in traindict['tagged_words']])
generic_vocab = ['SOS','EOS','PAD']+list([w for w in word_freq if word_freq[w]>5])
generic_word2id = {}
generic_id2word = {}
for i,word in enumerate(generic_vocab):
    generic_word2id[word] = i
    generic_id2word[i] = word

In [5]:
tagged_word2id = defaultdict(dict)
tagged_id2word = defaultdict(dict)
knowledgebase = defaultdict(deque)
pos_sizes = defaultdict(int)
for (word,tag) in traindict['tagged_words']:
    if word not in tagged_word2id[tag]:
        if word in word_freq:
            tagged_word2id[tag][word] = pos_sizes[tag]+4
            tagged_id2word[tag][pos_sizes[tag]+4] = word
            knowledgebase[tag].append(word)
            pos_sizes[tag]+=1

In [6]:
for tag in knowledgebase.keys():
    tagged_word2id[tag]['SOS'] = 0
    tagged_id2word[tag][0] = 'SOS'
    tagged_word2id[tag]['EOS'] = 1
    tagged_id2word[tag][1] = 'EOS'
    tagged_word2id[tag]['PAD'] = 2
    tagged_id2word[tag][2] = 'PAD'
    tagged_word2id[tag]['UNK'] = 3
    tagged_id2word[tag][3] = 'UNK'
    knowledgebase[tag].appendleft('UNK')
    knowledgebase[tag].appendleft('PAD')
    knowledgebase[tag].appendleft('EOS')
    knowledgebase[tag].appendleft('SOS')

In [31]:
tag2id = defaultdict(int)
id2tag = defaultdict(str)
for i, tag in enumerate(tagged_word2id.keys()):
    tag2id[tag] = i
    id2tag[i] = tag

In [7]:
tagged_word2id['.']

{'.': 4, '?': 5, '!': 6, 'SOS': 0, 'EOS': 1, 'PAD': 2, 'UNK': 3}

In [22]:
class PTBDataset(object):
    def __init__(self, instanceDict, word2id, id2word):
        self.root = instanceDict['tagged_sents']
        self.id2word = id2word
        self.word2id = word2id
        self.sents = [[s[0] for s in sentences] for sentences in self.root]
        self.sents.sort(key=lambda x:len(x))
        #self.tags = [[s[1] for s in sentences] for sentences in self.root]
    
    def __len__(self):
        return len(self.root)
    
    def __getitem__(self,idx):
        target_sent = [self.word2id[word] if word in self.word2id else self.word2id['UNK'] for word in self.sents[idx]]
        input_sent = [self.word2id['SOS']] + target_sent
        target_sent.append(self.word2id['EOS'])
        return (torch.as_tensor([input_sent], dtype=torch.long), torch.as_tensor([target_sent], dtype=torch.long))
        

In [33]:
traindict['tagged_sents'][0]

[('Pierre', 'NNP'),
 ('Vinken', 'NNP'),
 (',', ','),
 ('61', 'CD'),
 ('years', 'NNS'),
 ('old', 'JJ'),
 (',', ','),
 ('will', 'MD'),
 ('join', 'VB'),
 ('the', 'DT'),
 ('board', 'NN'),
 ('as', 'IN'),
 ('a', 'DT'),
 ('nonexecutive', 'JJ'),
 ('director', 'NN'),
 ('Nov.', 'NNP'),
 ('29', 'CD'),
 ('.', '.')]

In [56]:
class POSDataset(object):
    def __init__(self, instanceDict, word2id, tag2id):
        self.root = instanceDict['tagged_sents']
        self.tag2id = tag2id
        self.word2id = word2id
        self.root.sort(key=lambda x:len(x))
        self.sents = [[s[0] for s in sentences] for sentences in self.root]
        self.tags = [[s[1] for s in sentences] for sentences in self.root]
    
    def __len__(self):
        return len(self.root)
    
    def __getitem__(self,idx):
        target_labels = [self.tag2id[tag] for tag in self.tags[idx]]
        input_sent = [self.word2id[word] if word in self.word2id else self.word2id['UNK'] for word in self.sents[idx]]
        return (torch.as_tensor([input_sent], dtype=torch.long), torch.as_tensor([target_labels], dtype=torch.long))

In [57]:
train_dataset = PTBDataset(traindict,generic_word2id,generic_id2word)
val_dataset = POSDataset(valdict,generic_word2id,tag2id)

In [59]:
def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    for t in list_of_tensors:
        padded_tensor = torch.cat([t, torch.tensor([[pad_token]*(max_length - t.size(-1))], dtype=torch.long)], dim = -1)
        padded_list.append(padded_tensor)
    padded_tensor = torch.cat(padded_list, dim=0)
    return padded_tensor
def pad_collate_fn_lm(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    pad_token = 2    
    input_tensor = pad_list_of_tensors(input_list, pad_token)
    target_tensor = pad_list_of_tensors(target_list, pad_token)
    return input_tensor, target_tensor
def pad_collate_fn_pos(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    pad_token_input = 2 
    pad_token_tags = 46
    input_tensor = pad_list_of_tensors(input_list, pad_token_input)
    target_tensor = pad_list_of_tensors(target_list, pad_token_tags)
    return input_tensor, target_tensor

In [60]:
train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=pad_collate_fn_lm, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=pad_collate_fn_pos, pin_memory=True)

In [None]:
class LM(nn.Module):
    def __init__(self,vocab_size,hidden_size):
        