In [1]:
import json
from nltk import word_tokenize

with open("data/train_full.json") as f:
    dataset = json.load(f)

In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [274]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        # Bx50
        self.word_embeddings = nn.Embedding(10035, 50)
        # Bx10
        self.user_bot_embeddings = nn.Embedding(4, 10)
        self.rnn = nn.GRU(60, 128, 1)
        self.linear = nn.Linear(128, 3)
        self.softmax = nn.LogSoftmax()
        
        self.hidden = self.init_hidden()
    
    def init_hidden(self):
        return Variable(torch.zeros(1, 1, 128))
    
    # input => Bx2xN, B - sentence len
    def forward(self, input, calc_softmax=False):
        word_emb = self.word_embeddings(input[:, 0, :])
        user_bot_emb = self.user_bot_embeddings(input[:, 1, :])
        input_combined = torch.cat((word_emb, user_bot_emb), 2)
        input_combined = input_combined.view(input_combined.size()[1], 1, input_combined.size()[-1])
        
        rnn_out, self.hidden = self.rnn(input_combined, self.hidden)
        output = self.linear(self.hidden).view(1, 3)
        
        # Softmax только в самом конце считаем!
        # Без батчей работаем пока
        if calc_softmax:
            probs = self.softmax(output)
            return hidden, probs
        else:
            return hidden, output

In [275]:
model = Model()
input = Variable(torch.LongTensor([[[1, 10], [1, 1]]]))
hidden = model.init_hidden()

input.size(), hidden.size()

(torch.Size([1, 2, 2]), torch.Size([1, 1, 128]))

In [276]:
# hidden, output = model.forward(input)
# hidden

In [277]:
def mockup_data(vocab_words_size=100, vocab_user_bot_size=2):
    # 5 dialogs
    # 5 - 10 sentences each
    dialogs = []
    for _ in range(5):
        dialog = []
        for i in range(5, 10):
            sent_words = torch.LongTensor(i).random_(vocab_words_size)
            sent_userbot = torch.LongTensor(i).random_(vocab_user_bot_size)
            dialog.append(torch.cat((sent_words, sent_userbot)).view(1, 2, i))
        dialogs.append(dialog)
    return dialogs
    

In [278]:
import pickle

def load_dialogs_and_labels(filename):
    with open(filename, 'rb') as f:
        dialogs_vecs, labels = pickle.load(f)
    labels = Variable(torch.LongTensor(labels))
    dialogs = []
    for dialog_vec in dialogs_vecs:
        dialog = []
        for sent_vec in dialog_vec:
            dialog.append(torch.LongTensor(sent_vec).view(1, 2, -1))
        dialogs.append(dialog)
    return dialogs, labels

In [279]:
dialogs, labels = load_dialogs_and_labels('data/dilogs_and_labels.pickle')

In [280]:
# dialogs = mockup_data()
# labels = Variable(torch.LongTensor(len(dialogs)).random_(3))
# labels, dialogs[0][0].size()
# labels

In [281]:
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [282]:
for _ in range(10):
    avg_loss = 0
    for ind, dialog in enumerate(dialogs):
        model.zero_grad()
        model.hidden = model.init_hidden()

        for sent in dialog[:-1]:
            input = Variable(torch.LongTensor(sent))
            hidden, out = model(input)
        input = Variable(torch.LongTensor(dialog[-1]))
        hidden, out = model(input, True)

        loss = loss_function(out, labels[ind])
        avg_loss += loss
        loss.backward()
        optimizer.step()
    print("Loss: {}".format(avg_loss / len(dialogs)))

Loss: Variable containing:
 0.7655
[torch.FloatTensor of size 1]

Loss: Variable containing:
 0.7139
[torch.FloatTensor of size 1]

Loss: Variable containing:
 0.6981
[torch.FloatTensor of size 1]



KeyboardInterrupt: 

In [288]:
# https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222/11
list(labels.data)

[0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 2,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 2,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 1,
 0,
 1,
 1,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 0,
 1,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
