In [None]:
!pip install wandb

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn.utils.rnn import pad_sequence

from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import corpus_bleu

import numpy as np
import pandas as pd

import sys
import os
import random
import math

from collections import defaultdict
from tqdm.notebook import tqdm

import wandb

In [None]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
from google.colab import drive
# drive.flush_and_unmount()
drive.mount('/content/drive')

Mounted at /content/drive


# Translation

In [None]:
!tar -xvf '/content/drive/MyDrive/fr-en.tgz' -C '.'

europarl-v7.fr-en.en
europarl-v7.fr-en.fr


In [None]:
!cp '/content/drive/MyDrive/train.en.txt' '.'

In [None]:
!cp '/content/drive/MyDrive/train.de.txt' '.'

In [None]:
!cp '/content/drive/MyDrive/dict.en-de.txt' '.'

In [None]:
!cp '/content/drive/MyDrive/newstest2014.en.txt' '.'

In [None]:
!cp '/content/drive/MyDrive/newstest2014.de.txt' '.'

In [None]:
with open('/content/europarl-v7.fr-en.en', 'r') as f:
    print(f.readlines()[:2])
with open('/content/europarl-v7.fr-en.fr', 'r') as f:
    print(f.readlines()[:2])

with open('/content/train.de.txt', 'r') as f:
    print(f.readlines()[:2])
with open('/content/train.en.txt', 'r') as f:
    print(f.readlines()[:2])
with open('/content/dict.en-de.txt', 'r') as f:
    print(''.join(f.readlines()[:100]))
with open('/content/newstest2014.en.txt', 'r') as f:
    print(f.readlines()[:2])
with open('/content/newstest2014.de.txt', 'r') as f:
    print(f.readlines()[:2])

In [None]:
class Vocab:
    def __init__(self, lang='english', capacity=20000, maxsize=10000000):
        self.t2i = {'<PAD>': 0, '<UNK>': 1, '<CTC>': 2, '<SOS>': 3, '<EOS>': 4}
        self.i2t = {0: '<PAD>', 1: '<UNK>', 2: '<CTC>', 3: '<SOS>', 4: '<EOS>'}
        self.lang = lang
        self.capacity = capacity
        self.maxsize = maxsize

    def read(self, path, save_path='./vocab'):
        counts = defaultdict(int)
        with open(path, 'r') as f:
            for line in tqdm(f.readlines()[:self.maxsize]):
                content = line.strip().lower()
                tokenized_content = word_tokenize(content, language=self.lang)
                for word in tokenized_content:
                    counts[word] += 1

        picked_words = sorted(counts.keys(), key=lambda x: counts[x], reverse=True)[:self.capacity]
        with open(os.path.join(save_path, self.lang + '.txt'), 'w') as f:
            for word in picked_words:
                f.write(word + '\n')
                self.t2i[word] = len(self.t2i)
                self.i2t[self.t2i[word]] = word
        return self

    def read_file(self, path, save_path='./vocab'):
        counts = defaultdict(int)
        with open(path, 'r') as f:
            with open(os.path.join(save_path, self.lang + '.txt'), 'w') as fv:
                for line in tqdm(f.readlines()):
                    content = line.strip().lower().split(' ')
                    if self.lang == 'english':
                        word = content[0].strip().lower()
                    else:
                        word = content[1].strip().lower()
                    if word not in self.t2i:
                        fv.write(word + '\n')
                        self.t2i[word] = len(self.t2i)
                        self.i2t[self.t2i[word]] = word
        return self

    def load(self, path):
        with open(path, 'r') as f:
            for line in tqdm(f.readlines()):
                word = line.strip().lower()
                self.t2i[word] = len(self.t2i)
                self.i2t[self.t2i[word]] = word
        return self
                
class TrDataset(Dataset):
    def __init__(self, paths, vocabs, langs, maxsize=4000000, maxlen=20):
        self.vocabs = dict([(lang, vocab) for lang, vocab in zip(langs, vocabs)])
        self.paths = dict([(lang, path) for lang, path in zip(langs, paths)])
        self.data = dict([(lang, []) for lang in langs])
        self.langs = langs
        with open(self.paths[langs[0]], 'r') as f1:
            with open(self.paths[langs[1]], 'r') as f2:
                for lines in tqdm(zip(f1.readlines()[:maxsize], f2.readlines()[:maxsize])):
                    if len(lines[0].split(' ')) < maxlen \
                    and len(lines[1].split(' ')) < maxlen \
                    and len(lines[0].split(' ')) > 1 \
                    and len(lines[1].split(' ')) > 1:
                        # remove_line = False
                        for line, lang in zip(lines, langs):
                            content = line.strip().lower()
                            tokenized_content = word_tokenize(content, language=lang)
                            content_tokens = [self.vocabs[lang].t2i['<SOS>']]
                            for word in tokenized_content:
                                if word in self.vocabs[lang].t2i:
                                    content_tokens.append(self.vocabs[lang].t2i[word])
                                else:
                                    # remove_line = True
                                    content_tokens.append(self.vocabs[lang].t2i['<UNK>'])
                            content_tokens.append(self.vocabs[lang].t2i['<EOS>'])
                            self.data[lang].append(content_tokens)
                        # if remove_line:
                        #     self.data[langs[0]].pop()
                        #     self.data[langs[1]].pop()
                        
        
    def __getitem__(self, idx):
        return [torch.LongTensor(self.data[lang][idx]) for lang in self.langs]

    def __len__(self):
        return len(self.data[self.langs[0]])

    def collate_fn(self, batch):
        l1, l2 = list(zip(*batch))
        l1 = pad_sequence(l1, batch_first=True)
        l2 = pad_sequence(l2, batch_first=True)
        return l1, l2

In [None]:
# vocabs = [Vocab(lang='french').read('/content/europarl-v7.fr-en.fr'),
#           Vocab(lang='english').read('/content/europarl-v7.fr-en.en')]
# langs = ['french', 'english']
# dataset = TrDataset(['/content/europarl-v7.fr-en.fr', '/content/europarl-v7.fr-en.en'], vocabs, langs)

vocabs = [Vocab(lang='german').read('/content/train.de.txt').read_file('/content/dict.en-de.txt'),
          Vocab(lang='english').read('/content/train.en.txt').read_file('/content/dict.en-de.txt')]
langs = ['german', 'english']
dataset = TrDataset(['/content/train.de.txt', '/content/train.en.txt'], vocabs, langs)

  0%|          | 0/4468840 [00:00<?, ?it/s]

  0%|          | 0/344780 [00:00<?, ?it/s]

  0%|          | 0/4468840 [00:00<?, ?it/s]

  0%|          | 0/344780 [00:00<?, ?it/s]

0it [00:00, ?it/s]

In [None]:
len(dataset)

854649

In [None]:
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [680000, len(dataset) - 680000])
train_dataset = dataset
test_dataset = TrDataset(['/content/newstest2014.de.txt', '/content/newstest2014.en.txt'], vocabs, langs)

0it [00:00, ?it/s]

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=dataset.collate_fn)
val_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=dataset.collate_fn)

In [None]:
for batch in train_loader:
    print(batch[0].shape)
    break

torch.Size([64, 28])


In [72]:
class Encoder(nn.Module):
    def __init__(self, index_dim, embedding_dim=256, hidden_dim=256, dropout=0.2, layers=2):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(index_dim, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=layers, dropout=dropout, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.directions = 2
        self.layers = layers

    def forward(self, x):
        mask = (x != 0).to(torch.long)
        lengths = mask.sum(dim=1).to('cpu')

        emb = self.embedding(x)
        emb = pack_padded_sequence(emb, lengths, batch_first=True, enforce_sorted=False)
        out, hidden = self.lstm(emb)
        out, _ = pad_packed_sequence(out, batch_first=True)
        return self.dropout(out), hidden

class Attention(nn.Module):
    def __init__(self, hidden_dim, attention_dim, directions=2):
        super(Attention, self).__init__()
        self.W1 = nn.Linear(hidden_dim * directions, attention_dim)
        self.W2 = nn.Linear(hidden_dim * directions, attention_dim)
        self.V = nn.Linear(attention_dim, 1)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)    

    def forward(self, hidden, enc_out):
        score = self.V(self.tanh(self.W1(enc_out) + self.W2(hidden.view(hidden.shape[1], -1).unsqueeze(1)))).squeeze(-1)
        attention_weights = self.softmax(score)
        assert len(attention_weights.shape) == 2
        
        context_vector = attention_weights.unsqueeze(-1) * enc_out
        context_vector = context_vector.sum(axis=1)
        return context_vector

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


# class DuplicationModule(nn.Module):
#     def __init__(self, input_dim):
#         super().__init__()
#         self.linear = nn.Linear(input_dim, 1)
#         self.ReLU = nn.ReLU()

#     def forward(self, x):
#         # x = torch.repeat_interleave(x.permute(1, 0, 2), 3, dim=1)
#         return self.ReLU(self.linear(x).squeeze(-1))

# class PermutationModule(nn.Module):
#     def __init__(self, input_dim):
#         super().__init__()
#         self.linear = nn.Linear(input_dim, 1)
#         self.ReLU = nn.ReLU()

#     def forward(self, x):
#         # x = torch.repeat_interleave(x.permute(1, 0, 2), 3, dim=1)
#         return self.ReLU(self.linear(x).squeeze(-1))


class CTCDecoder(nn.Module):
    def __init__(self, index_dim, hidden_dim=512):
        super(CTCDecoder, self).__init__()
        self.ctc_linear = nn.Linear(hidden_dim, index_dim)
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, x):
        # x = torch.repeat_interleave(x.permute(1, 0, 2), 1, dim=1)
        return self.softmax(self.ctc_linear(x.permute(1, 0, 2)))


class CTCTransformerDecoder(nn.Module):
    def __init__(self, index_dim, hidden_dim=512, heads=4):
        super(CTCTransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(index_dim, hidden_dim, padding_idx=0)
        self.decoder = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=heads, dim_feedforward=hidden_dim)
        self.ctc_linear = nn.Linear(hidden_dim, index_dim)
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, x):
        # x = torch.repeat_interleave(x.permute(1, 0, 2), 1, dim=1)
        y = self.embedding(self.softmax(self.ctc_linear(x)).argmax(-1))
        return self.softmax(self.ctc_linear(self.decoder(y.detach(), x))).permute(1, 0, 2)


class Seq2CTC(nn.Module):
    def __init__(self, encoder_index_dim, decoder_index_dim, hidden_dim=256):
        super(Seq2CTC, self).__init__()
        self.encoder = Encoder(encoder_index_dim, hidden_dim=hidden_dim)
        self.ctc = CTCDecoder(decoder_index_dim)
        # self.output_mask = OutputMask(2 * hidden_dim)
        self.encoder_index_dim = encoder_index_dim

    def forward(self, x, y=None):
        batch_size = x.shape[0]
        device = x.device

        input_length = x.shape[1]

        enc_out, _ = self.encoder(x)

        enc_out = nn.functional.pad(enc_out, (0, 0, 0, enc_out.shape[1]), "constant", 0)
        
        # mask = self.output_mask(enc_out.permute(1, 0, 2))
        # ctc = torch.mul(mask.unsqueeze(-1), self.ctc(enc_out.permute(1, 0, 2)))

        return self.ctc(enc_out.permute(1, 0, 2))

class Transformer2CTC(nn.Module):
    def __init__(self, encoder_index_dim, decoder_index_dim, embedding_dim=256, 
                 heads=4, hidden_dim=256, layers=4, dropout = 0.3):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(embedding_dim, dropout)
        encoder_layers = nn.TransformerEncoderLayer(embedding_dim, heads, hidden_dim, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, layers)
        self.encoder = nn.Embedding(encoder_index_dim, embedding_dim)
        self.embedding_dim = embedding_dim
        self.decoder = CTCDecoder(decoder_index_dim, hidden_dim)
        # self.output_mask = OutputMask(hidden_dim=hidden_dim)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, mask):
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        # x = nn.functional.pad(x, (0, 0, 0, x.shape[0]), "constant", 0)
        x = self.encoder(x) * math.sqrt(self.embedding_dim)
        x = self.pos_encoder(x)
        output = self.transformer_encoder(x, mask)
        # out_mask = self.output_mask(output)
        # output = torch.mul(out_mask.unsqueeze(-1), self.decoder(output))
        return self.decoder(output)


class Transformer2TransformerCTC(nn.Module):
    def __init__(self, encoder_index_dim, decoder_index_dim, embedding_dim=256, 
                 heads=4, hidden_dim=256, layers=4, dropout = 0.3):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(embedding_dim, dropout)
        encoder_layers = nn.TransformerEncoderLayer(embedding_dim, heads, hidden_dim, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, layers)
        self.encoder = nn.Embedding(encoder_index_dim, embedding_dim)
        self.embedding_dim = embedding_dim
        self.decoder = CTCTransformerDecoder(decoder_index_dim, hidden_dim)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, mask):
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        # x = nn.functional.pad(x, (0, 0, 0, x.shape[0]), "constant", 0)
        x = self.encoder(x) * math.sqrt(self.embedding_dim)
        x = self.pos_encoder(x)
        output = self.transformer_encoder(x, mask)
        # out_mask = self.output_mask(output)
        # output = torch.mul(out_mask.unsqueeze(-1), self.decoder(output))
        return self.decoder(output)


def generate_square_subsequent_mask(sz, cuda=True):
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    if cuda:
        return mask.cuda()
    return mask.cpu()

In [None]:
def decode_ctc(sent):
    if not sent:
        return sent
    last_token = sent[0]
    real_sent = []
    for token in sent[1:]:
        if last_token != '<CTC>' and token != last_token:
            real_sent.append(last_token)
        last_token = token
    return real_sent

In [63]:
class TrLSTMTCTCrainer:
    def __init__(self, model: nn.Module, train_loader, val_loader, tf=0.25, lr=3e-4, betas=(0.9, 0.999),
                 project="ctc_translation", name='ctc_model', save_every=None, save_path='./'):
        self.ctc_criterion = nn.CTCLoss(blank=vocabs[0].t2i['<CTC>'], zero_infinity=True)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        self.save_path = save_path
        self.save_every = save_every
        self.name = name
        self.project = project
        self.tf = tf
        wandb.init(project=project, name=name)

    def train_epoch(self, cuda=True, clip=1):
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        self.model.train()
        total_loss = 0
        total_ctc_loss = 0
        for batch_idx, (tokens, targets) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            if cuda:
                tokens = tokens.cuda()
                targets = targets.cuda()
            # input_lengths = torch.full(size=(tokens.shape[0],), 
            #                 fill_value=tokens.shape[1], 
            #                 dtype=torch.long)
            target_lengths = (targets != 0).sum(dim=1)
            ctc = self.model(tokens.to(dtype=torch.long), targets.to(dtype=torch.long))
            input_lengths = (tokens != 0).sum(dim=1)
            loss = 0
            ctc_loss = self.ctc_criterion(ctc.permute(1, 0, 2).to(dtype=torch.float), targets.to(dtype=torch.long), 
                                          input_lengths=input_lengths, target_lengths=target_lengths)
            total_ctc_loss += ctc_loss.item()
            loss += ctc_loss
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), clip)
            self.optimizer.step()

            total_loss += loss.item()

            self.log_train(total_loss / (batch_idx + 1))
            print('\rTrain loss: %4f, Batch: %d of %d' % (
                total_loss / (batch_idx + 1), batch_idx + 1, len(self.train_loader)), end='')
        print()
        loss = total_loss / len(self.train_loader)
        ctc_loss = total_ctc_loss / len(self.train_loader)
        return loss

    def test_epoch(self, cuda=True):
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        with torch.no_grad():
            self.model.eval()
            total_loss = 0
            total_ctc_loss = 0
            for batch_idx, (tokens, targets) in enumerate(self.val_loader):
                if cuda:
                    tokens = tokens.cuda()
                    targets = targets.cuda()
                # input_lengths = torch.full(size=(tokens.shape[0],), 
                #                 fill_value=tokens.shape[1], 
                #                 dtype=torch.long)
                target_lengths = (targets != 0).sum(dim=1)
                ctc = self.model(tokens.to(dtype=torch.long), targets.to(dtype=torch.long))
                input_lengths = (tokens != 0).sum(dim=1)

                loss = 0 
                ctc_loss = self.ctc_criterion(ctc.permute(1, 0, 2).to(dtype=torch.float), targets.to(dtype=torch.long), 
                                              input_lengths=input_lengths, target_lengths=target_lengths)
                total_ctc_loss += ctc_loss.item()
                loss += ctc_loss
                total_loss += loss.item()

                self.log_test(total_loss / (batch_idx + 1))
                print('\rVal loss: %4f, Batch: %d of %d' % (
                    total_loss / (batch_idx + 1), batch_idx + 1, len(self.val_loader)), end='')
            print()
            loss = total_loss / len(self.val_loader)
            ctc_loss = total_ctc_loss / len(self.val_loader)
            return loss

    def output(self, cuda=True):
        self.model.eval()
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        tokens, targets = next(iter(self.val_loader))
        tokens = tokens[1:2].to(dtype=torch.long)
        targets = targets[1:2].to(dtype=torch.long)
        batch_size = tokens.shape[0]
        if cuda:
            tokens = tokens.cuda()
            targets = targets.cuda()
        ctc = self.model(tokens, targets)
        ctc = ctc.argmax(dim=-1)
        summ = '<SOS>'
        ctc_sent = []
        for di in range(1, targets.shape[1]):
            summ += vocabs[1].i2t[targets[0, di].cpu().detach().squeeze().item()] + ' '
        for di in range(ctc.shape[1]):
            ctc_sent.append(vocabs[1].i2t[ctc[0, di].cpu().detach().squeeze().item()])

        print(summ[:-1])
        print(' '.join(decode_ctc(ctc_sent)))

    @staticmethod
    def log(epoch, train_loss, test_loss):
        wandb.log({
            'train': {
                'loss': train_loss
            },
            'val': {
                'loss': test_loss
            },
            'epoch': epoch
        })

    @staticmethod
    def log_train(train_loss):
        wandb.log({
            'train': {
                'loss': train_loss
            }
        })

    @staticmethod
    def log_test(test_loss):
        wandb.log({
            'test': {
                'loss': test_loss
            }
        })

    def checkpoint(self, epoch):
        torch.save(self.model.state_dict(), os.path.join(self.save_path, self.name + str(epoch) + '.ckpt'))

    def fit(self, max_epochs: int = 11, cuda=True, clip=1, log=False):
        for epoch in range(max_epochs):
            if self.save_every and epoch % self.save_every == 0:
                self.checkpoint(epoch)
            print('\rEpoch: %d' % epoch)
            self.output(cuda=cuda)
            train_loss = self.train_epoch(cuda=cuda, clip=clip)
            test_loss = self.test_epoch(cuda=cuda)
            if log:
                self.log(epoch, train_loss, test_loss)

In [74]:
class TrTransformerCTCrainer:
    def __init__(self, model: nn.Module, train_loader, val_loader, tf=0.25, lr=3e-4, betas=(0.9, 0.999),
                 project="ctc_translation", name='ctc_model', save_every=None, save_path='./'):
        self.ctc_criterion = nn.CTCLoss(blank=vocabs[0].t2i['<CTC>'], zero_infinity=True)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        self.save_path = save_path
        self.save_every = save_every
        self.name = name
        self.project = project
        self.tf = tf
        wandb.init(project=project, name=name)

    def train_epoch(self, cuda=True, clip=1):
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        self.model.train()
        total_loss = 0
        total_ctc_loss = 0
        for batch_idx, (tokens, targets) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            if cuda:
                tokens = tokens.cuda()
                targets = targets.cuda()
            # input_lengths = torch.full(size=(tokens.shape[0],), 
            #                            fill_value=tokens.shape[1], 
            #                            dtype=torch.long)
            target_lengths = (targets != 0).sum(dim=1)
            ctc = self.model(tokens.to(dtype=torch.long).permute(1, 0), 
                             generate_square_subsequent_mask(tokens.shape[1], cuda))
            input_lengths = (tokens != 0).sum(dim=1)
            loss = 0
            ctc_loss = self.ctc_criterion(ctc.permute(1, 0, 2).to(dtype=torch.float), targets.to(dtype=torch.long), 
                                          input_lengths=input_lengths, target_lengths=target_lengths)
            total_ctc_loss += ctc_loss.item()
            loss += ctc_loss
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), clip)
            self.optimizer.step()

            total_loss += loss.item()

            self.log_train(total_loss / (batch_idx + 1))
            print('\rTrain loss: %4f, Batch: %d of %d' % (
                total_loss / (batch_idx + 1), batch_idx + 1, len(self.train_loader)), end='')
        print()
        loss = total_loss / len(self.train_loader)
        ctc_loss = total_ctc_loss / len(self.train_loader)
        return loss

    def test_epoch(self, cuda=True):
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        with torch.no_grad():
            self.model.eval()
            total_loss = 0
            total_ctc_loss = 0
            for batch_idx, (tokens, targets) in enumerate(self.val_loader):
                if cuda:
                    tokens = tokens.cuda()
                    targets = targets.cuda()
                # input_lengths = torch.full(size=(tokens.shape[0],), 
                #                            fill_value=tokens.shape[1], 
                #                            dtype=torch.long)
                target_lengths = (targets != 0).sum(dim=1)
                ctc = self.model(tokens.to(dtype=torch.long).permute(1, 0), 
                                 generate_square_subsequent_mask(tokens.shape[1], cuda))
                input_lengths = (tokens != 0).sum(dim=1)

                loss = 0 
                ctc_loss = self.ctc_criterion(ctc.permute(1, 0, 2).to(dtype=torch.float), targets.to(dtype=torch.long), 
                                              input_lengths=input_lengths, target_lengths=target_lengths)
                total_ctc_loss += ctc_loss.item()
                loss += ctc_loss
                total_loss += loss.item()
            
                self.log_test(total_loss / (batch_idx + 1))
                print('\rVal loss: %4f, Batch: %d of %d' % (
                    total_loss / (batch_idx + 1), batch_idx + 1, len(self.val_loader)), end='')
            print()
            loss = total_loss / len(self.val_loader)
            ctc_loss = total_ctc_loss / len(self.val_loader)
            return loss

    def output(self, cuda=True):
        self.model.eval()
        if cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        tokens, targets = next(iter(self.val_loader))
        tokens = tokens[1:2].to(dtype=torch.long)
        targets = targets[1:2].to(dtype=torch.long)
        batch_size = tokens.shape[0]
        if cuda:
            tokens = tokens.cuda()
            targets = targets.cuda()
        ctc = self.model(tokens.permute(1, 0), 
                         generate_square_subsequent_mask(tokens.shape[1], cuda))
        ctc = ctc.argmax(dim=-1)
        summ = '<SOS>'
        ctc_sent = []
        for di in range(1, targets.shape[1]):
            summ += vocabs[1].i2t[targets[0, di].cpu().detach().squeeze().item()] + ' '
        for di in range(ctc.shape[1]):
            ctc_sent.append(vocabs[1].i2t[ctc[0, di].cpu().detach().squeeze().item()])

        print(summ[:-1])
        print(' '.join(decode_ctc(ctc_sent)))

    @staticmethod
    def log(epoch, train_loss, test_loss):
        wandb.log({
            'train': {
                'loss': train_loss
            },
            'val': {
                'loss': test_loss
            },
            'epoch': epoch
        })

    @staticmethod
    def log_train(train_loss):
        wandb.log({
            'train': {
                'loss': train_loss
            }
        })

    @staticmethod
    def log_test(test_loss):
        wandb.log({
            'test': {
                'loss': test_loss
            }
        })

    def checkpoint(self, epoch):
        torch.save(self.model.state_dict(), os.path.join(self.save_path, self.name + str(epoch) + '.ckpt'))

    def fit(self, max_epochs: int = 11, cuda=True, clip=1, log=False):
        for epoch in range(max_epochs):
            if self.save_every and epoch % self.save_every == 0:
                self.checkpoint(epoch)
            print('\rEpoch: %d' % epoch)
            self.output(cuda=cuda)
            train_loss = self.train_epoch(cuda=cuda, clip=clip)
            test_loss = self.test_epoch(cuda=cuda)
            if log:
                self.log(epoch, train_loss, test_loss)

In [66]:
import gc
gc.collect()

0

In [None]:
ctc_only_model = Seq2CTC(len(vocabs[0].t2i), len(vocabs[1].t2i))
# ctc_only_model.load_state_dict(torch.load('/content/drive/MyDrive/CTC_in_NLP/checkpoints/bilstm_ctc_only_25610.ckpt'))
ctc_only_trainer = TrLSTMTCTCrainer(ctc_only_model, train_loader, val_loader, 
                                  lr=3e-4, save_every=1, name='bilstm_ctc_only_256_de_en', 
                                  save_path=f'/content/drive/MyDrive/CTC_in_NLP/checkpoints')
ctc_only_trainer.fit(log=True, max_epochs=5)

In [None]:
ctc_only_transformer_model = Transformer2CTC(len(vocabs[0].t2i), len(vocabs[1].t2i))
# encoder_index_dim, decoder_index_dim, heads, hidden_dim, layers
# ctc_only_transformer_model.load_state_dict(torch.load('/content/drive/MyDrive/CTC_in_NLP/checkpoints/transformer_ctc_only_N5.ckpt'))
ctc_only_trainer = TrTransformerCTCrainer(ctc_only_transformer_model, train_loader, val_loader, 
                                  lr=3e-4, save_every=1, name='transformer_ctc_only_de_en', 
                                  save_path=f'/content/drive/MyDrive/CTC_in_NLP/checkpoints')
ctc_only_trainer.fit(log=True, max_epochs=5)

In [None]:
ctc_only_t2t_model = Transformer2TransformerCTC(len(vocabs[0].t2i), len(vocabs[1].t2i))
# encoder_index_dim, decoder_index_dim, heads, hidden_dim, layers
# ctc_only_transformer_model.load_state_dict(torch.load('/content/drive/MyDrive/CTC_in_NLP/checkpoints/transformer_ctc_only_N5.ckpt'))
ctc_only_trainer = TrTransformerCTCrainer(ctc_only_t2t_model, train_loader, val_loader, 
                                  lr=3e-4, save_every=1, name='transformer2transformer_ctc_only_de_en', 
                                  save_path=f'/content/drive/MyDrive/CTC_in_NLP/checkpoints')
ctc_only_trainer.fit(log=True, max_epochs=5)

In [87]:
def cut_on_eos(tokens_iter):
    for token in tokens_iter:
        if token == '<EOS>':
            break
        yield token

def remove_tech_tokens(tokens_iter, tokens_to_remove=['<SOS>', '<UNK>', '<PAD>']):
    return [x for x in tokens_iter if x not in tokens_to_remove]

def generate_translation(src, trg, model):
    model.eval()

    ctc = model(src, trg)
    output = ctc[:, 1:].argmax(-1)

    source = remove_tech_tokens(cut_on_eos([vocabs[0].i2t[x] for x in list(src[0, :].cpu().detach().numpy())]))
    original = remove_tech_tokens(cut_on_eos([vocabs[1].i2t[x] for x in list(trg[0, :].cpu().detach().numpy())]))
    generated = decode_ctc(remove_tech_tokens(cut_on_eos([vocabs[1].i2t[x] for x in list(output[0, :].cpu().detach().squeeze().numpy())])))
    
    print('Source: {}'.format(' '.join(source)))
    print('Original: {}'.format(' '.join(original)))
    print('Generated: {}'.format(' '.join(generated)))
    print()

def generate_trf_translation(src, trg, model):
    model.eval()

    ctc = model(src.permute(1, 0).cpu(), 
                generate_square_subsequent_mask(src.shape[1], False))
    output = ctc[:, 1:].argmax(-1)

    source = remove_tech_tokens(cut_on_eos([vocabs[0].i2t[x] for x in list(src[0, :].cpu().detach().numpy())]))
    original = remove_tech_tokens(cut_on_eos([vocabs[1].i2t[x] for x in list(trg[0, :].cpu().detach().numpy())]))
    generated = decode_ctc(remove_tech_tokens(cut_on_eos([vocabs[1].i2t[x] for x in list(output[0, :].cpu().detach().squeeze().numpy())])))
    
    print('Source: {}'.format(' '.join(source)))
    print('Original: {}'.format(' '.join(original)))
    print('Generated: {}'.format(' '.join(generated)))
    print()

def get_text(x):
     generated = remove_tech_tokens(cut_on_eos([vocabs[1].i2t[elem] for elem in list(x)]))
     return generated


def calculate_ctc_bleu(model, loader):
    model.eval()
    generated = []
    original = []
    for src, trg in loader:
        ctc = model(src.cuda(), trg.cuda())
        output = ctc[:, 1:].argmax(-1)

        original.extend([decode_ctc(get_text(x)) for x in trg.cpu().numpy()])
        generated.extend([decode_ctc(get_text(x)) for x in output.detach().cpu().squeeze().numpy()])
    
    return corpus_bleu([[text] for text in original], generated) * 100


def calculate_ctc_trf_bleu(model, loader):
    model.eval()
    generated = []
    original = []
    for src, trg in loader:
        ctc = model(src.permute(1, 0).cuda(), 
                    generate_square_subsequent_mask(src.shape[1], True))
        output = ctc[:, 1:].argmax(-1)

        original.extend([decode_ctc(get_text(x)) for x in trg.cpu().numpy()])
        generated.extend([decode_ctc(get_text(x)) for x in output.detach().cpu().squeeze().numpy()])
    
    return corpus_bleu([[text] for text in original], generated) * 100

In [77]:
%%time
calculate_ctc_bleu(ctc_only_model.cuda(), val_loader)

CPU times: user 852 ms, sys: 36 ms, total: 888 ms
Wall time: 892 ms


20.20625366376728

In [78]:
%%time
calculate_ctc_trf_bleu(ctc_only_transformer_model.eval().cuda(), val_loader)

CPU times: user 444 ms, sys: 15 ms, total: 459 ms
Wall time: 462 ms


21.900361510441595

In [86]:
%%time
calculate_ctc_trf_bleu(ctc_only_t2t_model.eval().cuda(), val_loader)

CPU times: user 616 ms, sys: 7.94 ms, total: 624 ms
Wall time: 699 ms


18.34613496451572

In [44]:
%%time
for batch in val_loader:
    generate_translation(batch[0][7:8].cpu(), batch[1][7:8].cpu(), ctc_only_model.cpu())
    break

Source: sie kennen ihr geschäft besser als jeder andere .
Original: they know their business better than anyone else .
Generated: you know your business better than else

CPU times: user 73.8 ms, sys: 2.99 ms, total: 76.8 ms
Wall time: 92.6 ms


In [45]:
%%time
for batch in val_loader:
    generate_trf_translation(batch[0][7:8].cpu(), batch[1][7:8].cpu(), ctc_only_transformer_model.cpu())
    break

Source: sie kennen ihr geschäft besser als jeder andere .
Original: they know their business better than anyone else .
Generated: they know your business better than every else

CPU times: user 141 ms, sys: 105 ms, total: 246 ms
Wall time: 306 ms


In [84]:
%%time
for batch in val_loader:
    generate_trf_translation(batch[0][7:8].cpu(), batch[1][7:8].cpu(), ctc_only_t2t_model.cpu())
    break

Source: sie kennen ihr geschäft besser als jeder andere .
Original: they know their business better than anyone else .
Generated: know business

CPU times: user 202 ms, sys: 12.9 ms, total: 215 ms
Wall time: 220 ms
