In [1]:
import time
from torch.utils.tensorboard import SummaryWriter
import torchtext.data as data
from torchtext.data import BucketIterator
import torch
import torch.nn as nn
from torch_struct import HMM, LinearChainCRF
import matplotlib.pyplot as plt
# from torch_struct.data import ConllXDatasetPOS

# start_time = time.time()
device='cpu'


In [2]:
class ConllXDataset(data.Dataset):
    def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
        examples = []
        columns = [[], []]
        column_map = {1: 0, 3: 1}
        with open(path, encoding=encoding) as input_file:
            for line in input_file:
                line = line.strip()
                if line == "":
                    examples.append(data.Example.fromlist(columns, fields))
                    columns = [[], []]
                else:
                    for i, column in enumerate(line.split(separator)):
#                         print(columns)
                        if i in column_map and column !=',':
                            columns[column_map[i]].append(column)
            examples.append(data.Example.fromlist(columns, fields))
        super(ConllXDataset, self).__init__(examples, fields, **kwargs)


In [3]:
WORD = data.Field(init_token='<bos>', pad_token=None, eos_token='<eos>') #init_token='<bos>', 
POS = data.Field(init_token='<bos>', include_lengths=True, pad_token=None, eos_token='<eos>') 

fields = (('word', WORD), ('pos', POS), (None, None))
train = ConllXDataset('/Users/sofia/nlp-pytorch-struct/data/wsj.train0.conllx', fields, 
                filter_pred=lambda x: len(x.word) < 50) #en_ewt-ud-train.conllu
test = ConllXDataset('/Users/sofia/nlp-pytorch-struct/data/wsj.test0.conllx', fields)
print('total train sentences', len(train))
print('total test sentences', len(test))

WORD.build_vocab(train) # 
POS.build_vocab(train, min_freq = 5, max_size=7)
train_iter = BucketIterator(train, batch_size=20, device=device, shuffle=False)
test_iter = BucketIterator(test, batch_size=20, device=device, shuffle=False)

C = len(POS.vocab)
V = len(WORD.vocab)
train


total train sentences 1186
total test sentences 45


<__main__.ConllXDataset at 0x12e1e8690>

In [None]:
# vars(train).keys()

In [None]:
# vars(train.examples[0])

In [None]:
WORD.vocab.freqs[',']

In [None]:
WORD.vocab.stoi['.']

In [None]:
# max(WORD.vocab.freqs, key = lambda k : WORD.vocab.freqs.get(k))
sorted(WORD.vocab.freqs, key=WORD.vocab.freqs.get, reverse=True)[:10]

In [None]:
WORD.vocab.itos[4]

In [None]:
# vars(WORD.vocab).keys()

In [None]:
# type(set(WORD.vocab.itos))

In [None]:
# set(WORD.vocab.itos)

In [None]:
# vars(train_iter)

In [None]:
# batch = next(iter(train_iter))
# batch

In [None]:
# batch.word

In [None]:
# label, lengths = batch.pos
# lengths

In [None]:
# for b in range(batch.word.shape[1]):
#     print(batch.word[:lengths[b], b], '\n')

In [5]:
# cbow_data = []
# for ex in train_iter:
# #    print(ex.pos)
#     words = ex.word
#     label, lengths = ex.pos
    
#     for b in range(ex.word.shape[1]):
#         for i in range(2, lengths[b]-2):
#             context = torch.stack((ex.word[i-2, b], ex.word[i-1, b], ex.word[i+1, b], ex.word[i+2, b]))
#             target = ex.word[i, b]
#             cbow_data.append((context, target))
# cbow_data

In [None]:
EMDEDDING_DIM = 100

class CBOW(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        
        #out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()       
        #out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)
        self.activation_function2 = nn.LogSoftmax(dim = -1)
        
    def forward(self, inputs):
        embeds = sum(self.embeddings(inputs)).view(1,-1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

model = CBOW(V, EMDEDDING_DIM)

loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
cbow_data = []
for ex in train_iter:
#    print(ex.pos)
    words = ex.word
    _, lengths = ex.pos
    
    for b in range(ex.word.shape[1]):
        for i in range(2, lengths[b]-2):
            context = torch.stack((ex.word[i-2, b], ex.word[i-1, b], ex.word[i+1, b], ex.word[i+2, b]))
            target = ex.word[i, b].unsqueeze(0)
            cbow_data.append((context, target))
cbow_data

for epoch in range(1):
    total_loss = 0

    for context, target in cbow_data:

        log_probs = model(context)
        total_loss += loss_function(log_probs, target)

    #optimize at the end of each epoch
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

In [None]:
model.embeddings.weight

In [None]:
cbow_data[1]

In [None]:
a=model(cbow_data[19][0])
b=model(cbow_data[40][0])


In [None]:
WORD.vocab.itos[torch.argmax(b[0]).item()]

In [None]:
print(' '.join([WORD.vocab.itos[i] for i in cbow_data[40][0]]))
print([WORD.vocab.itos[cbow_data[40][1]]] ) 

In [None]:
# counts for mle's 
tags = [] # prior
bigrams = [] # transition
word_tag_counts = [] # emission
for ex in train_iter:
#    print(ex.pos)
    words = ex.word
    label, lengths = ex.pos
    for batch in range(label.shape[1]):
    #    print(' '.join([WORD.vocab.itos[i] for i in words[:lengths[batch], batch]]))        
        tags.append(label[:lengths[batch], batch])
        bigrams.append(label[:lengths[batch], batch].unfold(0, 2, 1)) #dimension, size, step      
        for i, t in enumerate(label[:lengths[batch], batch]):
            word_tag_counts.append(torch.tensor((t.item(), words[i, batch].item())))
tags = torch.cat(tags, 0)
bigrams = torch.cat(bigrams, 0)
word_tag_counts = torch.stack(word_tag_counts)


In [None]:
# prior
init = torch.ones(C).long() # add-1 smoothing
init.index_put_((tags,), torch.tensor(1), accumulate=True)
assert init[POS.vocab.stoi['<eos>']] == len(train)+1
init = init.float() / init.sum()
assert torch.isclose(init.sum(), torch.tensor(1.))# \sum_C p_c = 1
init = init.log()


In [None]:
init.type()

In [None]:
# transition
transition = torch.ones((C, C)).long() # p(. | eos) = 1/C
transition.index_put_((bigrams[:, 0], bigrams[:, 1]), torch.tensor(1), accumulate=True)
transition = (transition.float() / transition.sum(-1, keepdim=True)).transpose(0, 1) 
assert torch.isclose(transition.sum(0, keepdim=True).sum(), \
        torch.tensor(C, dtype=torch.float)) # should be for x in C-{eos}, sum_C  p(c, x) = 1?
transition = transition.log()


In [None]:
# emission 
emission = torch.ones((C, V)).long()
emission.index_put_((word_tag_counts[:, 0], word_tag_counts[:, 1]), torch.tensor(1), accumulate=True)
emission = (emission.float() / emission.sum(-1, keepdim=True)).transpose(0, 1)
assert torch.isclose(emission.sum(0, keepdim=True).sum(), \
        torch.tensor(C, dtype=torch.float)) # sum_V p(v | c) = 1
emission = emission.log()


In [None]:
transition.shape

In [None]:
emission.shape

In [None]:
init.shape

In [None]:
observations = torch.LongTensor(next(iter(test_iter)).word).transpose(0, 1).contiguous()    
observations.shape

In [None]:
scores = torch.zeros(20, 43, C, C).type_as(emission)
scores

In [None]:
scores += transition.view(1, 1, C, C)
scores.shape

In [None]:
obs = emission[observations.view(20*44), : ]
obs.shape

In [None]:
obs.view(20, 44, C, 1)[:, 1:].shape 

In [None]:
obs

In [None]:
obs.view(20, 44, 1, C)[:, 0].shape 

In [None]:
torch.tensor(3)

In [None]:
def show_chain(chain):
    plt.imshow(chain.detach().sum(-1).transpose(0, 1))

# print('t1', time.time() - start_time)

In [None]:
def test(iters):
    losses = []
    total = 0
    incorrect_edges = 0 
    #model.eval()
    for i, ex in enumerate(iters):      
        observations = torch.LongTensor(ex.word).transpose(0, 1).contiguous()            
        label, lengths = ex.pos
        
#         print(transition)
        dist = HMM(transition, emission, init, observations, lengths=lengths) 
        labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1) \
                .type(torch.LongTensor), C, lengths=lengths).type(torch.FloatTensor)    
        # print(HMM.struct.from_parts(dist.argmax)[0][0])
        # print('label', labels.shape)  
        # print(dist.argmax.shape)
        # show_chain(dist.argmax[0])  
        # plt.show()

        loglik = dist.log_prob(labels).sum()
        # print(loglik, label.shape[1])
        losses.append(loglik.detach()/label.shape[1])

        incorrect_edges += (dist.argmax.sum(-1) - labels.sum(-1)).abs().sum() / 2.0
        total += dist.argmax.sum()         

    print(total, incorrect_edges)
    print('inaccurate', incorrect_edges / total) 
    return torch.tensor(losses).mean()

# print('train-log-lik', test(train_iter))
print('test-log-lik', test(test_iter))

# print("--- %s seconds ---" % (time.time() - start_time))

print(transition,"\n", emission,"\n", init )

In [None]:
m = torch.tensor([[1, 2], [3, 4]])
m

In [None]:
m.shape

In [None]:
torch.tensor(((1, 2), (3, 4))).unsqueeze(-1)

In [None]:

r = torch.randint(0, 9, (2, 3, 4, 4))
r

In [None]:
r[:, 1:, :, :]