In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn.utils.rnn import pad_sequence

from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import corpus_bleu

import numpy as np
import pandas as pd

import sys
import os
import random

from collections import defaultdict
from tqdm.notebook import tqdm

import wandb

In [3]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [5]:
!tar -xvf '/content/drive/MyDrive/fr-en.tgz' -C '.'

europarl-v7.fr-en.en
europarl-v7.fr-en.fr


In [6]:
with open('/content/europarl-v7.fr-en.en', 'r') as f:
    print(f.readlines()[:2])
with open('/content/europarl-v7.fr-en.fr', 'r') as f:
    print(f.readlines()[:2])

['Resumption of the session\n', 'I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.\n']
['Reprise de la session\n', 'Je déclare reprise la session du Parlement européen qui avait été interrompue le vendredi 17 décembre dernier et je vous renouvelle tous mes vux en espérant que vous avez passé de bonnes vacances.\n']


In [18]:
class Vocab:
    def __init__(self, path, lang='english', capacity=50000, maxsize=10000000):
        self.t2i = {'<PAD>': 0, '<UNK>': 1, '<CTC>': 2, '<SOS>': 3, '<EOS>': 4}
        self.i2t = {0: '<PAD>', 1: '<UNK>', 2: '<CTC>', 3: '<SOS>', 4: '<EOS>'}
        self.lang = lang
        self.capacity = capacity
        counts = defaultdict(int)
        content = ''
        with open(path, 'r') as f:
            for line in tqdm(f.readlines()[:maxsize]):
                content = line.strip().lower()
                tokenized_content = word_tokenize(content, language=lang)
                for word in tokenized_content:
                    counts[word] += 1

        picked_words = sorted(counts.keys(), key=lambda x: counts[x], reverse=True)[:capacity]
        for word in picked_words:
            self.t2i[word] = len(self.t2i)
            self.i2t[self.t2i[word]] = word
                
class TrDataset(Dataset):
    def __init__(self, paths, vocabs, langs, maxsize=10000000, maxlen=20):
        self.vocabs = dict([(lang, vocab) for lang, vocab in zip(langs, vocabs)])
        self.paths = dict([(lang, path) for lang, path in zip(langs, paths)])
        self.data = dict([(lang, []) for lang in langs])
        self.langs = langs
        with open(self.paths[langs[0]], 'r') as f1:
            with open(self.paths[langs[1]], 'r') as f2:
                for lines in tqdm(zip(f1.readlines()[:maxsize], f2.readlines()[:maxsize])):
                    if len(lines[0].split(' ')) < maxlen and len(lines[1].split(' ')) < maxlen and len(lines[0].split(' ')) > 1 and len(lines[1].split(' ')) > 1:
                        for line, lang in zip(lines, langs):
                            content = line.strip().lower()
                            tokenized_content = word_tokenize(content)
                            content_tokens = [self.vocabs[lang].t2i['<SOS>']]
                            for word in tokenized_content:
                                if word in self.vocabs[lang].t2i:
                                    content_tokens.append(self.vocabs[lang].t2i[word])
                                else:
                                    content_tokens.append(self.vocabs[lang].t2i['<UNK>'])
                            content_tokens.append(self.vocabs[lang].t2i['<EOS>'])
                            self.data[lang].append(content_tokens)
        
    def __getitem__(self, idx):
        return [torch.LongTensor(self.data[lang][idx]) for lang in self.langs]

    def __len__(self):
        return len(self.data[self.langs[0]])

    def collate_fn(self, batch):
        l1, l2 = list(zip(*batch))
        l1 = pad_sequence(l1, batch_first=True)
        l2 = pad_sequence(l2, batch_first=True)
        return l1, l2

In [19]:
vocabs = [Vocab('/content/europarl-v7.fr-en.fr'), Vocab('/content/europarl-v7.fr-en.en')]
langs = ['french', 'english']
dataset = TrDataset(['/content/europarl-v7.fr-en.fr', '/content/europarl-v7.fr-en.en'], vocabs, langs)

  0%|          | 0/2007723 [00:00<?, ?it/s]

  0%|          | 0/2007723 [00:00<?, ?it/s]

0it [00:00, ?it/s]

In [20]:
len(dataset)

698950

In [21]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [300000, len(dataset) - 300000])

In [22]:
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=dataset.collate_fn)
val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=dataset.collate_fn)

In [23]:
for batch in train_loader:
    print(batch[0].shape)
    break

torch.Size([32, 24])


In [24]:
class Encoder(nn.Module):
    def __init__(self, index_dim, embedding_dim=256, hidden_dim=256, dropout=0.2, layers=2):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(index_dim, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=layers, dropout=dropout, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.directions = 2
        self.layers = layers

    def forward(self, x):
        mask = (x != 0).to(torch.long)
        lengths = mask.sum(dim=1).to('cpu')

        emb = self.embedding(x)
        emb = pack_padded_sequence(emb, lengths, batch_first=True, enforce_sorted=False)
        out, hidden = self.lstm(emb)
        out, _ = pad_packed_sequence(out, batch_first=True)
        return self.dropout(out), hidden

class Attention(nn.Module):
    def __init__(self, hidden_dim, attention_dim, directions=2):
        super(Attention, self).__init__()
        self.W1 = nn.Linear(hidden_dim * directions, attention_dim)
        self.W2 = nn.Linear(hidden_dim * directions, attention_dim)
        self.V = nn.Linear(attention_dim, 1)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, hidden, enc_out):
        score = self.V(self.tanh(self.W1(enc_out) + self.W2(hidden.view(hidden.shape[1], -1).unsqueeze(1)))).squeeze(-1)
        attention_weights = self.softmax(score)
        assert len(attention_weights.shape) == 2
        
        context_vector = attention_weights.unsqueeze(-1) * enc_out
        context_vector = context_vector.sum(axis=1)
        return context_vector

class Decoder(nn.Module):
    def __init__(self, index_dim, embedding_dim=256, hidden_dim=256, attention_dim=128, enc_directions=2):
        super(Decoder, self).__init__()
        self.attention = Attention(hidden_dim, attention_dim)
        self.embedding = nn.Embedding(index_dim, embedding_dim)
        self.linear = nn.Linear(hidden_dim * enc_directions, index_dim)
        self.lstm = nn.LSTM(embedding_dim + hidden_dim * enc_directions, hidden_dim * enc_directions, batch_first=True)
        self.relu = nn.ReLU()

    def forward(self, x, hidden, enc_out):
        context_vector = self.attention(hidden[0], enc_out)
        emb = self.relu(self.embedding(x))
        emb = torch.cat([context_vector.unsqueeze(1), emb], dim=-1)
        out, hidden = self.lstm(emb, hidden)
        out = self.linear(out)
        return out, hidden

class CTCDecoder(nn.Module):
    def __init__(self, index_dim, hidden_dim=512):
        super(CTCDecoder, self).__init__()
        self.ctc_linear = nn.Linear(hidden_dim, index_dim)
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, x):
        return self.softmax(self.ctc_linear(x))
    

class Seq2Seq(nn.Module):
    def __init__(self, encoder_index_dim, decoder_index_dim):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(encoder_index_dim)
        self.ctc = CTCDecoder(decoder_index_dim)
        self.decoder = Decoder(decoder_index_dim)

        self.encoder_index_dim = encoder_index_dim
        self.decoder_index_dim = decoder_index_dim

    def forward(self, x, y, use_teacher_forcing):
        assert x.shape[0] == y.shape[0]
        assert x.device == y.device

        batch_size = x.shape[0]
        device = x.device

        input_length = x.shape[1]
        target_length = y.shape[1]

        enc_out, hidden = self.encoder(x)

        hidden = [hidden[0].view(self.encoder.layers, batch_size, -1).mean(dim=0).unsqueeze(0), 
                  hidden[1].view(self.encoder.layers, batch_size, -1).mean(dim=0).unsqueeze(0)]
        ctc = self.ctc(enc_out)
        out = y[:, 0].unsqueeze(1)

        outputs = torch.zeros(batch_size, target_length, self.decoder_index_dim).to(device)

        if use_teacher_forcing:
            for di in range(1, y.shape[1]):
                out, hidden = self.decoder(out.to(dtype=torch.long), hidden, enc_out)
                outputs[:, di:di + 1] = out

                out = y[:, di].unsqueeze(1)
        else:
            for di in range(1, y.shape[1]):
                out, hidden = self.decoder(out.to(dtype=torch.long), hidden, enc_out)
                outputs[:, di:di + 1] = out

                out = out.argmax(dim=-1)

        return ctc, outputs

In [25]:
def decode_ctc(sent):
    if not sent:
        return sent
    last_token = sent[0]
    real_sent = []
    for token in sent[1:]:
        if token != '<CTC>' and token != last_token:
            real_sent.append(last_token)
        last_token = token
    return real_sent

In [26]:
class TrLSTMTCTCrainer:
    def __init__(self, model: nn.Module, train_loader, val_loader, tf=0.25, lr=3e-4, betas=(0.9, 0.999),
                 project="ctc_translation", name='ctc_model', save_every=None, save_path='./', ctc=True, gen=True):
        self.ctc_criterion = nn.CTCLoss(blank=vocabs[0].t2i['<CTC>'], zero_infinity=True)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        self.save_path = save_path
        self.save_every = save_every
        self.name = name
        self.project = project
        self.tf = tf
        self.ctc = ctc
        self.gen = gen
        wandb.init(project=project, name=name)

    def train_epoch(self, cuda=True, clip=1):
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        self.model.train()
        total_loss = 0
        total_ctc_loss = 0
        total_gen_loss = 0
        for batch_idx, (tokens, targets) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            if cuda:
                tokens = tokens.cuda()
                targets = targets.cuda()
            use_teacher_forcing = True if random.random() < self.tf else False
            input_lengths = (tokens != 0).sum(dim=1)
            target_lengths = (targets != 0).sum(dim=1)
            ctc, outputs = self.model(tokens.to(dtype=torch.long), targets.to(dtype=torch.long), use_teacher_forcing)

            loss = 0

            gen_loss = self.criterion(outputs[:, 1:].reshape(-1, self.model.decoder_index_dim), targets[:, 1:].flatten())
            total_gen_loss += gen_loss.item()
            if self.gen:
                loss += gen_loss
            ctc_loss = self.ctc_criterion(ctc.permute(1, 0, 2).to(dtype=torch.float), targets.to(dtype=torch.long), 
                                          input_lengths=input_lengths, target_lengths=target_lengths)
            total_ctc_loss += ctc_loss.item()
            if self.ctc:
                loss += ctc_loss
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), clip)
            self.optimizer.step()

            total_loss += loss.item()

            print('\rTrain loss: %4f, gen_loss: %4f, ctc_loss: %4f, Batch: %d of %d' % (
                total_loss / (batch_idx + 1), total_ctc_loss / (batch_idx + 1), total_gen_loss / (batch_idx + 1), batch_idx + 1, len(self.train_loader)), end='')
        print()
        loss = total_loss / len(self.train_loader)
        ctc_loss = total_ctc_loss / len(self.train_loader)
        gen_loss = total_gen_loss / len(self.train_loader)
        return loss, ctc_loss, gen_loss

    def test_epoch(self, cuda=True):
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        with torch.no_grad():
            self.model.eval()
            total_loss = 0
            total_ctc_loss = 0
            total_gen_loss = 0
            for batch_idx, (tokens, targets) in enumerate(self.val_loader):
                if cuda:
                    tokens = tokens.cuda()
                    targets = targets.cuda()
                use_teacher_forcing = False
                input_lengths = (tokens != 0).sum(dim=1)
                target_lengths = (targets != 0).sum(dim=1)
                ctc, outputs = self.model(tokens.to(dtype=torch.long), targets.to(dtype=torch.long), use_teacher_forcing)

                loss = self.criterion(outputs[:, 1:].reshape(-1, self.model.decoder_index_dim), targets[:, 1:].flatten())
                ctc_loss = self.ctc_criterion(ctc.permute(1, 0, 2).to(dtype=torch.float), targets.to(dtype=torch.long), 
                                              input_lengths=input_lengths, target_lengths=target_lengths)
                total_ctc_loss += ctc_loss.item()
                total_gen_loss += loss.item()
                if self.ctc:
                    loss += ctc_loss
                total_loss += loss.item()
            

                print('\rVal loss: %4f, gen_loss: %4f, ctc_loss: %4f, Batch: %d of %d' % (
                    total_loss / (batch_idx + 1), total_ctc_loss / (batch_idx + 1), total_gen_loss / (batch_idx + 1), batch_idx + 1, len(self.val_loader)), end='')
            print()
            loss = total_loss / len(self.val_loader)
            ctc_loss = total_ctc_loss / len(self.val_loader)
            gen_loss = total_gen_loss / len(self.val_loader)
            return loss, ctc_loss, gen_loss

    def output(self, cuda=True):
        self.model.eval()
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        tokens, targets = next(iter(self.val_loader))
        tokens = tokens[1:2].to(dtype=torch.long)
        targets = targets[1:2].to(dtype=torch.long)
        batch_size = tokens.shape[0]
        if cuda:
            tokens = tokens.cuda()
            targets = targets.cuda()
        ctc, outputs = self.model(tokens, targets, False)
        outputs = outputs.argmax(dim=-1)
        ctc = ctc.argmax(dim=-1)
        sent = '<SOS>'
        summ = '<SOS>'
        ctc_sent = []
        for di in range(1, targets.shape[1]):
            sent += vocabs[1].i2t[outputs[0, di].cpu().detach().squeeze().item()] + ' '
            summ += vocabs[1].i2t[targets[0, di].cpu().detach().squeeze().item()] + ' '
        for di in range(ctc.shape[1]):
            ctc_sent.append(vocabs[1].i2t[ctc[0, di].cpu().detach().squeeze().item()])

        print(sent[:-1])
        print(summ[:-1])
        print(' '.join(decode_ctc(ctc_sent)))

    @staticmethod
    def log(epoch, 
            train_loss, train_ctc_loss, train_gen_loss, 
            test_loss, test_ctc_loss, test_gen_loss):
        wandb.log({
            'train': {
                'loss': train_loss,
                'ctc loss': train_ctc_loss,
                'cross entropy loss': train_gen_loss,
            },
            'val': {
                'loss': test_loss,
                'ctc loss': test_ctc_loss,
                'cross entropy loss': test_gen_loss,
            },
            'epoch': epoch
        })

    def checkpoint(self, epoch):
        torch.save(self.model.state_dict(), os.path.join(self.save_path, self.name + str(epoch) + '.ckpt'))

    def fit(self, max_epochs: int = 11, cuda=True, clip=1, log=False):
        for epoch in range(max_epochs):
            if self.save_every and epoch % self.save_every == 0:
                self.checkpoint(epoch)
            print('\rEpoch: %d' % epoch)
            self.output(cuda=cuda)
            train_loss, train_ctc_loss, train_gen_loss = self.train_epoch(cuda=cuda, clip=clip)
            test_loss, test_ctc_loss, test_gen_loss = self.test_epoch(cuda=cuda)
            if log:
                self.log(epoch, 
                         train_loss, train_ctc_loss, train_gen_loss, 
                         test_loss, test_ctc_loss, test_gen_loss)

In [None]:
ctc_model = Seq2Seq(len(vocabs[0].t2i), len(vocabs[1].t2i))
# ctc_model.load_state_dict(torch.load('/content/drive/MyDrive/CTC_in_NLP/checkpoints/bilstm_with_ctc_2568.ckpt'))
ctc_trainer = TrLSTMTCTCrainer(ctc_model, train_loader, val_loader, 
                               lr=3e-4, save_every=1, name='bilstm_with_ctc_256', 
                               save_path=f'/content/drive/MyDrive/CTC_in_NLP/checkpoints')
ctc_trainer.fit(log=True)
normal_model = Seq2Seq(len(vocabs[0].t2i), len(vocabs[1].t2i))
# normal_model.load_state_dict(torch.load('/content/drive/MyDrive/CTC_in_NLP/checkpoints/bilstm_without_ctc_2568.ckpt'))
normal_trainer = TrLSTMTCTCrainer(normal_model, train_loader, val_loader, 
                                  lr=3e-4, save_every=1, ctc=False, name='bilstm_without_ctc_256', 
                                  save_path=f'/content/drive/MyDrive/CTC_in_NLP/checkpoints')
normal_trainer.fit(log=True)
ctc_only_model = Seq2Seq(len(vocabs[0].t2i), len(vocabs[1].t2i))
# ctc_only_model.load_state_dict(torch.load('/content/drive/MyDrive/CTC_in_NLP/checkpoints/bilstm_ctc_only_2568.ckpt'))
ctc_only_trainer = TrLSTMTCTCrainer(ctc_only_model, train_loader, val_loader, 
                                  lr=3e-4, save_every=1, gen=False, name='bilstm_ctc_only_256', 
                                  save_path=f'/content/drive/MyDrive/CTC_in_NLP/checkpoints')
ctc_only_trainer.fit(log=True)

In [46]:
def cut_on_eos(tokens_iter):
    for token in tokens_iter:
        if token == '<EOS>':
            break
        yield token

def remove_tech_tokens(tokens_iter, tokens_to_remove=['<SOS>', '<UNK>', '<PAD>']):
    return [x for x in tokens_iter if x not in tokens_to_remove]

def generate_translation(src, trg, model):
    model.eval()

    ctc, output = model(src, trg, False)
    output = output[:, 1:].argmax(-1)

    original = remove_tech_tokens(cut_on_eos([vocabs[1].i2t[x] for x in list(trg[0, :].cpu().detach().numpy())]))
    generated = remove_tech_tokens(cut_on_eos([vocabs[1].i2t[x] for x in list(output[0, :].cpu().detach().numpy())]))
    
    print('Original: {}'.format(' '.join(original)))
    print('Generated: {}'.format(' '.join(generated)))
    print()

def get_text(x):
     generated = remove_tech_tokens(cut_on_eos([vocabs[1].i2t[elem] for elem in list(x)]))
     return generated

def calculate_bleu(model, loader):
    generated = []
    original = []
    for src, trg in loader:
        ctc, output = model(src, trg, False)
        output = output[:, 1:].argmax(-1)

        original.extend([get_text(x) for x in trg.cpu().numpy()])
        generated.extend([get_text(x) for x in output.detach().cpu().numpy()])
    
    return corpus_bleu([[text] for text in original], generated) * 100


def calculate_ctc_bleu(model, loader):
    generated = []
    original = []
    for src, trg in loader:
        ctc, output = model(src, trg, False)
        output = output[:, 1:].argmax(-1)

        original.extend([decode_ctc(get_text(x)) for x in trg.cpu().numpy()])
        generated.extend([decode_ctc(get_text(x)) for x in output.detach().cpu().numpy()])
    
    return corpus_bleu([[text] for text in original], generated) * 100

In [21]:
calculate_bleu(normal_model, val_loader), calculate_bleu(ctc_model, val_loader)

(23.010634515617348, 22.102665120303463)

In [24]:
calculate_ctc_bleu(ctc_model, val_loader)

21.94119592362224

In [47]:
calculate_ctc_bleu(ctc_only_model, val_loader)

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


10.334502197607303