# Morphological Reinflection with Encoder-Decoder

Morphological reinflection is the task of generating a target form given a source form, a source tag and a target tag.

In this lab, we address on these tasks, i.e., Morphological reinflection, which is defined as:
```
Given an inflected form and its current tag, generate a target inflected form.

English example:

Source tag: Past 
Source form: ran 
Target tag: Present participle

Output: running
```

## Morphological Reinflection with Seq2Seq

In this lab, we will cover one of the fundamental breakthough method in solving this task with the use of Attention
based Seq2Seq network, used by [Kann et al.](http://anthology.aclweb.org/P16-2090) in 2016 to win the competition.

In subsequent years, multiple variants have been developed built on top of the vanilla Seq2seq model.

In [3]:
import numpy as np
import time
import os.path

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading Data

The dataset can be downloaded from [here](http://TODO). For this lab, we'll work with the German dataset, where each row contains the source and target:
```
pos=ADJ,case=ACC,comp=CMPR,gen=FEM,num=SG	aerodynamischere	pos=ADJ,case=ACC,comp=SPRL,gen=NEUT,num=SG	aerodynamischstes
pos=ADJ,case=ACC,comp=CMPR,gen=FEM,num=SG	aktivere	pos=ADJ,case=NOM,comp=SPRL,gen=NEUT,num=SG	aktivstes
pos=ADJ,case=ACC,comp=CMPR,gen=FEM,num=SG	ambitioniertere	pos=ADJ,case=GEN,gen=FEM,num=SG	ambitionierter
pos=ADJ,case=ACC,comp=CMPR,gen=FEM,num=SG	aufnahmefähigere	pos=ADJ,case=DAT,comp=CMPR,gen=FEM,num=SG	aufnahmefähigerer
```

Information in first 3 columns are fed to the encoder and the decoder outputs the word in the last columns. We tokenize the word to character level.

In [4]:
PAD_TOKEN = 0
START_TOKEN = 1
END_TOKEN = 2
UNK_TOKEN = 3
PAD_TAG = "<pad>"
START_TAG = "<w>"
END_TAG = "</w>"
UNKNOWN_TAG = "<unk>"

MAX_LENGTH = 10

def load_dataset(file_name):
    inputs = []
    outputs = []
    with open(file_name) as f:
        for line in f.readlines():
            example = line.split()
            example[1] = example[1].lower()
            example[3] = example[3].lower()
            example[0] = example[0].split(",")
            example[2] = example[2].split(",")
            
            if len(example[3]) <= MAX_LENGTH:
                inputs.append(example[:3])
                outputs.append(example[3])
    return np.array(inputs), np.array(list(outputs))

# Augment dataset by reversing the order, source->target becomes target->source
def enhance_dataset(inputs, outputs):
    # TODO implement
    return inputs, outputs

# We create 2 different vocab sets for source and target
# because the source contains the set of morphological tags
# which we do not require on the target end
def preprocess_data(inputs, outputs, train=True):
    if train:
        inputs, outputs = enhance_dataset(inputs, outputs)
    inputs = edit_tags(inputs)
    inputs[:, [1, 2]] = inputs[:, [2, 1]]
    inputs = transform_to_sequences(inputs)

    input_vocab = get_vocab(inputs)
    output_vocab = get_vocab(outputs)

    return inputs, outputs, input_vocab, output_vocab

# Specify input/output information in tags (from paper http://anthology.aclweb.org/P16-2090)
def edit_tags(inputs):
    for i in range(0, inputs.shape[0]):
        inputs[i, 0] = np.array(["IN=" + x for x in inputs[i, 0]])
        inputs[i, 2] = np.array(["OUT=" + x for x in inputs[i, 2]])
    return inputs

# tokenize words to characters
def transform_to_sequences(inputs):
    input_seq = np.array([np.concatenate((inputs[i, 0], inputs[i, 1], list(inputs[i, 2])))
         for i in range(inputs.shape[0])])
    return input_seq

# create vocab map from data
def get_vocab(data):
    idx_to_char = {0: PAD_TAG, 1: START_TAG, 2: END_TAG, 3: UNKNOWN_TAG}
    char_to_idx = {PAD_TAG: 0, START_TAG: 1, END_TAG: 2, UNKNOWN_TAG: 3}
    char_set = set([])
    for i in range(0, data.shape[0]):
        char_set.update(data[i])
    char_set = sorted(char_set)
    for i in range(0, len(char_set)):
        idx_to_char[i+4] = char_set[i]
        char_to_idx[char_set[i]] = i+4
    return idx_to_char, char_to_idx

# return token indices
def get_indices(input, vocab):
    return  [vocab[ch] for ch in input] + [vocab[END_TAG]]

## Dataset class
We'll go ahead and define a Dataset class that reads data file and performs all the preprocessing ops defined above. 

We also define a collate function for padding while serving batches. Here we don't define a max_length for the input sequence as it only consists of a single word and a finite sequence of tags.

In [5]:
class MEDDataset(Dataset):

    def __init__(self, file_name, train=True):
        inputs, outputs = load_dataset(file_name)
        inputs, outputs, in_vocab, out_vocab = preprocess_data(inputs, outputs, train)
        self.inputs = inputs
        self.outputs = outputs
        self.in_vocab = in_vocab
        self.out_vocab = out_vocab

    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, index):
        
        src = get_indices(self.inputs[index], self.in_vocab[1])
        trg = get_indices(self.outputs[index], self.out_vocab[1])
        return src, trg

def med_collate_fn(data):

    def _pad_sequences(seqs):
        lens = [len(seq) for seq in seqs]
        padded_seqs = torch.zeros(len(seqs), max(lens)).long()
        for i, seq in enumerate(seqs):
            end = lens[i]
            padded_seqs[i, :end] = torch.LongTensor(seq[:end])
        return padded_seqs, lens

    data.sort(key=lambda x: len(x[0]), reverse=True)
    src_seqs, trg_seqs = zip(*data)
    src_seqs, src_lens = _pad_sequences(src_seqs)
    trg_seqs, trg_lens = _pad_sequences(trg_seqs)

    #(batch, seq_len) => (seq_len, batch)
    src_seqs = src_seqs.transpose(0,1)
    trg_seqs = trg_seqs.transpose(0,1)

    return src_seqs, src_lens, trg_seqs, trg_lens

###  Verifying the dataset

To make sure our ops are correct, we'll initialise the dataset and do a sanity check.

In [6]:
# Initialize dataset
dataset = MEDDataset("data/german-task2-train.txt")

print("Number of source-target pairs:", len(dataset))
print("Input: ", [dataset.in_vocab[0][_] for _ in dataset[0][0]])
print("\n")
print("Output: ", [dataset.out_vocab[0][_] for _ in dataset[0][1]])

Number of source-target pairs: 1513
Input:  ['IN=pos=ADJ', 'IN=case=ACC', 'IN=comp=CMPR', 'IN=gen=FEM', 'IN=num=SG', 'OUT=pos=ADJ', 'OUT=case=NOM', 'OUT=comp=SPRL', 'OUT=gen=NEUT', 'OUT=num=SG', 'a', 'k', 't', 'i', 'v', 'e', 'r', 'e', '</w>']


Output:  ['a', 'k', 't', 'i', 'v', 's', 't', 'e', 's', '</w>']


## The Encoder

The encoder is bidirectional and we use dropout in the GRU cell. The output states of the 2 directions are summed.

In [9]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, n_layers=1, dropout=0.1):
        super(EncoderRNN, self).__init__()
        
        self.input_size = input_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = nn.Embedding(input_size, embed_size, padding_idx=PAD_TOKEN)
        self.gru = nn.GRU(embed_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
        
    def forward(self, input_seqs, input_lengths, hidden=None):
        # Note: we run this all at once (over multiple batches of multiple sequences)
        embedded = self.embedding(input_seqs)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
        return outputs, hidden

## The Decoder

[Effective Approaches to Attention-based Neural Machine Translation by Luong et al.](https://arxiv.org/pdf/1508.04025.pdf) describe a few more attention models that offer improvements and simplifications. They describe a few "global attention" models, the distinction between them being the way the attention scores are calculated.

The general form of the attention calculation relies on the target (decoder) side hidden state and corresponding source (encoder) side state, normalized over all states to get values summing to 1.

The specific "score" function that compares two states is either dot, a simple dot product between the states; general, a a dot product between the decoder hidden state and a linear transform of the encoder state; or concat, a dot product between a new parameter $v_a$ and a linear transform of the states concatenated together.

The modular definition of these scoring functions gives us an opportunity to build specific attention module that can switch between the different score methods. The input to this module is always the hidden state (of the decoder RNN) and set of encoder outputs.

## The Attention Module

In [10]:
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        
        self.method = method
        self.hidden_size = hidden_size
        
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs):

        # Create variable to store attention energies

        # For each batch of encoder outputs
        # Calculate energy for each encoder output
        
        # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
        
        # Return context vectors
        return None
    
    def score(self, hidden, encoder_output):
        
        if self.method == 'dot':
            ## TODO implement
            return energy
        
        elif self.method == 'general':
            energy = None
            ## TODO implement 
            return energy
        
        elif self.method == 'concat':
            energy = None
            ## TODO implement 
            return energy

## Luong et al. Decoder model

Now we can build a decoder that plugs this Attn module in after the RNN to calculate attention weights, and apply those weights to the encoder outputs to get a context vector.

In [40]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD_TOKEN)
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout)
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        # Choose attention model
        if attn_model != 'none':
            self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_seq, last_hidden, encoder_outputs):
        # Note: we run this one step at a time

        # Get the embedding of the current input word (last output word)
        batch_size = input_seq.size(0)
        embedded = self.embedding(input_seq)
        embedded = self.embedding_dropout(embedded)
        embedded = embedded.view(1, batch_size, self.hidden_size) # S=1 x B x N

        # Get current hidden state from input word and last hidden state
        rnn_output, hidden = self.gru(embedded, last_hidden)

        # Calculate attention from current RNN state and all encoder outputs;
        # apply to encoder outputs to get weighted average
        context = self.attn(rnn_output, encoder_outputs)

        # Attentional vector using the RNN hidden state and context vector
        # concatenated together (Luong eq. 5)
        rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
        context = context.squeeze(1)       # B x S=1 x N -> B x N
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = F.tanh(self.concat(concat_input))

        # Finally predict next token (Luong eq. 6, without softmax)
        output = self.out(concat_output)

        # Return final output, hidden state, and attention weights (for visualization)
        return output, hidden, attn_weights

## Training

In [41]:
def train_step(src_batch, src_lens, trg_batch, trg_lens, encoder, decoder, 
               encoder_optimizer, decoder_optimizer, criterion):
    
    # Zero gradients of both optimizers
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss, em_accuracy, edit_distance = 0.0, 0.0, 0.0

    # Run words through encoder
    encoder_outputs, encoder_hidden = encoder(src_batch, src_lens, None)
    
    # Prepare input and output variables
    decoder_input = torch.LongTensor([START_TOKEN] * batch_size).to(device)
    decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder

    max_trg_len = max(trg_lens)

    # Run through decoder one time step at a time using TEACHER FORCING=1.0
    for t in range(max_trg_len):
        decoder_output, decoder_hidden, decoder_attn = decoder(
            decoder_input, decoder_hidden, encoder_outputs
        )
        loss += criterion(decoder_output, trg_batch[t])

    # TODO implement accuracy
    # TODO implement Levenshtein/edit distance
        
    loss = loss / max_trg_len
    loss.backward()
    
    # Clip gradient norms
    enc_grads = torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
    dec_grads = torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)

    # Update parameters with optimizers
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item(), em_accuracy, edit_distance #, enc_grads, dec_grads

In [42]:
def save_checkpoint(encoder, decoder, checkpoint_dir):
    enc_filename = "{}/enc-{}.pth".format(checkpoint_dir, time.strftime("%d%m%y-%H%M%S"))
    dec_filename = "{}/dec-{}.pth".format(checkpoint_dir, time.strftime("%d%m%y-%H%M%S"))
    torch.save(encoder.state_dict(), enc_filename)
    torch.save(decoder.state_dict(), dec_filename)
    print("Model saved.")

def train(dataset, batch_size, n_epochs, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, 
          checkpoint_dir=None, save_every=500):
    train_iter = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=4,
                            collate_fn=med_collate_fn,
                            drop_last=True)
    for i in range(n_epochs):
        tick = time.clock()
        print("Epoch {}/{}".format(i+1, n_epochs))
        losses, accs, eds = [], [], []
        for batch_idx, batch in enumerate(train_iter):
            input_batch, input_lengths, target_batch, target_lengths = batch
            loss, accuracy, edit_distance = train_step(input_batch, input_lengths, target_batch, target_lengths, 
                                 encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
            losses.append(loss)
            accs.append(accuracy)
            eds.append(edit_distance)            
            
            if batch_idx % 100 == 0:
                print("batch: {}, loss: {}, accuracy: {}, edit distance: {}".format(batch_idx, loss, accuracy, 
                                                                                   edit_distance))
            if checkpoint_dir:
                if batch_idx % save_every == 0:
                    save_checkpoint(encoder, decoder, checkpoint_dir)
        tock = time.clock()
        print("Time: {} Avg loss: {} Avg acc: {} Edit Dist.: {}".format(
            tock-tick, np.mean(losses), np.mean(accs), np.mean(eds)))
    
    if checkpoint_dir:
        save_checkpoint(encoder, decoder, checkpoint_dir)

## Configuring and Initializing Models

In [43]:
# Configure models
attn_model = 'dot'
hidden_size = 100
embed_size = 300
n_layers = 1
dropout = 0.1
batch_size = 20
checkpoint_dir = "checkpoints"

# Configure training/optimization
clip = 50.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_epochs = 20

In [44]:
# Initialize models
encoder = EncoderRNN(len(dataset.in_vocab[0]), embed_size, hidden_size, n_layers, dropout=dropout).to(device)
decoder = LuongAttnDecoderRNN(attn_model, hidden_size, len(dataset.out_vocab[0]), n_layers, dropout=dropout).to(device)

# Initialize optimizers and criterion
encoder_optimizer = optim.Adadelta(encoder.parameters())
decoder_optimizer = optim.Adadelta(decoder.parameters())
criterion = nn.CrossEntropyLoss()

  "num_layers={}".format(dropout, num_layers))


In [None]:
train(dataset, 
      batch_size, 
      n_epochs, 
      encoder, 
      decoder, 
      encoder_optimizer, 
      decoder_optimizer, 
      criterion, 
      checkpoint_dir)

###  Exercises

1. Return attention weights and visualize them.
2. Train model on full data (remove the max_length constraint).
2. Train model with the Attention+Decoder that we learned in the previous lab
3. Validate the model on [test data] (https://github.com/ryancotterell/sigmorphon2016/blob/master/data/german-task2-test)

### Take Home Questions

1. How will you go about using MRI to improve LM and NMT? (think about using this in your project)
2. Where else can apply MRI?
3. What kind of improvements can be made to the current model? (hint: one way is to only learn what changes need to be made to the lemma given a form type)