In [1]:
import time
from torch.utils.tensorboard import SummaryWriter
import torchtext.data as data
from torchtext.data import BucketIterator
import torch
from torch_struct import HMM, LinearChainCRF
import matplotlib.pyplot as plt
from torch_struct.data import ConllXDatasetPOS

# start_time = time.time()
device='cpu'


In [2]:
WORD = data.Field(init_token='<bos>', pad_token=None, eos_token='<eos>') #init_token='<bos>', 
POS = data.Field(init_token='<bos>', include_lengths=True, pad_token=None, eos_token='<eos>') 

fields = (('word', WORD), ('pos', POS), (None, None))
train = ConllXDatasetPOS('data/wsj.train0.conllx', fields, 
                filter_pred=lambda x: len(x.word) < 50) #en_ewt-ud-train.conllu
test = ConllXDatasetPOS('data/wsj.test0.conllx', fields)
print('total train sentences', len(train))
print('total test sentences', len(test))

WORD.build_vocab(train) # 
POS.build_vocab(train, min_freq = 5, max_size=7)
train_iter = BucketIterator(train, batch_size=20, device=device, shuffle=False)
test_iter = BucketIterator(test, batch_size=20, device=device, shuffle=False)

C = len(POS.vocab)
V = len(WORD.vocab)

# t0 = time.time() - start_time
# print(t0)

total train sentences 1174
total test sentences 45


In [3]:
# counts for mle's 
tags = [] # prior
bigrams = [] # transition
word_tag_counts = [] # emission
for ex in train_iter:
#    print(ex.pos)
    words = ex.word
    label, lengths = ex.pos
    for batch in range(label.shape[1]):
    #    print(' '.join([WORD.vocab.itos[i] for i in words[:lengths[batch], batch]]))        
        tags.append(label[:lengths[batch], batch])
        bigrams.append(label[:lengths[batch], batch].unfold(0, 2, 1)) #dimension, size, step      
        for i, t in enumerate(label[:lengths[batch], batch]):
            word_tag_counts.append(torch.tensor((t.item(), words[i, batch].item())))
tags = torch.cat(tags, 0)
bigrams = torch.cat(bigrams, 0)
word_tag_counts = torch.stack(word_tag_counts)


In [6]:
# prior
init = torch.ones(C).long() # add-1 smoothing
init.index_put_((tags,), torch.tensor(1), accumulate=True)
assert init[POS.vocab.stoi['<eos>']] == len(train)+1
init = init.float() / init.sum()
assert torch.isclose(init.sum(), torch.tensor(1.))# \sum_C p_c = 1
init = init.log()


torch.LongTensor


In [5]:
init.type()

'torch.FloatTensor'

In [5]:
# transition
transition = torch.ones((C, C)).long() # p(. | eos) = 1/C
transition.index_put_((bigrams[:, 0], bigrams[:, 1]), torch.tensor(1), accumulate=True)
transition = (transition.float() / transition.sum(-1, keepdim=True)).transpose(0, 1) 
assert torch.isclose(transition.sum(0, keepdim=True).sum(), \
        torch.tensor(C, dtype=torch.float)) # should be for x in C-{eos}, sum_C  p(c, x) = 1?
transition = transition.log()


In [6]:
# emission 
emission = torch.ones((C, V)).long()
emission.index_put_((word_tag_counts[:, 0], word_tag_counts[:, 1]), torch.tensor(1), accumulate=True)
emission = (emission.float() / emission.sum(-1, keepdim=True)).transpose(0, 1)
assert torch.isclose(emission.sum(0, keepdim=True).sum(), \
        torch.tensor(C, dtype=torch.float)) # sum_V p(v | c) = 1
emission = emission.log()


In [68]:
transition.shape

torch.Size([10, 10])

In [49]:
emission.shape

torch.Size([5743, 10])

In [51]:
init.shape

torch.Size([10])

In [72]:
observations = torch.LongTensor(next(iter(test_iter)).word).transpose(0, 1).contiguous()    
observations.shape

torch.Size([20, 44])

In [73]:
scores = torch.zeros(20, 43, C, C).type_as(emission)
scores

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 

In [74]:
scores += transition.view(1, 1, C, C)
scores.shape

torch.Size([20, 43, 10, 10])

In [79]:
obs = emission[observations.view(20*44), : ]
obs.shape

torch.Size([880, 10])

In [83]:
obs.view(20, 44, C, 1)[:, 1:].shape 

torch.Size([20, 43, 10, 1])

In [89]:
obs

tensor([[-9.7067, -1.7727, -8.8417,  ..., -8.9465, -8.9160, -8.8736],
        [-9.7067, -8.8417, -8.8417,  ..., -8.9465, -8.9160, -8.8736],
        [-9.7067, -8.8417, -8.8417,  ..., -8.9465, -8.9160, -1.6308],
        ...,
        [-9.7067, -8.8417, -8.8417,  ..., -8.9465, -8.9160, -8.8736],
        [-9.7067, -8.8417, -8.8417,  ..., -8.9465, -8.9160, -8.8736],
        [-9.7067, -8.8417, -8.8417,  ..., -8.9465, -8.9160, -8.8736]])

In [96]:
obs.view(20, 44, 1, C)[:, 0].shape 

torch.Size([20, 1, 10])

In [95]:
torch.tensor(3)

torch.Size([1, 1, 1])

In [7]:
def show_chain(chain):
    plt.imshow(chain.detach().sum(-1).transpose(0, 1))

# print('t1', time.time() - start_time)

In [8]:
def test(iters):
    losses = []
    total = 0
    incorrect_edges = 0 
    #model.eval()
    for i, ex in enumerate(iters):      
        observations = torch.LongTensor(ex.word).transpose(0, 1).contiguous()            
        label, lengths = ex.pos
        
#         print(transition)
        dist = HMM(transition, emission, init, observations, lengths=lengths) 
        labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1) \
                .type(torch.LongTensor), C, lengths=lengths).type(torch.FloatTensor)    
        # print(HMM.struct.from_parts(dist.argmax)[0][0])
        # print('label', labels.shape)  
        # print(dist.argmax.shape)
        # show_chain(dist.argmax[0])  
        # plt.show()

        loglik = dist.log_prob(labels).sum()
        # print(loglik, label.shape[1])
        losses.append(loglik.detach()/label.shape[1])

        incorrect_edges += (dist.argmax.sum(-1) - labels.sum(-1)).abs().sum() / 2.0
        total += dist.argmax.sum()         

    print(total, incorrect_edges)
    print('inaccurate', incorrect_edges / total) 
    return torch.tensor(losses).mean()

# print('train-log-lik', test(train_iter))
print('test-log-lik', test(test_iter))

# print("--- %s seconds ---" % (time.time() - start_time))

print(transition,"\n", emission,"\n", init )

tensor(1053.) tensor(170.)
inaccurate tensor(0.1614)
test-log-lik tensor(-11.7505)
tensor([[-0.7926, -1.0316, -2.3026, -0.8810, -1.4414, -0.9779, -2.3528, -2.1235,
         -0.5721, -0.5894],
        [-9.2774, -7.0767, -2.3026, -8.2537, -7.9561, -7.8312, -7.8038, -7.5746,
         -7.4483, -7.2499],
        [-2.2195, -7.0767, -2.3026, -7.5606, -6.5698, -5.6340, -7.1107, -7.5746,
         -7.4483, -7.2499],
        [-2.5240, -3.6427, -2.3026, -2.0989, -2.3142, -3.0354, -0.7686, -0.8246,
         -4.1161, -3.1391],
        [-2.4063, -2.2325, -2.3026, -1.3795, -4.1949, -3.2261, -4.6258, -2.9206,
         -1.4395, -2.5050],
        [-3.2277, -1.6876, -2.3026, -4.6162, -1.8603, -1.0299, -2.2587, -3.1087,
         -5.1457, -2.3083],
        [-2.4542, -1.4102, -2.3026, -5.5457, -1.0615, -5.8853, -6.4175, -5.7828,
         -4.1525, -1.9769],
        [-2.7221, -3.2925, -2.3026, -4.4471, -2.3954, -4.8355, -1.4153, -2.4098,
         -4.0143, -2.9459],
        [-3.0529, -3.1848, -2.3026, -2.7403, 

  'with `validate_args=False` to turn off validation.')


In [28]:
m = torch.tensor([[1, 2], [3, 4]])
m

tensor([[1, 2],
        [3, 4]])

In [35]:
m.shape

torch.Size([2, 2])

In [22]:
torch.tensor(((1, 2), (3, 4))).unsqueeze(-1)

tensor([[[1],
         [2]],

        [[3],
         [4]]])

In [45]:

r = torch.randint(0, 9, (2, 3, 4, 4))
r

tensor([[[[5, 6, 2, 3],
          [2, 3, 8, 8],
          [4, 7, 8, 8],
          [0, 8, 8, 3]],

         [[1, 2, 8, 6],
          [7, 4, 5, 6],
          [3, 7, 3, 0],
          [1, 4, 5, 2]],

         [[7, 3, 2, 4],
          [2, 6, 6, 6],
          [8, 6, 4, 7],
          [1, 5, 3, 3]]],


        [[[8, 2, 2, 5],
          [8, 1, 8, 1],
          [1, 6, 5, 1],
          [1, 7, 4, 7]],

         [[1, 4, 6, 8],
          [5, 5, 7, 5],
          [4, 3, 8, 1],
          [4, 1, 2, 5]],

         [[4, 1, 4, 7],
          [6, 5, 1, 7],
          [6, 7, 3, 8],
          [8, 7, 0, 2]]]])

In [46]:
r[:, 1:, :, :]

tensor([[[[1, 2, 8, 6],
          [7, 4, 5, 6],
          [3, 7, 3, 0],
          [1, 4, 5, 2]],

         [[7, 3, 2, 4],
          [2, 6, 6, 6],
          [8, 6, 4, 7],
          [1, 5, 3, 3]]],


        [[[1, 4, 6, 8],
          [5, 5, 7, 5],
          [4, 3, 8, 1],
          [4, 1, 2, 5]],

         [[4, 1, 4, 7],
          [6, 5, 1, 7],
          [6, 7, 3, 8],
          [8, 7, 0, 2]]]])