In [None]:
import os
os.environ['PROJECT_PATH'] = os.path.abspath(os.curdir)

**Mount Google Drive**

It looks like it is impossible to use Google Colab, since I am using torchtext package, where field.build_vocab method is broken for the latest version supported by python3.6, while in python3.7 everything is OK

In [None]:
from google.colab import drive
os.environ['PROJECT_PATH']='/content/ydrive/My Drive/Study/UNMT'
drive.mount('/content/ydrive/')

In [None]:
os.environ['TOOLS']= os.path.join(os.environ['PROJECT_PATH'], 'tools')
os.environ['RESOURCES']= os.path.join(os.environ['PROJECT_PATH'], 'resources')
os.environ['DATA']= os.path.join(os.environ['RESOURCES'], 'data')
os.environ['MODELS']= os.path.join(os.environ['PROJECT_PATH'], 'models')

# DATA VARIABLES
os.environ['VOCAB_SIZE']="32000"
os.environ['L1']='ba'
os.environ['L2']='ru'
os.environ['L1_DATA']="ba.sentesized"  
os.environ['L2_DATA']="news.2016.ru.shuffled"
os.environ['L1_DATA_PREPARED']=os.path.join(os.environ['DATA'], "{}.{}".format(os.environ['L1_DATA'], os.environ['VOCAB_SIZE']))
os.environ['L2_DATA_PREPARED']=os.path.join(os.environ['DATA'], "{}.{}".format(os.environ['L2_DATA'], os.environ['VOCAB_SIZE']))

os.environ["EMBEDDINGS_DIR"]=os.path.join(os.environ["RESOURCES"], "embeddings")
os.environ["BPE_EMBEDDINGS"]="{}-{}-bpe".format(os.environ['L1'], os.environ['L2'])

os.environ['L1_DATA_PARALLEL_RAW']="raw.parallel.{}".format(os.environ['L1'])
os.environ['L2_DATA_PARALLEL_RAW']="raw.parallel.{}".format(os.environ['L2'])
os.environ['L1_DATA_PARALLEL']="parallel.{}".format(os.environ['L1'])
os.environ['L2_DATA_PARALLEL']="parallel.{}".format(os.environ['L2'])
os.environ['PARALLEL_PREFIX']=os.path.join(os.environ['DATA'], "parallel.{}.".format(os.environ['VOCAB_SIZE']))
os.environ['L1_DATA_PARALLEL_PREPARED']='{}{}'.format(os.environ['PARALLEL_PREFIX'], os.environ['L1'])
os.environ['L2_DATA_PARALLEL_PREPARED']='{}{}'.format(os.environ['PARALLEL_PREFIX'], os.environ['L2'])

# Download data

## Parallel


In [None]:
%%bash
wget "https://docs.google.com/uc?export=download&id=1CQnKby8igxidqC3DqC0RRBJ0mfdyQ-T1" --output-document=$DATA/$L1_DATA_PARALLEL_RAW
wget "https://docs.google.com/uc?export=download&id=1ikD6di7XiR3pWO72aVFbn9HK4-q0Iskb" --output-document=$DATA/$L2_DATA_PARALLEL_RAW

## Bashkir Language (source)

In [None]:
%%bash
BASHKIR="$DATA/bashkir"
git clone https://github.com/nevmenandr/bashkir-corpus "$BASHKIR-corpus"
mkdir "$BASHKIR" & mkdir "$BASHKIR/raw"
find "$BASHKIR-corpus" -name "*.txt" -print0 | xargs -0 -I file cat file > "$BASHKIR/ba"
# rm -rf -d  "$BASHKIR-corpus"

WIKIEXTRACTOR="$TOOLS/wikiextractor"
git clone https://github.com/ptakopysk/wikiextractor "$WIKIEXTRACTOR"
[ -f $BASHKIR/bawiki-latest-pages-articles.xml.bz2 ] || wget http://download.wikimedia.org/bawiki/latest/bawiki-latest-pages-articles.xml.bz2 -P "$BASHKIR"
python3 "$WIKIEXTRACTOR/WikiExtractor.py"  --json -o "$BASHKIR/ba_wiki" "$BASHKIR/bawiki-latest-pages-articles.xml.bz2"
# rm "$BASHKIR/bawiki-latest-pages-articles.xml.bz2"

In [None]:
import json

input_folder = os.path.join(os.environ['DATA'], 'bashkir', 'ba_wiki')
output_path = os.path.join(os.environ['DATA'], 'bashkir', 'ba')

output_file = open(output_path, "a+", encoding='utf-8')

for path, subdirs, files in os.walk(input_folder):
    for name in files:
        file = open(os.path.join(path, name), 'r', encoding='utf-8')
        for line in file.readlines():
            dump = json.loads(line)
            if dump["text"].strip('\n'):
                output_file.write("%s\n" % dump["text"])
        file.close()

output_file.close()

# !rm -rf -d "$DATA/bashkir/ba_wiki"

## Russian Language

In [None]:
%%bash
wget http://data.statmt.org/wmt17/translation-task/news.2016.ru.shuffled.gz -P "$DATA"
gzip -d "$DATA/news.2016.ru.shuffled.gz"

# Preprocessing

In [None]:
!pip install razdel
from razdel import sentenize

mono_raw_data_path = os.path.join(os.environ['DATA'], 'bashkir', 'ba')
mono_sentenized_data_path = os.path.join(os.environ['DATA'], 'ba.sentesized')

def sentenize_raw_data(raw_data_path, sentenized_data_path):
    raw_data = open(raw_data_path, 'r', encoding='utf-8')
    sentenized_data = open(sentenized_data_path, 'w+', encoding='utf-8')

    for line in raw_data:
        sentences = sentenize(line)
        sentenized_data.writelines(["%s\n" % sentence.text for sentence in sentences if sentence.text.strip()])

sentenize_raw_data(mono_raw_data_path, mono_sentenized_data_path)
sentenize_raw_data(os.path.join(os.environ['DATA'], os.environ['L1_DATA_PARALLEL_RAW']), os.path.join(os.environ['DATA'], os.environ['L1_DATA_PARALLEL']))
sentenize_raw_data(os.path.join(os.environ['DATA'], os.environ['L2_DATA_PARALLEL_RAW']), os.path.join(os.environ['DATA'], os.environ['L2_DATA_PARALLEL']))

## Text cleaning and tokenization

In [None]:
!pip install -U sacremoses
from sacremoses import MosesPunctNormalizer, MosesTokenizer

def preprocess_file(filepath, language):
    normalizer = MosesPunctNormalizer(language, pre_replace_unicode_punct=True, post_remove_control_chars=True)
    tokenizer = MosesTokenizer(language)
    output_file = open('%s.cleaned' % filepath, 'w+', encoding='utf-8')

    with open(filepath, 'r', encoding='utf-8') as input_file:
        for line in input_file:
            line = normalizer.normalize(line)
            line.replace("&quot;", '')
            tokens = tokenizer.tokenize(line)
            if tokens:
                output_file.write("{}\n".format(' '.join(tokens)))

preprocess_file(os.path.join(os.environ['DATA'], os.environ['L1_DATA']), os.environ['L1']) 
preprocess_file(os.path.join(os.environ['DATA'], os.environ['L2_DATA']), os.environ['L2'])
preprocess_file(os.path.join(os.environ['DATA'], os.environ['L1_DATA_PARALLEL']), os.environ['L1']) 
preprocess_file(os.path.join(os.environ['DATA'], os.environ['L2_DATA_PARALLEL']), os.environ['L2']) 

## BPE codes

In [None]:
%%bash
FASTBPE="$TOOLS/fastBPE"
FAST="$FASTBPE/fast"
git clone https://github.com/glample/fastBPE "$FASTBPE"
g++ -std=c++11 -pthread -O3 "$FASTBPE/fastBPE/main.cc" -IfastBPE -o "$FAST"
"$FAST" learnbpe $VOCAB_SIZE "$DATA/${L1_DATA}.cleaned" "$DATA/${L2_DATA}.cleaned" > "$DATA/BPE_codes"
"$FAST" applybpe "${L1_DATA_PREPARED}" "$DATA/${L1_DATA}.cleaned" "$DATA/BPE_codes"
"$FAST" applybpe "${L2_DATA_PREPARED}" "$DATA/${L2_DATA}.cleaned" "$DATA/BPE_codes"
"$FAST" applybpe "${L1_DATA_PARALLEL_PREPARED}" "$DATA/${L1_DATA_PARALLEL}.cleaned" "$DATA/BPE_codes"
"$FAST" applybpe "${L2_DATA_PARALLEL_PREPARED}" "$DATA/${L2_DATA_PARALLEL}.cleaned" "$DATA/BPE_codes"

## Cross-lingual Embeddings

In [None]:
%%bash
FASTTEXT_DIR="$TOOLS/fastText"
FASTTEXT="$FASTTEXT_DIR/fasttext"
git clone https://github.com/facebookresearch/fastText.git "$FASTTEXT_DIR"
cd "$FASTTEXT_DIR" 
[ -f "$FASTTEXT" ] || make

CONCAT_BPE="$DATA/concatenated.$VOCAB_SIZE"
N_THREADS=$(grep -c ^processor /proc/cpuinfo)
echo $N_THREADS
cat "${L1_DATA_PREPARED}" "${L2_DATA_PREPARED}" | shuf > "$CONCAT_BPE"
chmod +x "$FASTTEXT"
"$FASTTEXT" skipgram -dim 256 -thread $N_THREADS -input "$CONCAT_BPE" -output "$EMBEDDINGS_DIR/$BPE_EMBEDDINGS"

# Model implementation

## Tools

In [None]:
!pip install torch
import copy
import math
import torch
from torch import nn

#src https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def get_module_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def get_mask(inputs, pad_mask): #[2]
    slen, bs = inputs.size()
    lengths = slen-torch.sum(pad_mask, 0)
    alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
    return alen < lengths[:, None]

## Encoder

In [None]:
import torch.nn.functional as F
from torch.nn import TransformerDecoder, TransformerDecoderLayer, \
                     TransformerEncoder, TransformerEncoderLayer


class Encoder(nn.Module):
  
    def __init__(self, field, d_model=256, nlayers=4, nheads=8, dropout=0.1, freeze_embs=False):
        super(Encoder, self).__init__()
        
        self.voc_size = len(field.vocab) 

        self.d_model = d_model
        self.dropout = dropout
        self.embeddings = nn.Embedding(self.voc_size, d_model).from_pretrained(field.vocab.vectors, freeze=freeze_embs)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = TransformerEncoderLayer(d_model, nheads, dim_feedforward=4*d_model, dropout=dropout, activation='gelu')
        self.layers = get_module_clones(encoder_layer, nlayers)
    
    def forward(self, src, pad_mask):
        src_mask = get_mask(src, pad_mask)
        x = self.embeddings(src)
        x = self.pos_encoder(x)
        x = F.dropout(x, self.dropout, training=self.training)
      
        for layer in self.layers:
            x = layer(x, src_mask, pad_mask)

        return x


In [None]:
class Decoder(nn.Module):
    def __init__(self, fields, encoder, d_model=256, nlayers=4, nheads=8, dropout=0.1, shared_nlayers=2):
        """
        :param fields: list of fields for 0: L1, 1: L2, 2: both
        """
        super(Decoder, self).__init__()
        assert len(fields) == 3
        self.sos_idx = [field.vocab.stoi['<sos>'] for field in fields]
        self.eos_idx = [field.vocab.stoi['<eos>'] for field in fields]
        self.pad_idx = [field.vocab.stoi['<pad>'] for field in fields]
        self.fields = [field for field in fields]
        self.d_model = d_model
        self.dropout = dropout
        self.embeddings = encoder.embeddings
        self.pos_encoder = encoder.pos_encoder

        decoder_layer = TransformerDecoderLayer(d_model, nheads, dim_feedforward=4*d_model, dropout=dropout, activation='gelu')
        self.layers = nn.ModuleList()
        shared_layers = get_module_clones(decoder_layer, shared_nlayers)
        # Layers for the source language with shared bottom layers
        self.layers.append(shared_layers) 
        self.layers[0].extend(get_module_clones(decoder_layer, nlayers-shared_nlayers))
        # Layers for the target language with shared bottom layers
        self.layers.append(shared_layers) 
        self.layers[1].extend(get_module_clones(decoder_layer, nlayers-shared_nlayers))

        proj_layers = [nn.Linear(self.embeddings.embedding_dim, len(field.vocab)) for field in fields[:2]]
        self.proj_layers = nn.ModuleList(proj_layers)

    def forward(self, previous_tokens, encoded, enc_pad_mask, lang_id):
        x = self.embeddings(previous_tokens)
        x = self.pos_encoder(x)
        x = F.dropout(x, self.dropout, training=self.training)
      
        for layer in self.layers[lang_id]:
            x = layer(x, encoded, memory_key_padding_mask=enc_pad_mask)

        x = self.proj_layers[lang_id](x)
        return x

    def generate_sequence(self, encoded, enc_pad_mask, lang_id, sequence_len=128):
        cur_len = 1
        bs = encoded.size(1)
        decoded = torch.LongTensor(sequence_len, bs).fill_(self.pad_idx[lang_id])
        decoded = decoded.to(encoded.device)
        decoded_shared = torch.LongTensor(sequence_len, bs).fill_(self.pad_idx[lang_id]).to(encoded.device)
        decoded[0] = self.sos_idx[lang_id]
        decoded_shared[0] = self.sos_idx[2]
        unfinished_sents = torch.LongTensor(bs).fill_(1).to(encoded.device)
        
        while cur_len < sequence_len:
            scores = self.forward(decoded_shared[:cur_len], encoded, enc_pad_mask, lang_id)
            scores = scores[-1, :, :]
            next_words = torch.topk(scores, 1)[1].squeeze(1).to(encoded.device)
            assert next_words.size() == (bs,)

            decoded[cur_len] = next_words*unfinished_sents + self.pad_idx[lang_id]*(1-unfinished_sents)
            decoded_shared[cur_len] = self.specific_vocab_2_encoder(decoded[cur_len], lang_id)
            unfinished_sents.mul_(next_words.ne(self.eos_idx[lang_id]).long())
            cur_len += 1

            if unfinished_sents.max() == 0:
                break

        
        if cur_len == sequence_len:
            decoded[sequence_len - 1].masked_fill_(unfinished_sents.bool(), self.eos_idx[lang_id])
            decoded_shared[sequence_len - 1].masked_fill_(unfinished_sents.bool(), self.eos_idx[2])
            
        return decoded, decoded_shared
    
    def specific_vocab_2_encoder(self, trg, lang_id):
        replace_index = lambda x: self.fields[2].vocab.stoi[self.fields[lang_id].vocab.itos[x]]
        return trg.clone().detach().cpu().apply_(replace_index).to(trg.device)

# Training

1) https://github.com/pytorch/fairseq/blob/7b3df95f287bc0d844f64fe45717123d06dacb97/fairseq/data/noising.py

In [None]:
import numpy as np
# [1]
class Noising:
    def __init__(self, vocab):
        """
        Vocab of encoder input
        """
        self.vocab = vocab
        self.bpe_ends_mask = np.array([not vocab.itos[i].endswith('@@') for i in range(len(vocab))])
        
        self.pad_idx = vocab.stoi['<pad>']
        self.sos_idx = vocab.stoi['<sos>']
        self.eos_idx = vocab.stoi['<eos>']
        self.sep_idx = vocab.stoi['<sep>']
        self.mask_idx = vocab.stoi['<mask>']

    def noise(self, inp):
        x = inp.cpu()
        pad_mask = x.eq(self.pad_idx)
        lengths = x.size(0) - pad_mask.sum(0)

        x = self.shuffle(x, lengths)
        x, lengths = self.dropout(x, lengths)
        x, lengths = self.dropout(x, lengths, blank_idx=self.mask_idx)
        return x

    def get_word_idx(self, x):
        bpe_end = self.bpe_ends_mask[x]
        word_idx = bpe_end[::-1].cumsum(0)[::-1]
        word_idx = word_idx.max(0)[None, :] - word_idx 
        return word_idx

    def dropout(self, x, lengths, dropout_rate=0.1, blank_idx=None):
        sentences = []
        modified_lengths = []
        word_idx = self.get_word_idx(x)
        sos_mask = x.eq(self.sos_idx)
        eos_mask = x.eq(self.eos_idx)
        not_dropout_mask = sos_mask + eos_mask
        not_dropout_mask = not_dropout_mask.numpy()
        
        for i in range(lengths.size(0)):
            num_words = max(word_idx[:, i]) + 1
            keep = np.random.rand(num_words) >= dropout_rate
            do_not_dropout_words_idx = word_idx[:, i]*not_dropout_mask[:, i]
            keep[do_not_dropout_words_idx] = 1 # do not dropout <sos> symbol
            words = x[:lengths[i], i].tolist()
            new_s = [
                w if keep[word_idx[j, i]] else blank_idx
                for j, w in enumerate(words)
            ]
            new_s = [w for w in new_s if w is not None]
            sentences.append(new_s)
            modified_lengths.append(len(new_s))
        # re-construct input
        modified_lengths = torch.LongTensor(modified_lengths)

        modified_x = torch.LongTensor(
            x.size(0),
            x.size(1)
        ).fill_(self.pad_idx)
        for i in range(modified_lengths.size(0)):
            modified_x[:modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))

        return modified_x, modified_lengths

    def shuffle(self, x, lengths, max_shuffle_distance=3):
        if max_shuffle_distance == 0:
            return x
        eos_mask = x.eq(self.eos_idx)
        lengths -= eos_mask.sum(0)

        noise = np.random.uniform(
            0,
            max_shuffle_distance,
            size=(x.size(0), x.size(1)),
        )
        
        sos_mask = x.eq(self.sos_idx).numpy()
        do_not_shuffle_indices = np.nonzero(sos_mask)
        noise[do_not_shuffle_indices] = -1 # do not move <sos> symbols
        word_idx = self.get_word_idx(x)

        x2 = x.clone()
        for i in range(lengths.size(0)):
            scores = word_idx[:lengths[i], i] + noise[word_idx[:lengths[i], i], i]
            scores += 1e-6 * np.arange(lengths[i])
            permutation = scores.argsort()
            x2[:lengths[i], i].copy_(
                x2[:lengths[i], i][torch.from_numpy(permutation)]
            )
        return x2

In [None]:
from collections import OrderedDict
from torch.optim import Adam

class Trainer:
    def __init__(self, encoder, decoder, fields, params, logger, clip=1.0, lr=0.0001):
        """
        :param fields: list of fields for 0: L1, 1: L2, 2: both
        """
        self.encoder = encoder.to(params.device)
        self.decoder = decoder.to(params.device)

        self.pad_idx = [field.vocab.stoi['<pad>'] for field in fields]
        self.vocab_sizes = [len(field.vocab) for field in fields]
        self.fields = fields

        self.criterion = [nn.CrossEntropyLoss(ignore_index=pad_idx) for pad_idx in self.pad_idx[:2]]
        self.clip = clip

        self.noising = Noising(fields[2].vocab)

        self.enc_optimizer = Adam(encoder.parameters(), lr=lr)
        self.dec_optimizer = Adam(decoder.parameters(), lr=lr)

        self.n_total_iter = 0
        self.epoch = 0

        self.device = params.device
        self.logger = logger

    def get_denoising_loss_weight(self, init_weight=1, decrease_slower_iter=10**5, weight_slower=0.1, set_to_zero_iter=3*10**5):
        if self.n_total_iter < decrease_slower_iter:
            return init_weight - ((init_weight-weight_slower)/decrease_slower_iter)*self.n_total_iter

        return weight_slower - (weight_slower/(set_to_zero_iter - decrease_slower_iter))*(self.n_total_iter - decrease_slower_iter)
    
    def backprop(self, loss):
        self.enc_optimizer.zero_grad()
        self.dec_optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.encoder.parameters(), self.clip)
        nn.utils.clip_grad_norm_(self.decoder.parameters(), self.clip)
        self.enc_optimizer.step()
        self.dec_optimizer.step()

    def denoising_step(self, inp, specific_vocab_inp, lang_id):
        x = self.noising.noise(inp).to(self.device)
        pad_mask = x.eq(self.pad_idx[-1]).transpose_(0, 1)
        self.encoder.train()
        self.decoder.train()
        encoded = self.encoder(x, pad_mask)
        scores = self.decoder(x[:-1], encoded, pad_mask, lang_id)
        loss = self.criterion[lang_id](scores.view(-1, self.vocab_sizes[lang_id]), specific_vocab_inp[1:].view(-1))

        loss = self.get_denoising_loss_weight() * loss

        self.backprop(loss)

        progress_state = OrderedDict(
            step_type='denoising',
            loss=loss.item(),
            sentences=inp.size(1),
            n_total_iter=self.n_total_iter,
            epoch=self.epoch,
            lang_id=lang_id
            )

        return progress_state

    def backtranslation_step(self, src, specific_vocab_src, trg, src_lang_id, trg_lang_id):
        # src -> trg -> src
        self.encoder.train()
        self.decoder.train()

        trg_pad_mask = trg.eq(self.pad_idx[-1]).transpose_(0, 1)
        encoded = self.encoder(trg, trg_pad_mask)

        src_pad_mask = src.eq(self.pad_idx[-1]).transpose_(0, 1)
        scores = self.decoder(src[:-1], encoded, src_pad_mask, src_lang_id)
        loss = self.criterion[src_lang_id](scores.view(-1, self.vocab_sizes[src_lang_id]), specific_vocab_src[1:].view(-1))

        self.backprop(loss)

        progress_state = OrderedDict(
            step_type='backtranslation',
            loss=loss.item(),
            sentences=src.size(1),
            n_total_iter=self.n_total_iter,
            epoch=self.epoch,
            backtranslation_direction='{}->{}->{}'.format(src_lang_id, trg_lang_id, src_lang_id)
            )
        
        return progress_state
        
    def generate_translation(self, src, lang1_id, lang2_id, train=True):
        pad_mask = src.eq(self.pad_idx[-1]).transpose_(0, 1)

        if train:
            encoded = self.encoder(src, pad_mask)          
            trg, trg_shared = self.decoder.generate_sequence(encoded, pad_mask, lang2_id) 
        else:
            with torch.no_grad():
                encoded = self.encoder(src, pad_mask)          
                trg, trg_shared = self.decoder.generate_sequence(encoded, pad_mask, lang2_id) 

        return trg, trg_shared

    def save_model(self, dump_dir, name):
        path = os.path.join(dump_dir, '%s.pth' % name)
        self.logger.log('Saving model to %s ...' % path)
        torch.save({
            'encoder': self.encoder,
            'decoder': self.decoder,
            'enc_optimizer': self.enc_optimizer,
            'dec_optimizer': self.dec_optimizer,
            'epoch': self.epoch,
            'n_total_iter': self.n_total_iter, 
            'criterion': self.criterion
        }, path)

        
def reload_checkpoint(dump_dir, name, fields, params, logger):
    checkpoint_path = os.path.join(dump_dir, name)
    if not os.path.isfile(checkpoint_path):
        return

    logger.log('Reloading checkpoint from %s ...' % checkpoint_path)
    checkpoint_data = torch.load(checkpoint_path)
    encoder = checkpoint_data['encoder']
    decoder = checkpoint_data['decoder']
    trainer = Trainer(encoder, decoder, fields, params, logger)
    trainer.enc_optimizer = checkpoint_data['enc_optimizer']
    trainer.dec_optimizer = checkpoint_data['dec_optimizer']
    trainer.epoch = checkpoint_data['epoch']
    trainer.n_total_iter = checkpoint_data['n_total_iter'] + 1
    trainer.criterion = checkpoint_data['criterion']

    logger.log('Checkpoint reloaded. Resuming at epoch %i ...' % trainer.epoch)
    print('Checkpoint reloaded. Resuming at epoch %i ...' % trainer.epoch)

    return encoder, decoder, trainer

In [None]:
!pip install torchtext
from torchtext import data

class CustomDataset(data.Dataset):
    def __init__(self, path, fields, newline_eos=True,
                 encoding='utf-8', **kwargs):
        fields_ = [('encoder_text', fields[0]), ('specific_text', fields[1])]
        with open(path, encoding=encoding) as f:
            sentences = [fields[0].preprocess(line) for line in f if line.strip('\n')]        
        examples = [data.Example.fromlist([sentence, sentence], fields_) for sentence in sentences]
        super(CustomDataset, self).__init__(
            examples, fields_, **kwargs)

In [None]:
from torchtext.data import BucketIterator, metrics
from torchtext.datasets import TranslationDataset

class Evaluation:
    
    SPECIAL_TOKENS = ['<pad>', '<eos>', '<sos>']
    
    def __init__(self, path, exts, fields, trainer, params):
        """
        :param path: Common prefix of paths to the data files for both languages
        :param exts: A tuple containing the extension to path for each language
        :param fields: 0: L1, 1: L2, 2: both
        """
        self.fields = fields
        self.dataset = TranslationDataset(path, exts, (fields[2], fields[2]))
        self.iter = BucketIterator(dataset=self.dataset, batch_size=32)
        self.trainer = trainer
        self.device = params.device

    def calculate_score(self, src_lang=0, trg_lang=1):
        translations = []
        reference = []
        for batch in self.iter:
            results, _ = self.trainer.generate_translation(batch.src.to(self.device), src_lang, trg_lang)
            translations += [[self.fields[trg_lang].vocab.itos[token] for token in sentence] for sentence in results.T]
            reference += [[self.fields[2].vocab.itos[token] for token in sentence] for sentence in batch.trg.T]
            
        translations = [[token for token in sentence if token not in self.SPECIAL_TOKENS] for sentence in translations]
        reference = [[token for token in sentence if token not in self.SPECIAL_TOKENS] for sentence in reference]

        return metrics.bleu_score(translations, reference), translations, reference

    def generate_translation(self, sentence, src_lang, trg_lang):
        x = self.fields[src_lang].preprocess(sentence)
        x = self.fields[src_lang].process([x]).to(self.device)
        translation, _ = self.trainer.generate_translation(x, src_lang, trg_lang, train=False)
        translation = [self.fields[trg_lang].vocab.itos[token] for token in translation]
        translation_str = ''
        for token in translation:
            if token in self.SPECIAL_TOKENS:
                continue
            if token.endswith("@@"):
                translation_str += token[:-2]
                continue
            translation_str += token + ' '

        return translation_str
    
    def shared_2_specific(self, trg, lang_id):
        replace_index = lambda x: self.fields[lang_id].vocab.stoi[self.fields[2].vocab.itos[x]]
        return trg.clone().detach().cpu().apply_(replace_index).to(trg.device)

In [None]:
!pip install dill
import dill
from itertools import zip_longest
from torchtext.datasets import LanguageModelingDataset
from torchtext.vocab import Vectors

class Logger:
    def __init__(self, path=None):
          self.log_file = open(path, 'a+') if path else None

    def log(self, info):
        if type(info) is str:
            print("%s\n" % info, file=self.log_file)
        elif type(info) is OrderedDict:
            for k, v in info.items():
                print("%s: " % str(k), file=self.log_file)
                print("%s\n" % str(v), file=self.log_file)

            print('\n\n\n', file=self.log_file)

    def close(self):
        if self.log_file:
            self.log_file.close()


def main(params):
    logger = Logger(params.log_file)

    if os.path.isfile(params.field_path):
        with open(params.field_path,"rb")as f:
            TEXT = dill.load(f)
        with open('{}.{}'.format(params.field_path, params.l1),"rb")as f:
            TEXT_L1 = dill.load(f)
        with open('{}.{}'.format(params.field_path, params.l2),"rb")as f:
            TEXT_L2 = dill.load(f)
    else:
        assert params.train
        vectors = Vectors(name=params.embs_file, cache=params.embs_dir) 
        logger.log("Loaded Vectors")
    
        TEXT = data.Field(
            init_token='<sos>',
            eos_token='<eos>',
            fix_length=params.sequence_length
        )

        TEXT_L1 = data.Field(
            init_token='<sos>',
            eos_token='<eos>',
            fix_length=params.sequence_length
        )
        
        TEXT_L2 = data.Field(
            init_token='<sos>',
            eos_token='<eos>',
            fix_length=params.sequence_length
        )
        logger.log("Loaded Field")
        
        
    if params.train:
        train_l1_dataset = CustomDataset(params.l1_data_path, (TEXT, TEXT_L1))
        logger.log("Loaded L1 Dataset")
        train_l2_dataset = CustomDataset(params.l2_data_path, (TEXT, TEXT_L2))
        logger.log("Loaded L2 Dataset")

        if not os.path.isfile(params.field_path):
            TEXT.build_vocab(
                train_l1_dataset,
                train_l2_dataset,
                specials=['<sep>', '<mask>'],
                vectors=vectors
            )
                   
            TEXT_L1.build_vocab(
                train_l1_dataset,
                specials=['<sep>', '<mask>']
            )
                   
            TEXT_L2.build_vocab(
                train_l2_dataset,
                specials=['<sep>', '<mask>']
            )
                                       
            with open(params.field_path,"wb")as f:
                dill.dump(TEXT,f)
            with open('{}.{}'.format(params.field_path, params.l1),"wb")as f:
                dill.dump(TEXT_L1,f)
            with open('{}.{}'.format(params.field_path, params.l2),"wb")as f:
                dill.dump(TEXT_L2,f)
                                
        l1_iter = data.BucketIterator(
              dataset = train_l1_dataset,
              batch_size = params.batch_size,
              shuffle=True,
              device=params.device
        )
        logger.log("Created L1 Iterator")

        l2_iter = data.BucketIterator(
              dataset = train_l2_dataset,
              batch_size = params.batch_size,
              shuffle=True,
              device=params.device
        )
        logger.log("Created L2 Iterator")

    
    if params.checkpoint:
        encoder, decoder, trainer = reload_checkpoint(params.dump_dir, params.checkpoint, [TEXT_L1, TEXT_L2, TEXT], params, logger)    
    else:
        encoder = Encoder(TEXT)
        logger.log("Created Encoder")
        decoder = Decoder([TEXT_L1, TEXT_L2, TEXT], encoder)
        logger.log("Created Decoder")

        trainer = Trainer(encoder, decoder, [TEXT_L1, TEXT_L2, TEXT], params, logger)
        logger.log("Created Trainer")

        if params.train:
            encoder.train()
            decoder.train()
    
    languages = {params.l1: 0, params.l2: 1}
    
    evaluation = Evaluation(os.environ["PARALLEL_PREFIX"], languages.keys(), [TEXT_L1, TEXT_L2, TEXT], trainer, params)
    
    if not params.train:
        logger.log("BLEU score: ", evaluation.calculate_score())
    
    logger.log("===================TRAINING STARTED===================")
    while trainer.epoch <= params.n_epoch:
        logger.log("===================EPOCH%d===================" % trainer.epoch)
        for batches in zip_longest(l1_iter, l2_iter, fillvalue=None):
            print(f'\rIteration {trainer.n_total_iter}')
            for src_id in languages.values():
                if not batches[src_id]: # if there are no batches for this language
                    continue
                trg_id = 0 if src_id == 1 else 1 
                src_text = batches[src_id].encoder_text # using field with joint vocab
                src_specific_text = batches[src_id].specific_text # using field with specific vocabulary 
                denoising_progress_state = trainer.denoising_step(src_text, src_specific_text, src_id)
                _, translation_shared = trainer.generate_translation(src_text, src_id, trg_id)
                bt_progress_state = trainer.backtranslation_step(src_text, src_specific_text, translation_shared, src_id, trg_id)
        
            if trainer.n_total_iter and trainer.n_total_iter % params.save_every_ith_iter == 0:
                trainer.save_model(params.dump_dir, 'checkpoint-{}-{}'.format(trainer.epoch, trainer.n_total_iter))            
                logger.log(denoising_progress_state)
                logger.log(bt_progress_state)
                logger.log("Epoch {} BLEU score: {}".format(trainer.epoch, evaluation.calculate_score()))
            trainer.n_total_iter += 1

        trainer.epoch += 1

    log_file.close()


In [None]:
class Parameters:
    def __init__(self):
        # Embeddings
        self.embs_file = "%s.vec" % os.environ["BPE_EMBEDDINGS"]
        self.embs_dir = os.environ["EMBEDDINGS_DIR"]

        # Dataset
        self.l1 = os.environ['L1']
        self.l2 = os.environ['L2']        
        self.l1_data_path = os.environ["L1_DATA_PREPARED"]
        self.l2_data_path = os.environ["L2_DATA_PREPARED"]

        self.l1_parallel_path = os.environ["L1_DATA_PARALLEL_PREPARED"]
        self.l2_parallel_path = os.environ["L2_DATA_PARALLEL_PREPARED"]
        
        self.l1_data_path = os.environ["L1_DATA_PREPARED"] + '.cut'
        self.l2_data_path = os.environ["L2_DATA_PREPARED"] + '.cut'

        # Training
        self.sequence_length = 128
        self.batch_size = 48
        self.lr = 0.0001
        self.clip = 1.0

        self.n_epoch = 40

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.log_file = os.path.join(os.environ['MODELS'], 'log_1.txt')
        self.dump_dir = os.environ['MODELS']
        self.field_path = os.path.join(os.environ['MODELS'], 'TEXT.field')

        self.vocab_size = os.environ['VOCAB_SIZE']

        self.save_every_ith_iter = 1000
        self.checkpoint = None

        self.train = True

params = Parameters()
main(params)

### 