In [10]:
import torchtext.data as data
from torchtext.data import BucketIterator

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from torch_struct import LinearChainCRF
#import matplotlib
import matplotlib.pyplot as plt

class ConllXDataset(data.Dataset):
    def __init__(self, path, fields, encoding='utf-8', separator='\t', **kwargs):
        examples = []
        columns = [[], []]
        column_map = {1: 0, 3: 1}
        with open(path, encoding=encoding) as input_file:
            for line in input_file:
                line = line.strip()
                if line == '':
                    examples.append(data.Example.fromlist(columns, fields))
                    columns = [[], []]
                else:
                    for i, column in enumerate(line.split(separator)):
                        if i in column_map:
                            columns[column_map[i]].append(column)
            examples.append(data.Example.fromlist(columns, fields))
        super(ConllXDataset, self).__init__(examples, fields, **kwargs)

#to do: add bos
WORD = data.Field(pad_token=None) # init_token='<bos>', eos_token='<eos>'
POS = data.Field(is_target=True,  include_lengths=True)

fields = (('word', WORD), ('pos', POS), (None, None))
train = ConllXDataset('samIam.conllu', fields)
test = ConllXDataset('test.conllu', fields)

WORD.build_vocab(train)
POS.build_vocab(train)
#print(vars(POS.vocab))

train_iter = BucketIterator(train, batch_size=2, device='cpu', shuffle=False)
test_iter = BucketIterator(test, batch_size=2, device='cpu', shuffle=False)

C = len(POS.vocab.itos)
V = len(WORD.vocab.itos)
C, V

(6, 6)

In [11]:
print(vars(POS.vocab))

{'freqs': Counter({'PRON': 7, 'AUX': 7, 'PUNCT': 7, 'PROPN': 4, '<unk>': 2}), 'itos': ['<unk>', '<pad>', 'AUX', 'PRON', 'PUNCT', 'PROPN'], 'unk_index': 0, 'stoi': defaultdict(<bound method Vocab._default_unk_index of <torchtext.vocab.Vocab object at 0x127e82210>>, {'<unk>': 0, '<pad>': 1, 'AUX': 2, 'PRON': 3, 'PUNCT': 4, 'PROPN': 5}), 'vectors': None}


In [12]:
class Model(nn.Module):
    def __init__(self, voc_size, num_pos_tags):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(torch.eye(voc_size).type(torch.FloatTensor), freeze=True) #one hot 
        self.linear = nn.Linear(voc_size, num_pos_tags) # batch x C x V -> batch x C_t x C_t-1
        self.transition = nn.Linear(num_pos_tags, num_pos_tags)
        
    def forward(self, words):
        out = self.embedding(words) # (b x N ) -> (b x N x V)
        final = self.linear(out) # (b x N x V) (V x C) -> (b x N x C)
        batch, N, C = final.shape
        vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C)
        vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] 
        return vals

In [13]:
model = Model(V, C)
opt = optim.SGD(model.parameters(), lr=0.01)

def trn(train_iter):
    
    for epoch in range(100):
        model.train()
#        losses = []
        for i, batch in enumerate(train_iter):
            #model.zero_grad()
            #print(i)
            opt.zero_grad() 
            
            sents = batch.word.transpose(0,1)
            label, lengths = batch.pos

            log_potentials = model(sents)

            dist = LinearChainCRF(log_potentials, lengths=lengths) # f(y) = \prod_{n=1}^N \phi(n, y_n, y_n{-1}) 
            #print('d', dist.marginals.shape, dist.marginals)
            #print(dist.argmax.shape) 
            #show_chain(dist.argmax[0])
            #plt.show()

            labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1) \
                        .type(torch.LongTensor), C, lengths=lengths).type(torch.FloatTensor) # b x N x C -> b x (N-1) x C x C 
            #print('l', labels.shape) #labels
            
            #print(dist.log_prob(labels))

            loss = dist.log_prob(labels).sum() # (*sample_shape x batch_shape x event_shape*) -> (*sample_shape x batch_shape*)
            #print(loss)

            (-loss).backward()
            opt.step()
            #losses.append(loss.detach())

        #print(sum(losses))
           
        test_losses = []
        for i, batch in enumerate(test_iter):
            model.eval()

            sents = batch.word.transpose(0,1)
            label, lengths = batch.pos

            log_potentials = model(sents)
                    #print(probs.shape)
                    #for param in model.parameters():
                    #    print(i, param) 
            dist = LinearChainCRF(log_potentials, lengths=lengths) 
            #print('label', label.transpose(0, 1)[0])  

            #print('d', dist.marginals.shape, dist.marginals)
            #print(dist.argmax.shape) 

            #show_chain(dist.argmax[0])  
            #plt.show()

            labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1) \
                        .type(torch.LongTensor), C, lengths=lengths).type(torch.FloatTensor) # b x N x C -> b x (N-1) x C x C 
            #print('l', labels.shape, labels)
                    
            #print(dist.log_prob(labels))

            loss = dist.log_prob(labels).sum() # (*sample_shape x batch_shape x event_shape*) -> (*sample_shape x batch_shape*)
            test_losses.append(loss.detach())
            #print(epoch, loss)
            
        print(torch.tensor(test_losses).mean())

trn(train_iter)

tensor(-13.8836)
tensor(-13.4600)
tensor(-13.0552)
tensor(-12.6687)
tensor(-12.2995)
tensor(-11.9471)
tensor(-11.6107)
tensor(-11.2897)
tensor(-10.9835)
tensor(-10.6913)
tensor(-10.4127)
tensor(-10.1470)
tensor(-9.8936)
tensor(-9.6519)
tensor(-9.4213)
tensor(-9.2012)
tensor(-8.9912)
tensor(-8.7905)
tensor(-8.5988)
tensor(-8.4155)
tensor(-8.2401)
tensor(-8.0723)
tensor(-7.9114)
tensor(-7.7572)
tensor(-7.6092)
tensor(-7.4671)
tensor(-7.3305)
tensor(-7.1991)
tensor(-7.0727)
tensor(-6.9509)
tensor(-6.8335)
tensor(-6.7202)
tensor(-6.6108)
tensor(-6.5051)
tensor(-6.4029)
tensor(-6.3041)
tensor(-6.2084)
tensor(-6.1158)
tensor(-6.0259)
tensor(-5.9389)
tensor(-5.8544)
tensor(-5.7724)
tensor(-5.6928)
tensor(-5.6154)
tensor(-5.5402)
tensor(-5.4672)
tensor(-5.3961)
tensor(-5.3270)
tensor(-5.2597)
tensor(-5.1942)
tensor(-5.1304)
tensor(-5.0683)
tensor(-5.0079)
tensor(-4.9489)
tensor(-4.8915)
tensor(-4.8355)
tensor(-4.7810)
tensor(-4.7278)
tensor(-4.6759)
tensor(-4.6253)
tensor(-4.5759)
tensor(-4.52