In [1]:
import os
import re
from collections import Counter
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
with open(r"..\data\EMNLP_dataset\train\dialogues_train.txt", "r", encoding="UTF-8") as f:
    raw_dialog_lines = f.readlines()

In [3]:
dialogs = [line.strip().split("__eou__") for line in raw_dialog_lines]
dialogs_cleaned = [[utt.strip() for utt in dialog if utt.strip()] for dialog in dialogs]

In [4]:
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

In [5]:
pairs = []
for dialog in dialogs_cleaned:
    for i in range(len(dialog)-1):
        context = " ".join(dialog[:i+1]).strip()
        response = dialog[i+1].strip()
        if context and response:
            pairs.append({"context": context, "response": response})

df_pairs = pd.DataFrame(pairs)
df_pairs.head()

Unnamed: 0,context,response
0,"Say , Jim , how about going for a few beers af...",You know that is tempting but is really not go...
1,"Say , Jim , how about going for a few beers af...",What do you mean ? It will help us to relax .
2,"Say , Jim , how about going for a few beers af...",Do you really think so ? I don't . It will jus...
3,"Say , Jim , how about going for a few beers af...",I guess you are right.But what shall we do ? I...
4,"Say , Jim , how about going for a few beers af...",I suggest a walk over to the gym where we can ...


In [6]:
df_pairs["context_clean"] = df_pairs["context"].apply(preprocess_text)
df_pairs["response_clean"] = df_pairs["response"].apply(preprocess_text)

In [7]:
df_pairs["context_tokens"] = df_pairs["context_clean"].apply(lambda x: x.split())
df_pairs["response_tokens"] = df_pairs["response_clean"].apply(lambda x: x.split())

In [8]:
all_tokens = []
for tokens in df_pairs["context_tokens"]:
    all_tokens.extend(tokens)
for tokens in df_pairs["response_tokens"]:
    all_tokens.extend(tokens)

word_counts = Counter(all_tokens)
vocab = {word: i+2 for i, (word, _) in enumerate(word_counts.most_common())}
vocab["<PAD>"] = 0
vocab["<UNK>"] = 1

In [9]:
def tokens_to_indices(tokens, vocab):
    return [vocab.get(token, vocab["<UNK>"]) for token in tokens]

df_pairs["context_idx"] = df_pairs["context_tokens"].apply(lambda x: tokens_to_indices(x, vocab))
df_pairs["response_idx"] = df_pairs["response_tokens"].apply(lambda x: tokens_to_indices(x, vocab))

In [10]:
def pad_seq(seq, max_len, pad_value=0):
    if len(seq) < max_len:
        seq = seq + [pad_value] * (max_len - len(seq))
    else:
        seq = seq[:max_len]
    return seq

max_len_context = 40
max_len_response = 40

df_pairs["context_idx_padded"] = df_pairs["context_idx"].apply(lambda x: pad_seq(x, max_len_context))
df_pairs["response_idx_padded"] = df_pairs["response_idx"].apply(lambda x: pad_seq(x, max_len_response))

In [11]:
class Seq2SeqDataset(Dataset):
    def __init__(self, df):
        self.contexts = df["context_idx_padded"].tolist()
        self.responses = df["response_idx_padded"].tolist()

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, idx):
        context = torch.tensor(self.contexts[idx], dtype=torch.long)
        response = torch.tensor(self.responses[idx], dtype=torch.long)
        return context, response

In [12]:
dataset = Seq2SeqDataset(df_pairs)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [13]:
len(dataset), dataset[0]

(76053,
 (tensor([ 153,  878,   28,   31,   66,   13,    6,  183, 3257,  180,  273,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]),
  tensor([   2,   41,   14,    8, 3919,   27,    8,   52,   42,   39,   13,   70,
          1676,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0])))