In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import nltk
from tqdm import tqdm
import re
from pathlib import Path
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
device

device(type='cpu')

In [None]:
line = "10/20/19, 14:50 - Robin: Gleich erstmal bei dem Wetter zum Bus laufen 😂"
match = re.findall(r".+? - (.+?): (.+)", line)
match

[('Robin', 'Gleich erstmal bei dem Wetter zum Bus laufen 😂')]

In [None]:
v1 = torch.tensor([1,2,3])
v2 = torch.tensor([5,45,80])

(v1[:, None] + v2[None, :])[0, 2]

tensor(81)

In [None]:
vectors[0, 1]

tensor(2)

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(Encoder, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # input embedding
        self.embedding = nn.Embedding(input_size, hidden_size, padding_idx = 0)
        # encoder lstms
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first = True, dropout = 0.1, bidirectional = True)

    def forward(self, x, hidden, cell):
        # obtain embedding of input word
        embedding = self.embedding(x)
        # get the encoder outputs
        output, (hidden, cell) = self.lstm(embedding, (hidden, cell))

        return output, hidden, cell

    def init_hidden(self, batch_size):
        # init with zeros
        return torch.zeros(2 * self.num_layers, batch_size, self.hidden_size, device = device)


In [1]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, max_length, num_layers1, num_layers2):
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers1 = num_layers1
        self.num_layers2 = num_layers2

        # embedding of output words
        self.embedding = nn.Embedding(output_size, hidden_size, padding_idx = 0)
        # the attention matrix which is applied to the encoder outputs
        self.enc_attn_w = nn.Linear(2 * hidden_size, hidden_size)
        # the attention matrix which is applied to the current decoder state
        self.dec_attn_w = nn.Linear(hidden_size, hidden_size)
        # the attention matrix which is applied to the sum of the projected encoder outputs and the current decoder state
        self.comb_attn_w = nn.Linear(hidden_size, 1)
        self.attn_softmax = nn.Softmax(dim = -1)
        # the encoder outputs are double the size of the decoder states because it uses bidirectional layers
        # therefore the weighted context vector need to be scaled to half its size before combining it with the decoder state
        self.scale_context = nn.Linear(2 * hidden_size, hidden_size)
        # the lstm network which is applied to the combination of context vector and decoder state
        self.lstm_after_att = nn.LSTM(hidden_size * 3, hidden_size * 3, num_layers2, batch_first = True, dropout = 0.1)
        # last layer which projects decoder state to the size of the output vocabulary
        self.linear = nn.Linear(hidden_size * 3, output_size)

    def forward(self, y, enc_outputs, enc_hidden, hidden1, cell1, hidden2, cell2):
        # obtain embedding from lastly predicted symbol
        embedding = self.embedding(y)
        # obtain unnormalized attention energies by combining the current decoder state with all encoder outputs
        attn_energies = torch.tanh(self.dec_attn_w(embedding[:, :, None]) + self.enc_attn_w(enc_outputs[:, None]))
        # print("Attention sums: ", attn_energies.size())
        attn_energies = self.comb_attn_w(attn_energies)
        attn_energies = torch.squeeze(attn_energies, dim = -1)
        # print("Attention energies: ", attn_energies.size())
        # obtain weights by normalizing energies
        attn_weights = F.softmax(attn_energies, dim = -1)
        # apply the weights to the encoder outputs
        context = attn_weights[:, :, :, None] * enc_outputs[:, None]
        # print("Context before sum: ", context.size())
        # obtain weighted sum
        context = torch.sum(context, dim = -2)
        # print("Context after sum: ", context.size())
        # scale context down to match decoder state size
        #context = self.scale_context(context)
        # print("Context after scaling: ", context.size())
        # combine context and decoder state
        combined = torch.cat((embedding, context), dim = -1)
        # obtain output by applying lstm network and projecting to vocab size
        out, (hidden2, cell2) = self.lstm_after_att(combined, (torch.cat((hidden2, enc_hidden), dim = -1), torch.cat((cell2, enc_hidden), dim = -1)))
        # print("lstm out: ", out.size())
        out = self.linear(out)
        out = torch.squeeze(out, dim = 1)

        return out, hidden1, cell1, hidden2, cell2


    def init_hidden1(self, batch_size):
        return torch.zeros(self.num_layers1, batch_size, self.hidden_size, device = device)

        
    def init_hidden2(self, batch_size):
        return torch.zeros(self.num_layers2, batch_size, self.hidden_size, device = device)

NameError: name 'nn' is not defined

In [72]:
enc = Encoder(10, 64, 4).to(device)
dec = Decoder(64, 10, 20, 4, 4).to(device)

x = torch.randint(0, 10, (2, 5)).to(device)
y = torch.randint(0, 10, (2, 5)).to(device)
enc_hidden = enc.init_hidden(2)
dec_hidden1 = dec.init_hidden1(2)
dec_hidden2 = dec.init_hidden2(2)
enc_outputs, enc_hidden, enc_cell = enc(x, enc_hidden, enc_hidden)
enc_hidden = enc_hidden.view(4, 2, 2, 64)
enc_hidden = torch.cat((enc_hidden[:, 0], enc_hidden[:, 1]), dim = 1).view(4, 2, 128)

#print(enc_hidden.size())


out, hidden1, cell1, hidden2, cell2 = dec(y, enc_outputs, enc_hidden, dec_hidden1, dec_hidden1, dec_hidden2, dec_hidden2)

In [73]:
out.size()

torch.Size([2, 5, 10])

In [15]:
vocab_size = len(dataset.vocab)
encoder = Encoder(vocab_size, 128, 3).to(device)
decoder = Decoder(128, vocab_size, dataset.max_length, 3, 3).to(device)

In [230]:
train(encoder, decoder, dataset, epochs = 500, batch_size = 512, hidden_size = 128)






  0%|          | 0/500 [00:00<?, ?it/s]

['nein', 'nein', 'nein', 'ich', 'weis', 'nicht', 'was', 'da', 'falsch', 'ist']
['<start>', 'ja', 'und', 'der', '<unk>', '?', 'also', 'die', '4', '<stop>']
['ja', 'und', 'der', '<unk>', '?', 'also', 'die', '4', '<stop>']





In [229]:
def train(encoder, decoder, dataset, epochs = 50, batch_size = 32, hidden_size = 128):

    trainloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)

    criterion = torch.nn.CrossEntropyLoss()
    encoder_opt = torch.optim.Adam(encoder.parameters())
    decoder_opt = torch.optim.Adam(decoder.parameters())

    output_length = dataset.max_length

    encoder.train()
    decoder.train()

    for epoch in tqdm(range(epochs)):
        running_loss = 0
        for i, data in enumerate(trainloader):
            input_tensor, output_tensor = data
            input_tensor = input_tensor.to(device)
            output_tensor = output_tensor.to(device)

            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            
            loss = 0

            enc_hidden = encoder.init_hidden(batch_size)
            enc_cell = encoder.init_hidden(batch_size)

            enc_outputs, enc_hidden, enc_cell = encoder(input_tensor, enc_hidden, enc_cell)
            enc_hidden = enc_hidden.view(3, 2, batch_size, hidden_size)
            enc_hidden = torch.cat((enc_hidden[:, 0], enc_hidden[:, 1]), dim = 1).view(3, batch_size, hidden_size * 2)

            dec_hidden1 = decoder.init_hidden1(batch_size)
            dec_cell1 = decoder.init_hidden1(batch_size)
            dec_hidden2 = decoder.init_hidden2(batch_size)
            dec_cell2 = decoder.init_hidden2(batch_size)

            dec_in = output_tensor
            target = output_tensor

            # print(enc_outputs[0, 0, :20])

            dec_out, dec_hidden1, dec_cell1, dec_hidden2, dec_cell2 = decoder(dec_in, 
                                                                                enc_outputs, 
                                                                                enc_hidden, 
                                                                                dec_hidden1, 
                                                                                dec_cell1, 
                                                                                dec_hidden2, 
                                                                                dec_cell2)

            loss = criterion(torch.reshape(dec_out[:, :-1, :], (-1, 1054)), torch.reshape(target[:, 1:], (-1,)))

            print([dataset.inverse_vocab[t.item()] for t in input_tensor[0, :]])
            print([dataset.inverse_vocab[t.item()] for t in target[0, :]])
            print([dataset.inverse_vocab[t.item()] for t in torch.argmax(dec_out[0, :-1], dim = -1)])
            return
            

            # for di in range(output_length - 1):
            #     dec_in = output_tensor[:, di].view(batch_size, 1)
            #     target = output_tensor[:, di]

            #     dec_out, dec_hidden1, dec_cell1, dec_hidden2, dec_cell2 = decoder(dec_in, 
            #                                                                     enc_outputs, 
            #                                                                     enc_hidden, 
            #                                                                     dec_hidden1, 
            #                                                                     dec_cell1, 
            #                                                                     dec_hidden2, 
            #                                                                     dec_cell2)

            #     loss += criterion(dec_out, output_tensor[:, di + 1])

            loss.backward()
            encoder_opt.step()
            decoder_opt.step()

            running_loss += loss.item()

            if i % 10 == 9:
              print("Epoch {} - Batch {} - Loss: {}".format(epoch, i, running_loss / 10))
              running_loss = 0

    torch.save(encoder.state_dict(), "encoder.pt")
    torch.save(decoder.state_dict(), "decoder.pt")



In [5]:
def gen_input(message: str, dataset):
    tokens = dataset.get_tokens(message)
    tokens = [dataset.vocab[token] if token in dataset.vocab else dataset.vocab["<unk>"] for token in tokens]
    inp = tokens[:dataset.max_length] + [dataset.vocab["<pad>"]] * max(dataset.max_length - len(tokens), 0)

    return torch.tensor(inp)

In [6]:
def decode_greedy(inp, encoder, decoder, dataset):
    encoder.eval()
    decoder.eval()

    hidden_size = encoder.hidden_size
    output_length = inp.size(0)
    batch_size = 1                                                               

    with torch.no_grad():
        enc_hidden = encoder.init_hidden(batch_size)
        enc_cell = encoder.init_hidden(batch_size)
        enc_in = inp.view(batch_size, -1).to(device)
        print(enc_in)

        enc_outputs, enc_hidden, enc_cell = encoder(enc_in, enc_hidden, enc_cell)
        enc_hidden = enc_hidden.view(3, 2, batch_size, hidden_size)
        enc_hidden = torch.cat((enc_hidden[:, 0], enc_hidden[:, 1]), dim = 1).view(3, batch_size, hidden_size * 2)

        dec_hidden1 = decoder.init_hidden1(batch_size)
        dec_cell1 = decoder.init_hidden1(batch_size)
        dec_hidden2 = decoder.init_hidden2(batch_size)
        dec_cell2 = decoder.init_hidden2(batch_size)

        dec_in = torch.tensor(dataset.vocab["<start>"]).view(1, 1).to(device)
        prediction = []

        print(enc_outputs[0, 0, :10])

        dec_out, dec_hidden1, dec_cell1, dec_hidden2, dec_cell2 = decoder(dec_in, 
                                                                            enc_outputs, 
                                                                            enc_hidden, 
                                                                            dec_hidden1, 
                                                                            dec_cell1, 
                                                                            dec_hidden2, 
                                                                            dec_cell2)
        
        pred = torch.argmax(dec_out, dim = -1)
        prediction.append(dataset.inverse_vocab[pred.item()])

        # print(torch.topk(dec_out, 5, 1))

        dec_in = pred.view(1, 1)

        for di in range(output_length-1):
            dec_out, dec_hidden1, dec_cell1, dec_hidden2, dec_cell2 = decoder(dec_in, 
                                                                            enc_outputs, 
                                                                            torch.empty(3, batch_size, 0).to(device), 
                                                                            dec_hidden1, 
                                                                            dec_cell1, 
                                                                            dec_hidden2, 
                                                                            dec_cell2)
            #halo
            pred = torch.argmax(dec_out, dim = -1)
            prediction.append(dataset.inverse_vocab[pred.item()])

            # print(torch.topk(dec_out, 5, -1))

            dec_in = pred.view(1, 1)

        return prediction

In [18]:
encoder = Encoder(1054, 128, 3)
encoder.load_state_dict(torch.load("baseline_encoder.pt", map_location=torch.device('cpu')), strict = True)
decoder = Decoder(128, 1054, 10, 3, 3)
decoder.load_state_dict(torch.load("baseline_decoder.pt", map_location=torch.device('cpu')), strict = True)

<All keys matched successfully>

In [23]:
inp = gen_input("wie geht es dir?", dataset)
decode_greedy(inp, encoder, decoder, dataset)

tensor([[33, 80, 13, 22,  8,  0,  0,  0,  0,  0]])
tensor([-0.0158, -0.3362, -0.0093,  0.0240, -0.0389, -0.0253, -0.3156,  0.3089,
        -0.0141, -0.0667])


['oh',
 '<unk>',
 'ich',
 'bin',
 'gerade',
 'bei',
 '<unk>',
 '<stop>',
 '<pad>',
 '<pad>']

In [8]:
class ChatDataset(Dataset):
    def __init__(self, directory = "chats", max_length = 10):
        self.chats, token_counts = self.get_data(directory)
        self.vocab = self.get_vocab(token_counts)
        self.inverse_vocab = {val: key for key, val in self.vocab.items()}
        self.sequence_pairs = self.get_sequence_pairs(self.chats, max_length)
        self.max_length = max_length

    def get_sequence_pairs(self, chats, max_length):
        sequence_pairs = []
        for chat in chats:
            # get the number of messages
            num_msgs = len(chat["Robin"]) + len(chat["Other"])

            start = 0
            # iterate start to the first message of "Other"
            while start in chat["Robin"]:
                start += 1
            # iterate over all messages
            while start < num_msgs:
                pair = [[], []]
                # add all messages from "Other" as first element of pair
                while start in chat["Other"]:
                    pair[0] += chat["Other"][start]
                    start += 1
                # add all messages from "Robin" as second element of pair
                while start in chat["Robin"]:
                    pair[1] += chat["Robin"][start]
                    start += 1

                pair = [[self.vocab[token] if token in self.vocab else self.vocab["<unk>"] for token in sequence] for sequence in pair]
                pair[0] = pair[0][:max_length] + [self.vocab["<pad>"]] * max(max_length - len(pair[0]), 0)
                pair[1] = [self.vocab["<start>"]] + pair[1][:max_length - 2] + [self.vocab["<stop>"]] + [self.vocab["<pad>"]] * max(max_length - len(pair[1]) - 2, 0)


                # append the pair to the sequence pairs
                sequence_pairs.append(tuple(pair))

        return sequence_pairs
        
    def get_vocab(self, token_counts):
        greater_ten = Counter({token: count for token, count in token_counts.items() if count > 10})
        most_common = greater_ten.most_common()
        vocab = {token: index + 4 for index, (token, _) in enumerate(most_common)}
        vocab["<pad>"] = 0
        vocab["<unk>"] = 1
        vocab["<start>"] = 2
        vocab["<stop>"] = 3
        
        return vocab
    
    def get_data(self, directory: str):
        chats = []
        token_counts = Counter()
        for filename in tqdm(Path(directory).glob("*.txt")):
            chat, counts = self.get_chat_data(filename)
            chats.append(chat)
            token_counts += counts
        return chats, token_counts


    def get_chat_data(self, filename):
        f = open(filename, "r", encoding = "utf-8")
        sequences = {}
        token_counts = Counter()
        index = 0

        for line in tqdm(f):
            # matches the author and the message
            match = re.findall(r".+? - (.+?): (.+)", line)
            if not match:
                continue      
            # get author and message
            author, msg = match[0]
            author = "Other" if author != "Robin" else "Robin"
            # skip media files and missed calls
            if msg in ["<Media omitted>", "Missed voice call", "Missed video call"]:
                continue
            # get tokens
            tokens = self.get_tokens(msg)

            # get token counts
            token_counts.update(tokens)

            # add entry for author in dictionary
            if not author in sequences:
                sequences[author] = {}
            # add message to dict
            sequences[author][index] = tokens
            index += 1

        return sequences, token_counts

    def get_tokens(self, message):
        sequence = nltk.word_tokenize(message)
        tokens = [token.lower() for token in sequence]
        return tokens


    def __len__(self):
        return len(self.sequence_pairs)

    def __getitem__(self, index):
        seq1, seq2 = self.sequence_pairs[index]

        return (torch.tensor(seq1), torch.tensor(seq2))

In [8]:
!unzip chats.zip

Archive:  chats.zip
   creating: chats/
  inflating: chats/sophie.txt        


In [9]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [9]:
dataset = ChatDataset(max_length = 10)

0it [00:00, ?it/s]
0it [00:00, ?it/s][A
1382it [00:00, 13205.02it/s][A
2771it [00:00, 13528.15it/s][A
4162it [00:00, 13641.50it/s][A
5583it [00:00, 13813.47it/s][A
6965it [00:00, 13621.86it/s][A
8328it [00:00, 13538.09it/s][A
9736it [00:00, 13671.21it/s][A
11214it [00:00, 13979.00it/s][A
12620it [00:00, 13962.94it/s][A
14017it [00:01, 13904.08it/s][A
15455it [00:01, 14007.55it/s][A
16856it [00:01, 13926.89it/s][A
18249it [00:01, 13925.75it/s][A
19656it [00:01, 13603.90it/s][A
21019it [00:01, 13575.08it/s][A
22378it [00:01, 13342.84it/s][A
24534it [00:01, 13669.83it/s]
1it [00:01,  1.81s/it]


In [None]:
dataset[10]

(tensor([  9,   5, 178,   4,  75, 372, 157, 161,  75,  11,  77, 167, 381,  37,
          25, 472, 694,   0,   0,   0]),
 tensor([ 2,  1, 21,  1, 33,  1,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]))