In [None]:
import os
import csv
import shutil
import zipfile
import gzip
import pickle
import itertools
import urllib.parse
import urllib.request
from collections import Counter
import functools

import numpy as np
import torch
import torch.utils.data
from nltk import word_tokenize
import matplotlib.pyplot as plt
import ipywidgets as widgets
from tqdm import tqdm_notebook

In [None]:
plt.style.use('fivethirtyeight')

## Introduction

In this turorial, we will train a simple recurrent sequence-to-sequnce dialog model using the [OpenSubtitles](http://opus.nlpl.eu/OpenSubtitles.php) dataset. This dataset contains already tokenized subtitles collected from http://www.opensubtitles.org/.

## Dataset

As in the previous tutorial, we are going to use a `Voacbulary` class, and a subclass of the `torch.utils.data.Dataset` class.

In [None]:
def maybe_download_and_unzip_file(file_url, file_name=None):
    """
    Download and unzip a remote archive if it does not exists yet

    :param file_url: Url of the archive
    :param file_name:  (Default value = None) The filename to save the content

    """    
    if file_name is None:
        file_name = os.path.basename(file_url)
        
    if not os.path.exists(file_name):
        print(f'Downloading: {file_name}')
        
        with urllib.request.urlopen(file_url) as response, open(file_name, 'wb') as target_file:
            shutil.copyfileobj(response, target_file)

        print(f'Downloaded: {file_name}')

        file_extension = os.path.splitext(file_name)[1]
        if file_extension == '.zip':
            print(f'Extracting zip: {file_name}')
            with zipfile.ZipFile(file_name, 'r') as zip_file:
                zip_file.extractall('.')
                
    else:
        print(f'Exists: {file_name}')

In [None]:
dataset_url = 'http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2018/mono/OpenSubtitles.en.gz'
dataset_filename = 'OpenSubtitles.en.gz'

In [None]:
maybe_download_and_unzip_file(dataset_url, dataset_filename)

In [None]:
class Vocab(object):
    """ Vocabulary class to provide token to id correpondance """
    END_TOKEN = '<end>'
    START_TOKEN = '<start>'
    PAD_TOKEN = '<pad>'
    UNK_TOKEN = '<unk>'

    def __init__(self, special_tokens=None):
        """
        Initialize the vocabulary class

        :param special_tokens:  (Default value = None) A list of special tokens. The PAD token should be the first in the list, if used.

        """
        super().__init__()

        self.special_tokens = special_tokens

        self.token2id = {}
        self.id2token = {}

        self.token_counts = Counter()

        if self.special_tokens is not None:
            self.add_document(self.special_tokens)

    def add_document(self, document, rebuild=True):
        """
        Process the document and add tokens from the it to the vocabulary

        :param document: A list of tokens in the document
        :param rebuild:  (Default value = True) Whether to rebuild the token2id correspondance or not

        """
        for token in document:
            self.token_counts[token] += 1

            if token not in self.token2id:
                self.token2id[token] = len(self.token2id)

        if rebuild:
            self._rebuild_id2token()

    def add_documents(self, documents):
        """
        Process a list of documents and tokens from the them to the vocabulary

        :param documents: A list of documents, where each document is a list of tokens

        """
        for doc in documents:
            self.add_document(doc, rebuild=False)

        self._rebuild_id2token()

    def _rebuild_id2token(self):
        """ Revuild the token to id correspondance """
        self.id2token = {i: t for t, i in self.token2id.items()}

    def get(self, item, default=None):
        """
        Given a token, return the corresponding id

        :param item: A token
        :param default:  (Default value = None) Default value to return if token is not present in the vocabulary

        """
        return self.token2id.get(item, default)

    def __getitem__(self, item):
        """
        Given a token, return the corresponding id

        :param item: A token

        """
        return self.token2id[item]

    def __contains__(self, item):
        """
        Check if a token is present in the vocabulary

        :param item: A token

        """
        return item in self.token2id

    def __len__(self):
        """ Return the length of the vocabulary """
        return len(self.token2id)

    def __str__(self):
        """ Get a string representation of the vocabulary """
        return f'{len(self)} tokens'

There are two important differences with the previous tutorial. First, notice how we form `<query>,<response>` pairs from a sequence of subtitles. Next, since the vocabulary can be quite large, we prune it to contain only the top 50,000 most common tokens.

In [None]:
class SubtitlesDialogDataset(torch.utils.data.Dataset):
    """ A conversational dialog dataset with query-response pairs  """
    def __init__(self, filename, vocab=None, max_lines = 1000, max_len=50, max_vocab_size=50000):
        """
        Initialize a conversational dialog dataset with query-response pairs        

        :param filename: Path to the OpenSubstitles dataset
        :param vocab:  (Default value = None) Vocabulary, will be created if None
        :param max_lines:  (Default value = 1000) Limit the number of lines to read from the dataset file
        :param max_len:  (Default value = 50) Maximum length of the sentences
        :param max_vocab_size:  (Default value = 50000) Maximum size of the vocabulary

        """

        self.lines = []
        with gzip.open(filename, 'rb') as f:
            for i, line in enumerate(f):
                if i >= max_lines:
                    break

                tokens = word_tokenize(line.decode('utf-8'))
                self.lines.append(tokens)

        self.max_lines = min(len(self.lines), max_lines)
                
        if vocab is None:
            vocab = Vocab(special_tokens=[Vocab.PAD_TOKEN, Vocab.START_TOKEN, Vocab.END_TOKEN, Vocab.UNK_TOKEN])
            vocab.add_documents(self.lines)
            vocab.prune_vocab(max_vocab_size)

            print(f'Created vocab: {vocab}')

            
        if max_len is None:
            max_len = max(len(s) for s in itertools.chain.from_iterable(self.sentences))
            print(f'Calculed max len: {max_len}')
        
        self.vocab = vocab
        self.max_len = max_len
        
    def _pad_sentnece(self, sent):
        """
        Cut the sentence if needed and pad it to the maximum len

        :param sent: The input sentnece

        """
        sent = sent[:self.max_len - 1] + [Vocab.END_TOKEN,]
        
        nb_pad = self.max_len - len(sent)
        sent = sent + [Vocab.PAD_TOKEN,] * nb_pad
        
        return sent
        
    def _process_sent(self, sent):
        """
        Cut, pad, and convert the sentence from tokens to indices using the vocabulary

        :param sent: The input sentence

        """
        sent = self._pad_sentnece(sent)
        sent = [self.vocab[t] if t in self.vocab else self.vocab[Vocab.UNK_TOKEN] for t in sent]
        
        sent = np.array(sent, dtype=np.long)
        return sent
        
    def __getitem__(self, index):
        """
        Create a pair of query-reponse using two consequtive lines in the dataset and return it

        :param index: Index of the query line. The reponse is the next line.

        """
        query = self.lines[index]
        response = self.lines[index+1]
        
        query = self._process_sent(query)
        response = self._process_sent(response)        
        
        return query, response
    
    def __len__(self):
        """ Return the total length of the dataset """
        return self.max_lines - 1

In [None]:
dataset = SubtitlesDialogDataset(dataset_filename, max_lines=1000000)

In [None]:
len(dataset.lines)

In [None]:
len(dataset)

### Save the vocab

In [None]:
vocab_filename = 'tmp/seq2seq_dialog.vocab.csv'

In [None]:
dataset.vocab.save(vocab_filename)

## Word embeddings

We are going to use the same word embeddings, as in the previous tutorial.

In [None]:
embeddings_url = 'https://mednli.blob.core.windows.net/shared/word_embeddings/crawl-300d-2M.pickled'
embeddings_filename = 'crawl-300d-2M.pickled'

In [None]:
maybe_download_and_unzip_file(embeddings_url, embeddings_filename)

In [None]:
with open(embeddings_filename, 'rb') as pkl_file:
    word_embeddings = pickle.load(pkl_file)

### Embedding matrix

In [None]:
def create_embeddings_matrix(word_embeddings, vocab):
    """
    Given word embeddings dictionary and the vocabulary, construct the embeddings martix, where each row corresponds to a token and contains the embedding of this token

    :param word_embeddings: Word embeddings dictionary, token -> numpy array
    :param vocab: Vocabulary

    """    
    embedding_size = word_embeddings[list(word_embeddings.keys())[0]].shape[0]

    W_emb = np.zeros((len(vocab), embedding_size), dtype=np.float32)
    
    special_tokens = {
        t: np.random.uniform(-0.3, 0.3, (embedding_size,))
        for t in (Vocab.UNK_TOKEN, )
    }
    special_tokens[Vocab.PAD_TOKEN] = np.zeros((embedding_size,))

    nb_unk = 0
    for i, t in vocab.id2token.items():
        if t in special_tokens:
            W_emb[i] = special_tokens[t]
        else:
            if t in word_embeddings:
                W_emb[i] = word_embeddings[t]
            else:
                W_emb[i] = np.random.uniform(-0.3, 0.3, embedding_size)
                nb_unk += 1

    print(f'Nb unk: {nb_unk}')

    return W_emb

In [None]:
W_emb = create_embeddings_matrix(word_embeddings, dataset.vocab)

## Model

We are going to use a standard seq2seq model. Given an input query (sentence), the model produces a response. Although this model does not have any context information, it provides a good starting point. 

In [None]:
class Seq2SeqModel(torch.nn.Module):
    """ A simple GRU-based sequence-to-sequence model without attention """
    def __init__(self, vocab_size, embedding_size, hidden_size, teacher_forcing,
                 max_len,trainable_embeddings, start_index, end_index, pad_index, W_emb=None):
        """
        Initialize the model

        :param vocab_size: The size of the vocabulary
        :param embedding_size: Dimension of the embeddings
        :param hidden_size: The size of the hidden layers, including GRU
        :param teacher_forcing: The probability of teacher forcing
        :param max_len: Maximum length of the sequences
        :param trainable_embeddings: Whether the embedding layer will be trainable or frozen
        :param start_index: Index of the START token in the vocabulary
        :param end_index: Index of the END token in the vocabulary
        :param pad_index: Index of the PAD token in the vocabulary
        :param W_emb:  (Default value = None) Initial values of the embedding layer, a numpy array

        """

        super().__init__()

        self.teacher_forcing = teacher_forcing
        self.max_len = max_len
        self.start_index = start_index
        self.end_index = end_index
        self.pad_index = pad_index
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index)
        if W_emb is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(W_emb))
        if not trainable_embeddings:
            self.embedding.weight.requires_grad = False

        self.encoder = torch.nn.GRU(embedding_size, hidden_size, batch_first=True)
        self.decoder = torch.nn.GRUCell(embedding_size, hidden_size)
        self.decoder_projection = torch.nn.Linear(hidden_size, vocab_size)

            
    def encode(self, inputs):
        """
        Encode input sentence and return the last hidden state of the GRU

        :param inputs: The input sentence

        """
        batch_size = inputs.size(0)
        inputs_lengths = torch.sum(inputs != self.pad_index, dim=1).long()
        
        inputs_emb = self.embedding(inputs)
        outputs, h = self.encoder(inputs_emb)
        
        h_last_hidden = outputs[np.arange(batch_size), inputs_lengths - 1]
        
        return h_last_hidden
    
    def decode(self, decoder_hidden, targets=None):
        """
        Decode the response given the initial hidden state of the decoder

        :param decoder_hidden: Initial hidden state of the decoder
        :param targets:  (Default value = None) True decoding targets to be used for teacher forcing

        """
        batch_size = decoder_hidden.size(0)
        
        outputs_logits = []
        decoder_inputs = torch.full_like(decoder_hidden[:, 0], self.start_index).long()
        for i in range(self.max_len):
            decoder_inputs_emb = self.embedding(decoder_inputs)
            
            decoder_hidden = self.decoder(decoder_inputs_emb, decoder_hidden)
            
            decoder_output_logit = self.decoder_projection(decoder_hidden)
            
            if np.random.rand() < self.teacher_forcing and targets is not None:
                decoder_inputs = targets[:, i]
            else:
                decoder_inputs = decoder_output_logit.argmax(dim=1).long()
            
            outputs_logits.append(decoder_output_logit)
            
        outputs_logits = torch.stack(outputs_logits, dim=1)
            
        return outputs_logits
        
    def forward(self, inputs, targets=None):
        """
        Encode the input query and decode the response

        :param inputs: The input sentence
        :param targets:  (Default value = None) True decoding targets

        """
        decoder_hidden = self.encode(inputs)
        outputs_logits = self.decode(decoder_hidden, targets)

        return outputs_logits

In [None]:
def softmax_masked(inputs, mask, dim=1, epsilon=0.000001):
    """
    Perform the softmas operation on a batch of masked sequences of different lengths

    :param inputs: Input sequences, a 2d array of the shape (batch_size, max_seq_len)
    :param mask: Mask, an array of 1 and 0
    :param dim:  (Default value = 1) Dimension of the softmax operation
    :param epsilon:  (Default value = 0.000001)

    """
    inputs_exp = torch.exp(inputs)
    inputs_exp = inputs_exp * mask.float()
    inputs_exp_sum = inputs_exp.sum(dim=dim)
    inputs_attention = inputs_exp / (inputs_exp_sum.unsqueeze(dim) + epsilon)

    return inputs_attention

In [None]:
class Seq2SeqAttentionModel(torch.nn.Module):
    """ A more advanced GRU-based sequence-to-sequence model with attention """
    def __init__(self, vocab_size, embedding_size, hidden_size, teacher_forcing,
                 max_len,trainable_embeddings, start_index, end_index, pad_index, W_emb=None):
        """
        Initialize the model

        :param vocab_size: The size of the vocabulary
        :param embedding_size: Dimension of the embeddings
        :param hidden_size: The size of the hidden layers, including GRU
        :param teacher_forcing: The probability of teacher forcing
        :param max_len: Maximum length of the sequences
        :param trainable_embeddings: Whether the embedding layer will be trainable or frozen
        :param start_index: Index of the START token in the vocabulary
        :param end_index: Index of the END token in the vocabulary
        :param pad_index: Index of the PAD token in the vocabulary
        :param W_emb:  (Default value = None) Initial values of the embedding layer, a numpy array

        """

        super().__init__()

        self.teacher_forcing = teacher_forcing
        self.max_len = max_len
        self.start_index = start_index
        self.end_index = end_index
        self.pad_index = pad_index
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index)
        if W_emb is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(W_emb))
        if not trainable_embeddings:
            self.embedding.weight.requires_grad = False

        self.encoder = torch.nn.GRU(embedding_size, hidden_size, batch_first=True)
        self.decoder = torch.nn.GRUCell(embedding_size, hidden_size)

        self.attention_decoder = torch.nn.Linear(hidden_size, hidden_size)
        self.attention_encoder = torch.nn.Linear(hidden_size, hidden_size)        
        self.attention_reduce = torch.nn.Linear(hidden_size, 1, bias=False)                
        self.decoder_hidden_combine = torch.nn.Linear(hidden_size * 2, hidden_size)
        
        self.decoder_projection = torch.nn.Linear(hidden_size, vocab_size)

            
    def encode(self, inputs):
        """
        Encode input sentence and return the all hidden states and the input mask

        :param inputs: The input sentence

        """
        batch_size = inputs.size(0)
        inputs_mask = (inputs != self.pad_index).long()
        inputs_lengths = torch.sum(inputs_mask, dim=1)
        
        inputs_emb = self.embedding(inputs)
        outputs, h = self.encoder(inputs_emb)
        
        return outputs, inputs_mask
    
    def decode(self, encoder_hiddens, inputs_mask, targets=None):
        """
        Decode the response given the all hidden states of the encoder

        :param encoder_hiddens: Hidden states of the decoder
        :param inputs_mask: Input mask
        :param targets:  (Default value = None) True decoding targets to be used for teacher forcing

        """
        batch_size = encoder_hiddens.size(0)

        outputs_logits = []
        decoder_hidden = torch.zeros_like(encoder_hiddens[:,0,:])
        decoder_inputs = torch.full_like(decoder_hidden[:, 0], self.start_index).long()
        for i in range(self.max_len):
            decoder_inputs_emb = self.embedding(decoder_inputs)
            
            att_enc = self.attention_encoder(encoder_hiddens)
            att_dec = self.attention_decoder(decoder_hidden)
            att = torch.tanh(att_enc + att_dec.unsqueeze(1))
            att_reduced = self.attention_reduce(att).squeeze(-1)
            att_normazlied = softmax_masked(att_reduced, inputs_mask)

            decoder_hidden_att = torch.sum(encoder_hiddens * att_normazlied.unsqueeze(-1), dim=1)
            decoder_hidden_combined = self.decoder_hidden_combine(torch.cat([decoder_hidden, decoder_hidden_att], dim=-1))
            
            decoder_hidden = self.decoder(decoder_inputs_emb, decoder_hidden_combined)
            
            decoder_output_logit = self.decoder_projection(decoder_hidden)
            
            if np.random.rand() < self.teacher_forcing and targets is not None:
                decoder_inputs = targets[:, i]
            else:
                decoder_inputs = decoder_output_logit.argmax(dim=1).long()
            
            outputs_logits.append(decoder_output_logit)
            
        outputs_logits = torch.stack(outputs_logits, dim=1)
            
        return outputs_logits
        
    def forward(self, inputs, targets=None):
        """
        Encode the input query and decode the response

        :param inputs: The input sentence
        :param targets:  (Default value = None) True decoding targets

        """
        encoder_hiddens, inputs_mask = self.encode(inputs)
        outputs_logits = self.decode(encoder_hiddens, inputs_mask, targets)

        return outputs_logits

Below are some helper functions to save and load the weights of the model. Feel free to use them in your projects!

In [None]:
def load_model(model_class, filename):
    """
    Create the model of the given class and load the checkpoint from the given file

    :param model_class: Model class
    :param filename: Path to the checkpoint

    """
    def _map_location(storage, loc):
        """ A utility function to load a trained on a GPU model to the CPU """
        return storage

    # load trained on GPU models to CPU
    map_location = None
    if not torch.cuda.is_available():
        map_location = _map_location

    state = torch.load(str(filename), map_location=map_location)

    model = model_class(**state['model_params'])
    model.load_state_dict(state['model_state'])

    return model


def save_model(model, filename, model_params=None):
    """
    Save the model configuration parameters and the weights to the file

    :param model: A trained model
    :param filename: Path to the checkpoint
    :param model_params:  (Default value = None) A dictionary of model configuration parameters

    """
    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    state = {
        'model_params': model_params or {},
        'model_state': model.state_dict(),
    }

    torch.save(state, str(filename))

In [None]:
hidden_size = 256
trainable_embeddings = True
teacher_forcing = 0.5

In [None]:
model_params = dict(
    hidden_size=hidden_size,
    trainable_embeddings=trainable_embeddings,
    teacher_forcing=teacher_forcing,

    vocab_size = len(dataset.vocab),
    embedding_size=W_emb.shape[1],
    max_len=dataset.max_len,
    start_index=dataset.vocab[Vocab.START_TOKEN],
    end_index=dataset.vocab[Vocab.END_TOKEN], 
    pad_index=dataset.vocab[Vocab.PAD_TOKEN],
)

In [None]:
model = Seq2SeqAttentionModel(**model_params, W_emb=W_emb)

In [None]:
model = model.to('cuda')

In [None]:
model

## Training

In [None]:
model_filename = 'tmp/seq2seq_dialog_att.pt'

In [None]:
batch_size=128
nb_epochs = 500
learning_rate=0.001
weight_decay = 0.000001

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
losses = []

for epoch in tqdm_notebook(range(nb_epochs), desc='Epochs'):
    epoch_losses = []
    for i, (query, response) in enumerate(tqdm_notebook(dataloader, desc='Iteration', leave=False)):
        model.train()
        optimizer.zero_grad()

        query = query.to('cuda')
        response = response.to('cuda')        
        
        response_logits = model(query, response)
    
        # reshape for the cross entropy loss
        response_logits = response_logits.view(-1, response_logits.size(2))
        response = response.view(-1)
        loss = criterion(response_logits, response)
        
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

    epoch_loss = np.mean(epoch_losses)
        
    losses.append(epoch_loss)
    print('Epoch {}, loss {}'.format(epoch, epoch_loss))
    
    save_model(model, model_filename, model_params=model_params)

### Plot the training loss

In [None]:
fig, ax = plt.subplots()

ax.plot(np.arange(len(losses)), losses)

ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

fig.tight_layout()

## Try some inputs

In [None]:
model = load_model(Seq2SeqModel, model_filename)
model = model.to('cuda')

In [None]:
def generate_response(query, model, dataset):
    """
    Generate a response from the model for a given query

    :param query: Query to generate the response to
    :param model: A trained  model
    :param dataset: The dataset object for pre-processing

    """
    if not isinstance(query, list):
        query = word_tokenize(query)
        
    query = dataset._process_sent(query)
    query = torch.tensor(query).to('cuda')
        
    response_logits = model(query.view(1, -1)).squeeze(0)
    response_indices = response_logits.argmax(dim=-1).cpu().numpy()
    
    response = [dataset.vocab.id2token[int(idx)] for idx in response_indices]
    response = [t for t in response if t not in dataset.vocab.special_tokens]
    response = ' '.join(response)

    return response

In [None]:
query = 'How are you?'

In [None]:
generate_response(query, model, dataset)

In [None]:
def print_response(text_widget):
    """
    Print the response given the Jupyter text input widget 

    :param text_widget: Jupyter text input widget

    """
    query = text_widget.value
    response = generate_response(query, model, dataset)
    print(f'{query} -> {response}')

In [None]:
text_input = widgets.Text(value='How are you?')
text_input.on_submit(print_response)

In [None]:
text_input