In [1]:
import os
import csv
import shutil
import zipfile
import gzip
import pickle
import itertools
import urllib.parse
import urllib.request
from collections import Counter
import functools

import numpy as np
import torch
import torch.utils.data
from nltk import word_tokenize
import matplotlib.pyplot as plt
import ipywidgets as widgets
from tqdm import tqdm_notebook

In [2]:
plt.style.use('fivethirtyeight')

## Introduction

In this turorial, we will train a simple recurrent sequence-to-sequnce dialog model using the [OpenSubtitles](http://opus.nlpl.eu/OpenSubtitles.php) dataset. This dataset contains already tokenized subtitles collected from http://www.opensubtitles.org/.

## Dataset

As in the previous tutorial, we are going to use a `Voacbulary` class, and a subclass of the `torch.utils.data.Dataset` class.

In [3]:
def maybe_download_and_unzip_file(file_url, file_name=None):
    if file_name is None:
        file_name = os.path.basename(file_url)
        
    if not os.path.exists(file_name):
        print(f'Downloading: {file_name}')
        
        with urllib.request.urlopen(file_url) as response, open(file_name, 'wb') as target_file:
            shutil.copyfileobj(response, target_file)

        print(f'Downloaded: {file_name}')

        file_extension = os.path.splitext(file_name)[1]
        if file_extension == '.zip':
            print(f'Extracting zip: {file_name}')
            with zipfile.ZipFile(file_name, 'r') as zip_file:
                zip_file.extractall('.')
                
    else:
        print(f'Exists: {file_name}')

In [4]:
dataset_url = 'http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2018/mono/OpenSubtitles.en.gz'
dataset_filename = 'OpenSubtitles.en.gz'

In [5]:
maybe_download_and_unzip_file(dataset_url, dataset_filename)

Exists: OpenSubtitles.en.gz


In [6]:
class Vocab(object):
    END_TOKEN = '<end>'
    START_TOKEN = '<start>'
    PAD_TOKEN = '<pad>'
    UNK_TOKEN = '<unk>'

    def __init__(self, special_tokens=None):
        super().__init__()

        self.special_tokens = special_tokens

        self.token2id = {}
        self.id2token = {}

        self.token_counts = Counter()

        if self.special_tokens is not None:
            self.add_document(self.special_tokens)

    def add_document(self, document, rebuild=True):
        for token in document:
            self.token_counts[token] += 1

            if token not in self.token2id:
                self.token2id[token] = len(self.token2id)

        if rebuild:
            self._rebuild_id2token()

    def add_documents(self, documents):
        for doc in documents:
            self.add_document(doc, rebuild=False)

        self._rebuild_id2token()

    def prune_vocab(self, max_size):
        nb_tokens_before = len(self.token2id)

        tokens_all = set(self.token2id.keys())
        tokens_special = set(self.special_tokens)
        tokens_most_common = set(t for t, c in self.token_counts.most_common(max_size)) - tokens_special
        tokens_to_delete = tokens_all - tokens_most_common - tokens_special

        for token in tokens_to_delete:
            self.token_counts.pop(token)

        self.token2id = {}
        for i, token in enumerate(self.special_tokens):
            self.token2id[token] = i
        for i, token in enumerate(tokens_most_common):
            self.token2id[token] = i + len(self.special_tokens)

        self._rebuild_id2token()

        nb_tokens_after = len(self.token2id)

        print(f'Vocab pruned: {nb_tokens_before} -> {nb_tokens_after}')

    def _rebuild_id2token(self):
        self.id2token = {i: t for t, i in self.token2id.items()}

    def get(self, item, default=None):
        return self.token2id.get(item, default)

    def __getitem__(self, item):
        return self.token2id[item]

    def __contains__(self, item):
        return item in self.token2id

    def __len__(self):
        return len(self.token2id)

    def __str__(self):
        return f'{len(self)} tokens'

    def save(self, filename):
        with open(filename, 'w') as csv_file:
            writer = csv.DictWriter(csv_file, fieldnames=['token', 'counts', 'is_special'])
            writer.writeheader()
            for idx in range(len(self.token2id)):
                token = self.id2token[idx]
                is_special = 1 if token in self.special_tokens else 0
                writer.writerow({'token': token, 'counts': self.token_counts[token], 'is_special': is_special})
        
    @staticmethod
    def load(filename):
        with open(filename, 'r') as csv_file:
            token2id = {}
            tokens_counts = {}
            special_tokens = []
            reader = csv.DictReader(csv_file)
            for i, row in enumerate(reader):
                token2id[row['token']] = i
                tokens_counts[row['token']] = int(row['counts'])
                if bool(int(row['is_special'])):
                    special_tokens.append(row['token'])
                
                
        vocab = Vocab()
        vocab.token2id = token2id
        vocab.token_counts = Counter(tokens_counts)
        vocab.special_tokens = special_tokens
        vocab._rebuild_id2token()
        
        return vocab

There are two important differences with the previous tutorial. First, notice how we form `<query>,<response>` pairs from a sequence of subtitles. Next, since the vocabulary can be quite large, we prune it to contain only the top 50,000 most common tokens.

In [7]:
class SubtitlesDialogDataset(torch.utils.data.Dataset):
    def __init__(self, filename, vocab=None, max_lines = 1000, max_len=50, max_vocab_size=50000):

        self.lines = []
        with gzip.open(filename, 'rb') as f:
            for i, line in enumerate(f):
                if i >= max_lines:
                    break

                tokens = word_tokenize(line.decode('utf-8'))
                self.lines.append(tokens)

        self.max_lines = min(len(self.lines), max_lines)
                
        if vocab is None:
            vocab = Vocab(special_tokens=[Vocab.PAD_TOKEN, Vocab.START_TOKEN, Vocab.END_TOKEN, Vocab.UNK_TOKEN])
            vocab.add_documents(self.lines)
            vocab.prune_vocab(max_vocab_size)

            print(f'Created vocab: {vocab}')

            
        if max_len is None:
            max_len = max(len(s) for s in itertools.chain.from_iterable(self.sentences))
            print(f'Calculed max len: {max_len}')
        
        self.vocab = vocab
        self.max_len = max_len
        
    def _pad_sentnece(self, sent):
        sent = sent[:self.max_len - 1] + [Vocab.END_TOKEN,]
        
        nb_pad = self.max_len - len(sent)
        sent = sent + [Vocab.PAD_TOKEN,] * nb_pad
        
        return sent
        
    def _process_sent(self, sent):
        sent = self._pad_sentnece(sent)
        sent = [self.vocab[t] if t in self.vocab else self.vocab[Vocab.UNK_TOKEN] for t in sent]
        
        sent = np.array(sent, dtype=np.long)
        return sent
        
    def __getitem__(self, index):
        query = self.lines[index]
        response = self.lines[index+1]
        
        query = self._process_sent(query)
        response = self._process_sent(response)        
        
        return query, response
    
    def __len__(self):
        return self.max_lines - 1

In [8]:
dataset = SubtitlesDialogDataset(dataset_filename, max_lines=1000000)

Vocab pruned: 57727 -> 50000
Created vocab: 50000 tokens


In [9]:
len(dataset.lines)

1000000

In [10]:
len(dataset)

999999

### Save the vocab

In [11]:
vocab_filename = 'tmp/seq2seq_dialog.vocab.csv'

In [12]:
dataset.vocab.save(vocab_filename)

## Word embeddings

We are going to use the same word embeddings, as in the previous tutorial.

In [13]:
embeddings_url = 'https://mednli.blob.core.windows.net/shared/word_embeddings/crawl-300d-2M.pickled'
embeddings_filename = 'crawl-300d-2M.pickled'

In [14]:
maybe_download_and_unzip_file(embeddings_url, embeddings_filename)

Exists: crawl-300d-2M.pickled


In [15]:
with open(embeddings_filename, 'rb') as pkl_file:
    word_embeddings = pickle.load(pkl_file)

### Embedding matrix

In [16]:
def create_embeddings_matrix(word_embeddings, vocab):
    embedding_size = word_embeddings[list(word_embeddings.keys())[0]].shape[0]

    W_emb = np.zeros((len(vocab), embedding_size), dtype=np.float32)
    
    special_tokens = {
        t: np.random.uniform(-0.3, 0.3, (embedding_size,))
        for t in (Vocab.UNK_TOKEN, )
    }
    special_tokens[Vocab.PAD_TOKEN] = np.zeros((embedding_size,))

    nb_unk = 0
    for i, t in vocab.id2token.items():
        if t in special_tokens:
            W_emb[i] = special_tokens[t]
        else:
            if t in word_embeddings:
                W_emb[i] = word_embeddings[t]
            else:
                W_emb[i] = np.random.uniform(-0.3, 0.3, embedding_size)
                nb_unk += 1

    print(f'Nb unk: {nb_unk}')

    return W_emb

In [17]:
W_emb = create_embeddings_matrix(word_embeddings, dataset.vocab)

Nb unk: 5182


## Model

We are going to use a standard seq2seq model. Given an input query (sentence), the model produces a response. Although this model does not have any context information, it provides a good starting point. 

In [18]:
class Seq2SeqModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, teacher_forcing,
                 max_len,trainable_embeddings, start_index, end_index, pad_index, W_emb=None):

        super().__init__()

        self.teacher_forcing = teacher_forcing
        self.max_len = max_len
        self.start_index = start_index
        self.end_index = end_index
        self.pad_index = pad_index
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index)
        if W_emb is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(W_emb))
        if not trainable_embeddings:
            self.embedding.weight.requires_grad = False

        self.encoder = torch.nn.GRU(embedding_size, hidden_size, batch_first=True)
        self.decoder = torch.nn.GRUCell(embedding_size, hidden_size)
        self.decoder_projection = torch.nn.Linear(hidden_size, vocab_size)

            
    def encode(self, inputs):
        batch_size = inputs.size(0)
        inputs_lengths = torch.sum(inputs != self.pad_index, dim=1).long()
        
        inputs_emb = self.embedding(inputs)
        outputs, h = self.encoder(inputs_emb)
        
        h_last_hidden = outputs[np.arange(batch_size), inputs_lengths - 1]
        
        return h_last_hidden
    
    def decode(self, decoder_hidden, targets=None):
        batch_size = decoder_hidden.size(0)
        
        outputs_logits = []
        decoder_inputs = torch.full_like(decoder_hidden[:, 0], self.start_index).long()
        for i in range(self.max_len):
            decoder_inputs_emb = self.embedding(decoder_inputs)
            
            decoder_hidden = self.decoder(decoder_inputs_emb, decoder_hidden)
            
            decoder_output_logit = self.decoder_projection(decoder_hidden)
            
            if np.random.rand() < self.teacher_forcing and targets is not None:
                decoder_inputs = targets[:, i]
            else:
                decoder_inputs = decoder_output_logit.argmax(dim=1).long()
            
            outputs_logits.append(decoder_output_logit)
            
        outputs_logits = torch.stack(outputs_logits, dim=1)
            
        return outputs_logits
        
    def forward(self, inputs, targets=None):
        decoder_hidden = self.encode(inputs)
        outputs_logits = self.decode(decoder_hidden, targets)

        return outputs_logits

In [19]:
def softmax_masked(inputs, mask, dim=1, epsilon=0.000001):
    inputs_exp = torch.exp(inputs)
    inputs_exp = inputs_exp * mask.float()
    inputs_exp_sum = inputs_exp.sum(dim=dim)
    inputs_attention = inputs_exp / (inputs_exp_sum.unsqueeze(dim) + epsilon)

    return inputs_attention

In [20]:
class Seq2SeqAttentionModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, teacher_forcing,
                 max_len,trainable_embeddings, start_index, end_index, pad_index, W_emb=None):

        super().__init__()

        self.teacher_forcing = teacher_forcing
        self.max_len = max_len
        self.start_index = start_index
        self.end_index = end_index
        self.pad_index = pad_index
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index)
        if W_emb is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(W_emb))
        if not trainable_embeddings:
            self.embedding.weight.requires_grad = False

        self.encoder = torch.nn.GRU(embedding_size, hidden_size, batch_first=True)
        self.decoder = torch.nn.GRUCell(embedding_size, hidden_size)

        self.attention_decoder = torch.nn.Linear(hidden_size, hidden_size)
        self.attention_encoder = torch.nn.Linear(hidden_size, hidden_size)        
        self.attention_reduce = torch.nn.Linear(hidden_size, 1, bias=False)                
        self.decoder_hidden_combine = torch.nn.Linear(hidden_size * 2, hidden_size)
        
        self.decoder_projection = torch.nn.Linear(hidden_size, vocab_size)

            
    def encode(self, inputs):
        batch_size = inputs.size(0)
        inputs_mask = (inputs != self.pad_index).long()
        inputs_lengths = torch.sum(inputs_mask, dim=1)
        
        inputs_emb = self.embedding(inputs)
        outputs, h = self.encoder(inputs_emb)
        
#         h_last_hidden = outputs[np.arange(batch_size), inputs_lengths - 1]
        
        return outputs, inputs_mask
    
    def decode(self, encoder_hiddens, inputs_mask, targets=None):
        batch_size = encoder_hiddens.size(0)

        outputs_logits = []
        decoder_hidden = torch.zeros_like(encoder_hiddens[:,0,:])
        decoder_inputs = torch.full_like(decoder_hidden[:, 0], self.start_index).long()
        for i in range(self.max_len):
            decoder_inputs_emb = self.embedding(decoder_inputs)
            
            att_enc = self.attention_encoder(encoder_hiddens)
            att_dec = self.attention_decoder(decoder_hidden)
            att = torch.tanh(att_enc + att_dec.unsqueeze(1))
            att_reduced = self.attention_reduce(att).squeeze(-1)
            att_normazlied = softmax_masked(att_reduced, inputs_mask)

            decoder_hidden_att = torch.sum(encoder_hiddens * att_normazlied.unsqueeze(-1), dim=1)
            decoder_hidden_combined = self.decoder_hidden_combine(torch.cat([decoder_hidden, decoder_hidden_att], dim=-1))
            
            decoder_hidden = self.decoder(decoder_inputs_emb, decoder_hidden_combined)
            
            decoder_output_logit = self.decoder_projection(decoder_hidden)
            
            if np.random.rand() < self.teacher_forcing and targets is not None:
                decoder_inputs = targets[:, i]
            else:
                decoder_inputs = decoder_output_logit.argmax(dim=1).long()
            
            outputs_logits.append(decoder_output_logit)
            
        outputs_logits = torch.stack(outputs_logits, dim=1)
            
        return outputs_logits
        
    def forward(self, inputs, targets=None):
        encoder_hiddens, inputs_mask = self.encode(inputs)
        outputs_logits = self.decode(encoder_hiddens, inputs_mask, targets)

        return outputs_logits

Below are some helper functions to save and load the weights of the model. Feel free to use them in your projects!

In [21]:
def load_model(model_class, filename):
    def _map_location(storage, loc):
        return storage

    # load trained on GPU models to CPU
    map_location = None
    if not torch.cuda.is_available():
        map_location = _map_location

    state = torch.load(str(filename), map_location=map_location)

    model = model_class(**state['model_params'])
    model.load_state_dict(state['model_state'])

    return model


def save_model(model, filename, model_params=None):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    state = {
        'model_params': model_params or {},
        'model_state': model.state_dict(),
    }

    torch.save(state, str(filename))

In [22]:
hidden_size = 256
trainable_embeddings = True
teacher_forcing = 0.5

In [23]:
model_params = dict(
    hidden_size=hidden_size,
    trainable_embeddings=trainable_embeddings,
    teacher_forcing=teacher_forcing,

    vocab_size = len(dataset.vocab),
    embedding_size=W_emb.shape[1],
    max_len=dataset.max_len,
    start_index=dataset.vocab[Vocab.START_TOKEN],
    end_index=dataset.vocab[Vocab.END_TOKEN], 
    pad_index=dataset.vocab[Vocab.PAD_TOKEN],
)

In [24]:
model = Seq2SeqAttentionModel(**model_params, W_emb=W_emb)

In [25]:
model = model.to('cuda')

In [26]:
model

Seq2SeqAttentionModel(
  (embedding): Embedding(50000, 300, padding_idx=0)
  (encoder): GRU(300, 256, batch_first=True)
  (decoder): GRUCell(300, 256)
  (attention_decoder): Linear(in_features=256, out_features=256, bias=True)
  (attention_encoder): Linear(in_features=256, out_features=256, bias=True)
  (attention_reduce): Linear(in_features=256, out_features=1, bias=False)
  (decoder_hidden_combine): Linear(in_features=512, out_features=256, bias=True)
  (decoder_projection): Linear(in_features=256, out_features=50000, bias=True)
)

## Training

In [27]:
model_filename = 'tmp/seq2seq_dialog_att.pt'

In [28]:
batch_size=128
nb_epochs = 500
learning_rate=0.001
weight_decay = 0.000001

In [29]:
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

In [30]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
losses = []

for epoch in tqdm_notebook(range(nb_epochs), desc='Epochs'):
    epoch_losses = []
    for i, (query, response) in enumerate(tqdm_notebook(dataloader, desc='Iteration', leave=False)):
        model.train()
        optimizer.zero_grad()

        query = query.to('cuda')
        response = response.to('cuda')        
        
        response_logits = model(query, response)
    
        # reshape for the cross entropy loss
        response_logits = response_logits.view(-1, response_logits.size(2))
        response = response.view(-1)
        loss = criterion(response_logits, response)
        
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

    epoch_loss = np.mean(epoch_losses)
        
    losses.append(epoch_loss)
    print('Epoch {}, loss {}'.format(epoch, epoch_loss))
    
    save_model(model, model_filename, model_params=model_params)

HBox(children=(IntProgress(value=0, description='Epochs', max=500, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Iteration', max=7813, style=ProgressStyle(description_width='…

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 0, loss 0.8906596688551068


HBox(children=(IntProgress(value=0, description='Iteration', max=7813, style=ProgressStyle(description_width='…

### Plot the training loss

In [None]:
fig, ax = plt.subplots()

ax.plot(np.arange(len(losses)), losses)

ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

fig.tight_layout()

## Try some inputs

In [None]:
model = load_model(Seq2SeqModel, model_filename)
model = model.to('cuda')

In [None]:
def generate_response(query, model, dataset):
    if not isinstance(query, list):
        query = word_tokenize(query)
        
    query = dataset._process_sent(query)
    query = torch.tensor(query).to('cuda')
        
    response_logits = model(query.view(1, -1)).squeeze(0)
    response_indices = response_logits.argmax(dim=-1).cpu().numpy()
    
    response = [dataset.vocab.id2token[int(idx)] for idx in response_indices]
    response = [t for t in response if t not in dataset.vocab.special_tokens]
    response = ' '.join(response)

    return response

In [None]:
query = 'How are you?'

In [None]:
generate_response(query, model, dataset)

In [None]:
def print_response(text_widget):
    query = text_widget.value
    response = generate_response(query, model, dataset)
    print(f'{query} -> {response}')

In [None]:
text_input = widgets.Text(value='How are you?')
text_input.on_submit(print_response)

In [None]:
text_input