In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import nltk
from tqdm.notebook import tqdm
import re
from pathlib import Path
from collections import Counter
import pandas as pd
from xml.dom import minidom
from bs4 import BeautifulSoup
import string
import matplotlib.pyplot as plt
import numpy as np
import operator
import gensim

import time
import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, pretrained_emb = None):
        super(Encoder, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # input embedding
        if pretrained_emb is None:
            self.embedding = nn.Embedding(input_size, hidden_size, padding_idx = 0)
        else:
            self.embedding = nn.Embedding.from_pretrained(pretrained_emb)
        # encoder lstms
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first = True, dropout = 0.1, bidirectional = True)

    def forward(self, x, hidden, cell):
        # obtain embedding of input word
        embedding = self.embedding(x)
        # get the encoder outputs
        output, (hidden, cell) = self.lstm(embedding, (hidden, cell))

        return output, hidden, cell

    def init_hidden(self, batch_size):
        # init with zeros
        return torch.zeros(2 * self.num_layers, batch_size, self.hidden_size, device = device)


In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, num_layers, pretrained_emb = None):
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # embedding of output words
        if pretrained_emb is None:
            self.embedding = nn.Embedding(output_size, hidden_size, padding_idx = 0)
        else:
            self.embedding = nn.Embedding.from_pretrained(pretrained_emb)
        # the attention matrix which is applied to the encoder outputs
        self.enc_attn_w = nn.Linear(2 * hidden_size, hidden_size)
        # the attention matrix which is applied to the current decoder state
        self.dec_attn_w = nn.Linear(hidden_size, hidden_size)
        # the attention matrix which is applied to the sum of the projected encoder outputs and the current decoder state
        self.comb_attn_w = nn.Linear(hidden_size, 1)
        self.attn_softmax = nn.Softmax(dim = -1)
        # the encoder outputs are double the size of the decoder states because it uses bidirectional layers
        # therefore the weighted context vector need to be scaled to half its size before combining it with the decoder state
        self.scale_enc_hidden = nn.Linear(2 * hidden_size, hidden_size)
        # the lstm network which is applied to the combination of context vector and decoder state
        self.lstm = nn.LSTM(hidden_size * 3, hidden_size, num_layers, batch_first = True, dropout = 0.1)
        # last layer which projects decoder state to the size of the output vocabulary
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, y, enc_outputs, enc_hidden, hidden, cell):
        # obtain embedding from lastly predicted symbol
        embedding = self.embedding(y)
        # obtain unnormalized attention energies by combining the current decoder state with all encoder outputs
        attn_energies = torch.tanh(self.dec_attn_w(embedding[:, :, None]) + self.enc_attn_w(enc_outputs[:, None]))
        attn_energies = self.comb_attn_w(attn_energies)
        attn_energies = torch.squeeze(attn_energies, dim = -1)
        # obtain weights by normalizing energies
        attn_weights = F.softmax(attn_energies, dim = -1)
        # apply the weights to the encoder outputs
        context = attn_weights[:, :, :, None] * enc_outputs[:, None]
        # obtain weighted sum
        context = torch.sum(context, dim = -2)
        # combine context and decoder state
        combined = torch.cat((embedding, context), dim = -1)
        enc_hidden = self.scale_enc_hidden(enc_hidden)
        # obtain output by applying lstm network and projecting to vocab size
        out, (hidden, cell) = self.lstm(combined, (hidden + enc_hidden, cell))
        out = self.linear(out)
        out = torch.squeeze(out, dim = 1)

        return out, hidden, cell


    def init_hidden(self, batch_size):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size, device = device)

In [None]:
hidden_size = 64
num_batches = 2
seq_len = 5
vocab_size = 10

num_enc_lay = 4
num_dec_lay = 4

enc = Encoder(vocab_size, hidden_size, num_enc_lay).to(device)
dec = Decoder(hidden_size, vocab_size, num_dec_lay).to(device)

x = torch.randint(0, vocab_size, (num_batches, seq_len)).to(device)
y = torch.randint(0, vocab_size, (num_batches, seq_len)).to(device)

enc_hidden = enc.init_hidden(num_batches)
dec_hidden = dec.init_hidden(num_batches)

enc_outputs, enc_hidden, enc_cell = enc(x, enc_hidden, enc_hidden)

enc_hidden = enc_hidden.view(enc.num_layers, 2, num_batches, 64)
enc_hidden = torch.cat((enc_hidden[-1, 0], enc_hidden[-1, 1]), dim = 1).view(1, num_batches, hidden_size * 2)

out, hidden, cell = dec(y, enc_outputs, enc_hidden, dec_hidden, dec_hidden)

In [None]:
def get_pretrained_emb(dataset, emb_size, window = 5, min_count = 5, iter = 300):
    # get all sequences from dataset
    sequences = dataset.get_all_sequences()
    # train word2vec model
    word2vec = gensim.models.Word2Vec(sequences, size = emb_size, window = window, min_count = min_count, iter = iter)
    # set vocab of dataset
    dataset.set_vocab_from_gensim(word2vec.wv)
    # return pretrained embeddings
    return torch.cat([torch.zeros(5, emb_size), torch.FloatTensor(word2vec.wv.vectors)])

In [None]:
def get_pretrained_models(src_dataset, trt_dataset, hidden_size = 128, num_enc_lay = 3, num_dec_lay = 3, epochs = 200, pretrain_emb = False):

    # pretrain src embeddings
    if pretrain_emb:
        pretrained_emb = get_pretrained_emb(src_dataset, hidden_size)

    # set src vocab size
    src_vocab_size = len(src_dataset.vocab)

    # train src models
    src_encoder = Encoder(src_vocab_size, hidden_size, num_enc_lay, pretrained_emb if pretrain_emb else None).to(device)
    src_decoder = Decoder(hidden_size, src_vocab_size, num_dec_lay, pretrained_emb if pretrain_emb else None).to(device)
    train(src_encoder, src_decoder, src_dataset, epochs, batch_size = 512)

    # pretrain trt embeddings
    if pretrain_emb:
        pretrained_emb = get_pretrained_emb(trt_dataset, hidden_size)

    # set trt vocab size
    trt_vocab_size = len(trt_dataset.vocab)

    # define trt models
    trt_encoder = Encoder(trt_vocab_size, hidden_size, num_enc_lay, pretrained_emb if pretrain_emb else None).to(device)
    trt_decoder = Decoder(hidden_size, trt_vocab_size, num_dec_lay, pretrained_emb if pretrain_emb else None).to(device)

    # get pretrained params
    encoder_pre = src_encoder.state_dict()
    encoder_pre.pop("embedding.weight")

    decoder_pre = src_decoder.state_dict()
    decoder_pre.pop("embedding.weight")
    decoder_pre.pop("linear.weight")
    decoder_pre.pop("linear.bias")

    # load pretrained params into trt models
    trt_encoder.load_state_dict(encoder_pre, strict = False)
    trt_decoder.load_state_dict(decoder_pre, strict = False)

    return trt_encoder, trt_decoder

In [81]:
encoder, decoder = get_pretrained_models(eng_dataset, wa_dataset, 256, 5, 2, epochs = 200, pretrain_emb=True)

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))

Epoch 0 - Loss: 5.479780033656529
Epoch 1 - Loss: 3.4351761136736187
Epoch 2 - Loss: 3.078288813999721
Epoch 3 - Loss: 2.985598727634975
Epoch 4 - Loss: 2.9357080459594727
Epoch 5 - Loss: 2.8832288197108675
Epoch 6 - Loss: 2.8433937072753905
Epoch 7 - Loss: 2.7933183670043946
Epoch 8 - Loss: 2.7637302943638393
Epoch 9 - Loss: 2.7268248694283623
Epoch 10 - Loss: 2.683944375174386
Epoch 11 - Loss: 2.6491590772356304
Epoch 12 - Loss: 2.608696392604283
Epoch 13 - Loss: 2.571050453186035
Epoch 14 - Loss: 2.5340819222586495
Epoch 15 - Loss: 2.50437137058803
Epoch 16 - Loss: 2.5150723321097237
Epoch 17 - Loss: 2.4680127007620674
Epoch 18 - Loss: 2.425076457432338
Epoch 19 - Loss: 2.4018927710396905
Epoch 20 - Loss: 2.3645123617989676
Epoch 21 - Loss: 2.3436673300606863
Epoch 22 - Loss: 2.309556606837681
Epoch 23 - Loss: 2.2911182403564454
Epoch 24 - Loss: 2.2692659105573383
Epoch 25 - Loss: 2.2409066336495536
Epoch 26 - Loss: 2.221466827392578
Epoch 27 - Loss: 2.1968146460396905
Epoch 28 - Lo

In [None]:
vocab_size = len(wa_dataset.vocab)
encoder = Encoder(vocab_size, hidden_size=128, num_layers = 3).to(device)
decoder = Decoder(hidden_size=128, output_size=vocab_size, num_layers=3).to(device)

In [82]:
train(encoder, decoder, wa_dataset, epochs = 100, batch_size = 512)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

Epoch 0 - Loss: 5.038227717081706
Epoch 1 - Loss: 3.134919595718384
Epoch 2 - Loss: 3.2895942846934
Epoch 3 - Loss: 3.025574048360189
Epoch 4 - Loss: 2.692671791712443
Epoch 5 - Loss: 2.6088014761606853
Epoch 6 - Loss: 2.542850399017334
Epoch 7 - Loss: 2.501156949996948
Epoch 8 - Loss: 2.469832499821981
Epoch 9 - Loss: 2.4492687861124676
Epoch 10 - Loss: 2.42799498240153
Epoch 11 - Loss: 2.4038827737172443
Epoch 12 - Loss: 2.3836277961730956
Epoch 13 - Loss: 2.368703317642212
Epoch 14 - Loss: 2.34832878112793
Epoch 15 - Loss: 2.306018352508545
Epoch 16 - Loss: 2.2913071473439532
Epoch 17 - Loss: 2.2605347633361816
Epoch 18 - Loss: 2.232298008600871
Epoch 19 - Loss: 2.200734758377075
Epoch 20 - Loss: 2.187978251775106
Epoch 21 - Loss: 2.169781271616618
Epoch 22 - Loss: 2.133491611480713
Epoch 23 - Loss: 2.1216053485870363
Epoch 24 - Loss: 2.0955012321472166
Epoch 25 - Loss: 2.084765911102295
Epoch 26 - Loss: 2.06378067334493
Epoch 27 - Loss: 2.034674135843913
Epoch 28 - Loss: 2.00702681

In [None]:
def train(encoder, decoder, dataset, epochs = 50, batch_size = 32):

    trainloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)

    criterion = torch.nn.CrossEntropyLoss()
    encoder_opt = torch.optim.Adam(encoder.parameters())
    decoder_opt = torch.optim.Adam(decoder.parameters())

    vocab_size = len(dataset.vocab)
    hidden_size = encoder.hidden_size

    encoder.train()
    decoder.train()

    for epoch in tqdm(range(epochs)):
        running_loss = 0
        for i, data in enumerate(trainloader):
            # get data and move it to the device
            input_tensor, output_tensor = data
            input_tensor = input_tensor.to(device)
            output_tensor = output_tensor.to(device)

            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            
            loss = 0

            # init encoder hidden state and cell
            enc_hidden = encoder.init_hidden(batch_size)
            enc_cell = encoder.init_hidden(batch_size)
            # obtain encoder outputs
            enc_outputs, enc_hidden, enc_cell = encoder(input_tensor, enc_hidden, enc_cell)
            # concatenate the last hidden states from both directions
            enc_hidden = enc_hidden.view(encoder.num_layers, 2, batch_size, hidden_size)
            enc_hidden = torch.cat((enc_hidden[-1, 0], enc_hidden[-1, 1]), dim = 1).view(1, batch_size, hidden_size * 2)

            # init decoder hidden and cell state
            dec_hidden = decoder.init_hidden(batch_size)
            dec_cell = decoder.init_hidden(batch_size)

            for i in range(output_tensor.size(1) - 1):
                dec_in = output_tensor[:, i].view(-1, 1)
                target = output_tensor[:, i + 1].view(-1, 1)

                dec_out, dec_hidden, dec_cell = decoder(dec_in, enc_outputs, enc_hidden, dec_hidden, dec_cell)

                loss += criterion(torch.reshape(dec_out, (-1, vocab_size)), torch.reshape(target, (-1,)))

            # dec_in = output_tensor
            # target = output_tensor

            # # obtain decoder outputs
            # dec_out, dec_hidden, dec_cell = decoder(dec_in, enc_outputs, enc_hidden, dec_hidden, dec_cell)
            # # decoder out is of shape (batch, seq_len, vocab) and we reshape it by concatenating all sequences of all batches
            # # we shift the target to the right to avoid the decoder learning to mimic the target sequence
            # # target is of shape (batch, seq_len) and we reshape it like the decoder output
            # loss = criterion(torch.reshape(dec_out[:, :-1, :], (-1, vocab_size)), torch.reshape(target[:, 1:], (-1,)))

            loss.backward()
            encoder_opt.step()
            decoder_opt.step()

            running_loss += loss.item()

        print("Epoch {} - Loss: {}".format(epoch, running_loss / ((len(dataset) // batch_size) * dataset.max_length)))

    torch.save(encoder.state_dict(), "encoder.pt")
    torch.save(decoder.state_dict(), "decoder.pt")



In [None]:
def gen_input(message: str, dataset):
    tokens = dataset.get_tokens(message)
    tokens = [dataset.vocab[token] if token in dataset.vocab else dataset.vocab["<unk>"] for token in tokens]
    inp = tokens[:dataset.max_length - 1] + [dataset.vocab["<new>"]] + [dataset.vocab["<pad>"]] * max(dataset.max_length - len(tokens) - 1, 0)

    return torch.tensor(inp)

In [None]:
def decode_beam(inp, encoder, decoder, dataset, beam_width):
    encoder.eval()
    decoder.eval()

    hidden_size = encoder.hidden_size
    batch_size = 1

    top_picks = []

    with torch.no_grad():
        # init encoder inputs
        enc_hidden = encoder.init_hidden(batch_size)
        enc_cell = encoder.init_hidden(batch_size)
        enc_in = inp.view(batch_size, -1).to(device)

        # obtain encoder outputs
        enc_outputs, enc_hidden, enc_cell = encoder(enc_in, enc_hidden, enc_cell)

        #prepare hidden encoder state for decoder
        enc_hidden = enc_hidden.view(3, 2, batch_size, hidden_size)
        enc_hidden = torch.cat((enc_hidden[:, 0], enc_hidden[:, 1]), dim = 1).view(3, batch_size, hidden_size * 2)

        # init decoder inputs
        dec_hidden = decoder.init_hidden(batch_size)
        dec_cell = decoder.init_hidden(batch_size)
        dec_in = torch.tensor(dataset.vocab["<start>"]).view(1, 1).to(device)

        # obtain first decoder outputs
        dec_out, dec_hidden, dec_cell = decoder(dec_in, enc_outputs, enc_hidden, dec_hidden, dec_cell)

        # get first top predictions
        top_k = torch.topk(F.softmax(dec_out, 1), beam_width, 1)

        # save parameters of the first top picks
        top_picks = [{
            "seq": [token],
            "prob": np.log(prob.item()),
            "hid": dec_hidden,
            "cell": dec_cell,
        } for prob, token in zip(top_k[0][0], top_k[1][0])]

        # do 10 decoding steps
        for i in range(10):
            hypotheses = []

            # go through every current top pick
            for pick in top_picks:
                # the lastly predicted symbol is the next input
                dec_in = pick["seq"][-1].view(1, 1)

                # get the next outputs
                dec_out, dec_hidden, dec_cell = decoder(dec_in, enc_outputs, enc_hidden, pick["hid"], pick["cell"])

                # get next top picks of the current hypothesis
                top_k = torch.topk(F.softmax(dec_out, 1), beam_width, 1)

                # store parameters of the top picks
                picks = [{
                    "seq": pick["seq"] + [token],
                    "prob": np.log(prob.item()) + pick["prob"],
                    "hid": dec_hidden,
                    "cell": dec_cell,
                } for prob, token in zip(top_k[0][0], top_k[1][0])]

                # add to current hypothesis
                hypotheses += picks

            # sort after probability
            hypotheses = sorted(hypotheses, key = operator.itemgetter("prob"), reverse = True)

            # get top k hyptheses
            top_picks = hypotheses[:beam_width]

    for pick in top_picks:
        print(np.exp(pick["prob"].item()), [dataset.inverse_vocab[token.item()] for token in pick["seq"]])





In [83]:
inp = gen_input("halo", wa_dataset)
decode_beam(inp, encoder, decoder, wa_dataset, 10)

RuntimeError: ignored

In [None]:
class ChatDataset(Dataset):
    def __init__(self, directory = "chats", max_length = 10, word_list = None, spell_check_path= None):
        # list of words to filter chats on
        self.spell_check = {}
        if not spell_check_path is None:
            with open(spell_check_path) as f:
                self.spell_check = eval(f.read())

        self.chats = self.get_data(directory)
        self.chats, token_counts = self.filter_tokens(self.chats, word_list)
        self.vocab = self.get_vocab(token_counts)
        self.inverse_vocab = {val: key for key, val in self.vocab.items()}
        self.sequence_pairs = self.get_sequence_pairs(self.chats, max_length)
        self.max_length = max_length

    def get_all_sequences(self):
        return [seq for pair in self.sequence_pairs for group in pair for seq in group]

    def set_vocab_from_gensim(self, wv):
        if "<unk>" in wv.index2word:
          wv.index2word.remove("<unk>")
        vocab = {token: i for i, token in enumerate(wv.index2word, start = 5)}
        vocab["<pad>"] = 0
        vocab["<unk>"] = 1
        vocab["<start>"] = 2
        vocab["<stop>"] = 3
        vocab["<new>"] = 4

        self.vocab = vocab
        self.inverse_vocab = {val: key for key, val in vocab.items()}

    def get_sequence_pairs(self, chats, max_length):
        sequence_pairs = []
        for chat in chats:       
            if len(chat) % 2 == 1:
                chat.pop()

            # iterate over all messages
            chat_iter = iter(chat)
            for p1 in chat_iter:
                p2 = next(chat_iter)

                pair = {
                    "Person1": [[token for token in msg["tokens"]] for msg in p1["messages"]],
                    "Person2": [[token for token in msg["tokens"]] for msg in p2["messages"]]
                }

                sequence_pairs.append(tuple(sequence for _, sequence in pair.items()))

        return sequence_pairs
        
    def get_vocab(self, token_counts):
        token_counts.pop("<unk>", None)

        greater_ten = Counter({token: count for token, count in token_counts.items() if count > 10})
        most_common = greater_ten.most_common()
        vocab = {token: index + 5 for index, (token, _) in enumerate(most_common)}
        vocab["<pad>"] = 0
        vocab["<unk>"] = 1
        vocab["<start>"] = 2
        vocab["<stop>"] = 3
        vocab["<new>"] = 4
        
        return vocab

    def filter_tokens(self, chats, word_list):
        new_chats = []
        token_counts = Counter()
        for chat in chats:
            data = {"group_id": [], "message_id": [], "author": [], "token": []}

            for group in chat:
                for message in group["messages"]:
                    for token in message["tokens"]:
                        data["group_id"].append(group["id"])
                        data["message_id"].append(message["id"])
                        data["author"].append(group["author"])
                        data["token"].append(token)


            df = pd.DataFrame(data)

            if word_list is not None:
                df_unknown = df.loc[~(df["token"].isin(word_list["word"])
                                    | df["token"].str.fullmatch(u"[\U00010000-\U0010ffff(\u2764\uFE0F)]")
                                    | df["token"].isin(list(string.punctuation)))]
                # print(df_unknown.groupby(["token"]).size().sort_values(ascending = False).head(50))

                df.loc[~df["token"].isin(word_list["word"]), "token"] = "<unk>"
                df = df.loc[df["token"].isin(word_list["word"])
                            | (df["token"] == "<unk>")
                            | df["token"].str.fullmatch(u"[\U00010000-\U0010ffff(\u2764\uFE0F)]")
                            | df["token"].isin(list(string.punctuation))]

                new_chat = []
                group_id = 0
                msg_id = 0

                for (_, group) in df.groupby(["group_id"]):
                    new_group = {"id": group_id, "author": group["author"].iloc[0], "messages": []}
                    group_id += 1
                    msg_id = 0

                    for (_, message) in group.groupby(["message_id"]):
                        new_message = {
                            "id": msg_id,
                            "tokens": message["token"].tolist()}
                        new_group["messages"].append(new_message)
                        msg_id += 1

                    new_chat.append(new_group)

                new_chats.append(new_chat)
            
            else:
                new_chats.append(chat)

            counts = df.groupby(["token"]).size()
            token_counts += Counter(counts.to_dict())
        
        return new_chats, token_counts

    
    def get_data(self):
        pass


    def get_sorted_sequences(self):
        pass


    def get_sequences(self):
        pass

    def get_tokens(self, message):
        emoji_pattern = re.compile(r"(\S*?)" u"([\U00010000-\U0010ffff(\u2764\uFE0F)])")
        message = re.sub(emoji_pattern, r"\1 \2", message)
        emoji_pattern = re.compile(u"([\U00010000-\U0010ffff(\u2764\uFE0F)])" r"(\S+)")
        message = re.sub(emoji_pattern, r"\1 \2", message)

        sequence = nltk.word_tokenize(message)
        tokens = [token.lower() for token in sequence]

        new_tokens = []
        for token in tokens:
            new = token
            multichars = re.findall(r'((\w)\2{2,})', token)
            if multichars:
                for char in multichars:
                    new = new.replace(char[0], char[1] * 1)
    
            new_tokens.append(new)

        new_tokens = [self.spell_check[token] if token in self.spell_check else token for token in new_tokens]
        new_tokens = ["<number>" if re.fullmatch(r"\d+", token) else token for token in new_tokens]
 
        return new_tokens


    def __len__(self):
        return len(self.sequence_pairs)

    def __getitem__(self, index):
        seq1, seq2 = self.sequence_pairs[index]
        seq1, seq2 = [[token for msg in seq for token in msg + ["<new>"]] for seq in [seq1, seq2]]
        seq1, seq2 = [[self.vocab[token] if token in self.vocab else self.vocab["<unk>"] for token in seq] for seq in [seq1, seq2]]

        seq1 = seq1[:self.max_length] + [self.vocab["<pad>"]] * max(self.max_length - len(seq1), 0)
        seq2 = [self.vocab["<start>"]] + seq2[:self.max_length - 2] + [self.vocab["<stop>"]] + [self.vocab["<pad>"]] * max(self.max_length - len(seq2) - 2, 0)  

        # return seq1, seq2
        return (torch.tensor(seq1), torch.tensor(seq2))

In [None]:
class WhatsAppDataset(ChatDataset):
    def get_data(self, directory):
        chats = []
        token_counts = Counter()
        for filename in tqdm(Path(directory).glob("*.txt")):
            sequences = self.get_sequences(filename)
            chats.append(sequences)

        return chats

    def get_date_from_msg(self, msg):
        # print(msg)
        date = re.findall(r"(\d+/\d+/\d+, \d+:\d+) -", msg)

        if date:
            return time.mktime(datetime.datetime.strptime(date[0], r"%m/%d/%y, %H:%M").timetuple())
        else:
            return 0

    def get_sorted_sequences(self, filename):
        f = open(filename, "r", encoding = "utf-8")
        messages = []
        index = 0
        for line in f:
            if re.match(r"\d+/\d+/\d+, \d+:\d+", line):
                messages.append(line)
                index += 1
            else:
                # print(line)
                messages[index - 1] += line
                # print(messages[index - 1], index - 1)

        messages.sort(key = self.get_date_from_msg)
        # print(messages)

        return messages

    def get_sequences(self, filename):
        messages = self.get_sorted_sequences(filename)
        print(len(messages))
        sequences = []
        group = {}
        prev_author = ""
        msg_id = 0
        group_id = 0
        num_cont = 0

        for i, line in enumerate(tqdm(messages[:])):
            # matches the author and the message
            match = re.findall(r".+? - (.+?): (.+)", line, flags = re.DOTALL)
            if not match:
                continue      
            # get author and message
            author, msg = match[0]

            # skip media files and missed calls
            if re.match("(<Media omitted>)|(Missed voice call)|(Missed video call)|This message was|You deleted", msg):
                continue

            # get tokens
            tokens = self.get_tokens(msg)
            if not tokens:
                continue

            author = "Person1" if author != "Robin" else "Person2"
            if author != prev_author:
                prev_author = author
                if group:
                    sequences.append(group)

                group_id += 1
                msg_id = 0

                group = {
                    "id": group_id,
                    "author": author,
                    "messages": []
                }

            group["messages"].append({
                "id": msg_id,
                "tokens": tokens
            })

            msg_id += 1

        print(len(sequences))
        return sequences

In [None]:
class EngDialogueDataset(ChatDataset):
    def get_data(self, directory):
        chats = []

        for filename in Path(directory).glob("*.txt"):
            sequences = self.get_sequences(filename)
            chats.append(sequences)
        
        return chats

    def get_sequences(self, filename):
        sequences = []

        with open(filename, "r", encoding = "utf-8") as f:
            for i, line in enumerate(tqdm(f)):
                try:
                    question, answer = re.findall("(.+?)\t(.+?)\n", line)[0]
                except:
                    continue

                for j, text in enumerate([question, answer]):
                    group = {
                        "id": i * 2 + j * 1,
                        "author": "Person" + str(j + 1),
                        "messages": [
                            {
                                "id": 0,
                                "tokens": nltk.word_tokenize(text)
                            }
                        ]
                    }

                    sequences.append(group)

            return sequences

In [None]:
!unzip data.zip

In [None]:
nltk.download('punkt')

In [None]:
# german word list

words = pd.read_csv("german_words/words.csv", header = None, usecols = [1], delimiter=",", names = ["word"])
words = words.loc[words["word"].str.len() > 1]
words = words[:100000]
words = words["word"].str.lower()

counter = Counter()

for filename in Path("german_words/chatkorpus").glob("*.html"):
    with open(filename, 'r') as f:
        soup = BeautifulSoup(f, 'html.parser')

    msgs = soup.find_all("font")
    msgs = [msg.getText() for msg in msgs]
    for msg in tqdm(msgs):
        tokens = nltk.word_tokenize(msg)
        tokens = [token.lower() for token in tokens if len(token) > 1 and not re.search(r"\W|[0-9]", token)]
        tokens = [token.replace("ae", "ä") for token in tokens]
        tokens = [token.replace("oe", "ö") for token in tokens]
        tokens = [token.replace("ue", "ü") for token in tokens]
        counter.update(tokens)

words2 = [word.lower() for word, _ in counter.most_common(20000)]
words2 = pd.DataFrame(words2, columns = ["word"])
words2 = words2.merge(words, how = 'outer', on = "word", indicator=True).loc[lambda x : x['_merge']=='left_only']["word"]

words = pd.concat([words, words2], ignore_index = True)


special_words = open("german_words/special_words.txt", "r", encoding = "utf-8")
words2 = [word.strip() for word in special_words]
words2 = pd.Series(words2)

words = pd.concat([words, words2], ignore_index = True)
words.to_csv("german_words/top_words.csv", header = ["word"])

In [None]:
eng_dataset = EngDialogueDataset(max_length = 20, directory = "eng_dialogues")

In [None]:
wa_dataset = WhatsAppDataset(max_length = 20, word_list = pd.read_csv("german_words/top_words.csv"), spell_check_path="word_correct.txt")

In [None]:
print(list(eng_dataset.inverse_vocab[token.item()] for token in eng_dataset[8][0]))
print(list(eng_dataset.inverse_vocab[token.item()] for token in eng_dataset[8][1]))

In [None]:
import pickle
with open("vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

In [None]:
for key, val in dataset.vocab.items():
    if not key in vocab:
        print(key, val)

In [None]:
dataset.vocab

In [None]:
emoji_pattern = re.compile(r"(\S*?)" u"([\U00010000-\U0010ffff])")
string = "Ich weis🥺😞was geht"
print(string)
print(nltk.word_tokenize(string), "\n")
string = re.sub(emoji_pattern, r"\1 \2", string)
emoji_pattern = re.compile(u"([\U00010000-\U0010ffff])" r"(\S+)")
string = re.sub(emoji_pattern, r"\1 \2", string)
# re.findall(emoji_pattern, string)
# print(string)
# new = string
# for match in matches:
print(string)
print(nltk.word_tokenize(string), "\n")


In [None]:
a = pd.DataFrame(["hi", "hallo", "was", "was"])
b = pd.DataFrame(["wie", "geht", "es"])

c = a.merge(b, how = 'outer' ,indicator=True).loc[lambda x : x['_merge']=='left_only']
c