In [0]:
!pip install -U -q PyDrive
!pip install sentencepiece fastBPE regex requests sacremoses subword_nmt
!pip install fairseq
!pip install transformers

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
from googleapiclient.http import MediaIoBaseDownload
from googleapiclient.discovery import build

import os, re
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from fairseq.sequence_generator import SequenceGenerator

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 3.3MB/s 
[?25hCollecting fastBPE
  Downloading https://files.pythonhosted.org/packages/e1/37/f97181428a5d151501b90b2cebedf97c81b034ace753606a3cda5ad4e6e2/fastBPE-0.1.0.tar.gz
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)
[K     |████████████████████████████████| 870kB 18.8MB/s 
[?25hCollecting subword_nmt
  Downloading https://files.pythonhosted.org/packages/74/60/6600a7bc09e7ab38bc53a48a20d8cae49b837f93f5842a41fe513a694912/subword_nmt-0.3.7-py2.py3-none-any.whl
Building wheels for collected packages: fastBPE, sacremoses
  Building wheel for fastBPE (setup.py) ... [?25l[?25hdone
  Created 

In [0]:
%%bash

#XLM-R model
# wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
# tar -zxvf xlmr.base.tar.gz

#Vanilla Transformer as in Vaswani et. al
wget https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2
tar -xvf wmt14.en-fr.joined-dict.transformer.tar.bz2

wmt14.en-fr.joined-dict.transformer/
wmt14.en-fr.joined-dict.transformer/model.pt
wmt14.en-fr.joined-dict.transformer/dict.en.txt
wmt14.en-fr.joined-dict.transformer/dict.fr.txt
wmt14.en-fr.joined-dict.transformer/bpecodes


IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [0]:
%%bash

echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git

echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git

SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=40000

URLS=(
    #"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
    #"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
    #"http://statmt.org/wmt13/training-parallel-un.tgz"
    "http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
    #"http://statmt.org/wmt10/training-giga-fren.tar"
    #"http://statmt.org/wmt14/test-full.tgz"
)
FILES=(
    #"training-parallel-europarl-v7.tgz"
    #"training-parallel-commoncrawl.tgz"
    #"training-parallel-un.tgz"
    "training-parallel-nc-v9.tgz"
    #"training-giga-fren.tar"
    #"test-full.tgz"
)
CORPORA=(
    #"training/europarl-v7.fr-en"
    #"commoncrawl.fr-en"
    #"un/undoc.2000.fr-en"
    "training/news-commentary-v9.fr-en"
    #"giga-fren.release2.fixed"
)

if [ ! -d "$SCRIPTS" ]; then
    echo "Please set SCRIPTS variable correctly to point to Moses scripts."
    exit
fi

src=en
tgt=fr
lang=en-fr
prep=wmt14_en_fr
tmp=$prep/tmp
orig=orig

mkdir -p $orig $tmp $prep
cd $orig

for ((i=0;i<${file=${FILES[i]}
    if [ -f $file ]; then
        echo "$file already exists, skipping download"
    else
        url=${URLS[i]}
        wget "$url"
        if [ -f $file ]; then
            echo "$url successfully downloaded."
        else
            echo "$url not successfully downloaded."
            exit -1
        fi
        if [ ${file: -4} == ".tgz" ]; then
            tar zxvf $file
        elif [ ${file: -4} == ".tar" ]; then
            tar xvf $file
        fi
    fi
done

gunzip giga-fren.release2.fixed.*.gz
cd ..

echo "pre-processing train data..."
for l in $src $tgt; do
    rm $tmp/train.tags.$lang.tok.$l
    for f in "${CORPORA[@]}"; do
        cat $orig/$f.$l | \
            perl $NORM_PUNC $l | \
            perl $REM_NON_PRINT_CHAR | \
            perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
    done
done

echo "pre-processing test data..."
for l in $src $tgt; do
    if [ "$l" == "$src" ]; then
        t="src"
    else
        t="ref"
    fi
    grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \
        sed -e 's/<seg id="[0-9]*">\s*//g' | \
        sed -e 's/\s*<\/seg>\s*//g' | \
        sed -e "s/\’/\'/g" | \
    perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
    echo ""
done

echo "splitting train and valid..."
for l in $src $tgt; do
    awk '{if (NR%1333 == 0)  print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
    awk '{if (NR%1333 != 0)  print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done

TRAIN=$tmp/train.fr-en
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
    cat $tmp/train.$l >> $TRAIN
done

echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE

for L in $src $tgt; do
    for f in train.$L valid.$L test.$L; do
        echo "apply_bpe.py to ${f}..."
        python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
    done
done

perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250

for L in $src $tgt; do
    cp $tmp/bpe.test.$L $prep/test.$L
done

rm -rf mosesdecoder subword-nmt

In [2]:
from transformers import XLMRobertaModel, XLMRobertaTokenizer, PreTrainedEncoderDecoder
from fairseq.models.transformer import TransformerModel

encoder = XLMRobertaModel.from_pretrained('xlm-roberta-base')
en2fr = TransformerModel.from_pretrained('wmt14.en-fr.joined-dict.transformer/',
                                         checkpoint_file='model.pt',
                                         bpe='subword_nmt',
                                         bpe_codes='wmt14.en-fr.joined-dict.transformer/bpecodes')
decoder = [model for name, model in en2fr.named_modules() if name == 'models.0.decoder'][0]

hi2fr = PreTrainedEncoderDecoder(encoder, decoder)
hi2fr.encoder.eval()
hi2fr.decoder.eval()

ModuleNotFoundError: ignored

In [0]:
def load_data(file_id, file_name):
  
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)

  if not os.path.exists("./data"):
    os.mkdir("data")

  handle = drive.CreateFile({'id': file_id})
  handle.GetContentFile('data/' + file_name)

  import tarfile
  tar = tarfile.open('data/' + file_name)
  tar.extractall(path='data/')
  tar.close()

load_data('152gnwdwcgvUm8ADDiR-yEky7GTROJB9r', 'parallel.tgz')

In [0]:
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
text = 'अपने अनुप्रयोग को पहुंचनीयता व्यायाम का लाभ दें'
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
print(tokenized_text)
hi2fr(tokenized_text)

ERROR! Session/line number was not unique in database. History logging moved to new session 59
[5564, 24949, 7475, 35418, 629, 77291, 80903, 1480, 218438, 641, 28904, 50486]


TypeError: ignored

In [0]:
from fairseq.data import Dictionary
from fairseq.models import FairseqEncoderDecoderModel
from fairseq.models.roberta import XLMRModel

class TranslationModel(FairseqEncoderDecoderModel):

  def __init__(self, encoder, decoder):
    super().__init__(encoder, decoder)

tgt_dict = Dictionary()
tgt_dict.load('wmt14.en-fr.joined-dict.transformer/dict.fr.txt')

<fairseq.data.dictionary.Dictionary at 0x7f8c99b645c0>

In [0]:
xlmr = XLMRModel.from_pretrained('xlmr.base/', checkpoint_file='model.pt')
encoder = [model for name, model in encoder.named_modules() if name == 'model.decoder.sentence_encoder'][0]

en2fr = TransformerModel.from_pretrained('wmt14.en-fr.joined-dict.transformer/',
                                         checkpoint_file='model.pt',
                                         bpe='subword_nmt',
                                         bpe_codes='wmt14.en-fr.joined-dict.transformer/bpecodes')
decoder = [model for name, model in en2fr.named_modules() if name == 'models.0.decoder'][0]

model = TranslationModel(encoder, decoder)
#hi2fr = SequenceGenerator(tgt_dict=tgt_dict, beam_size=5)
#hi2fr.generate([encoder, decoder], source_sentences[0])

loading archive file xlmr.base/
| dictionary: 250001 types
loading archive file wmt14.en-fr.joined-dict.transformer/
| [en] dictionary: 44512 types
| [fr] dictionary: 44512 types
Namespace(activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.98)', adam_eps=1e-08, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='transformer_vaswani_wmt_en_de_big', attention_dropout=0.0, bpe='subword_nmt', bpe_codes='wmt14.en-fr.joined-dict.transformer/bpecodes', bpe_separator='@@', clip_norm=0.0, criterion='label_smoothed_cross_entropy', cross_self_attention=False, data='/content/wmt14.en-fr.joined-dict.transformer', decoder_attention_heads=16, decoder_embed_dim=1024, decoder_embed_path=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layerdrop=0, decoder_layers=6, decoder_layers_to_keep=None, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, distributed_backend='nccl', distributed_init_meth

AssertionError: ignored

In [0]:
source = 'data/parallel/IITB.en-hi.hi'
#pivot = 'data/parallel/IITB.en-hi.en'

source_handle = open(source, 'r')
#pivot_handle = open(pivot, 'r')

source_sentences = []
#pivot_sentences = []

# for s, p in zip(source_handle.readlines(), pivot_handle.readlines()):
#   source_sentences.append(s.strip())
#   pivot_sentences.append(p.strip())

for s in source_handle.readlines():
  source_sentences.append(s.strip())

In [0]:
indices = np.random.choice(np.arange(len(source_sentences)), 10)

source_samples = [source_sentences[s] for s in indices]
#pivot_samples = [pivot_sentences[s] for s in indices]

In [0]:
source_tensors = []
#pivot_tensors = []

#for s, p in zip(source_samples, pivot_samples):
for s in source_samples:
  features = encoder.extract_features(encoder.encode(s))
  decoder(features)
  #pivot_tensors.append(xlmr.extract_features(xlmr.encode(p)).squeeze().mean(axis=0))

# source_tensors = torch.stack(source_tensors).cuda().detach().cpu().numpy()
# pivot_tensors = torch.stack(pivot_tensors).cuda().detach().cpu().numpy()

# print(source_tensors.shape, pivot_tensors.shape)

RuntimeError: ignored

In [0]:
np.argmax(np.dot(source_tensors, pivot_tensors.T), axis=1)

array([0, 0, 7, 0, 0, 0, 0, 0, 0, 7])

In [0]:
distance

tensor([[945.2126, 934.2817, 943.2759, 937.4709, 935.9614, 936.4699, 932.5482,
         944.3632, 940.1902, 942.5416],
        [935.0580, 929.8993, 933.1376, 931.8377, 929.7444, 930.6298, 926.9304,
         934.3476, 930.1325, 933.9056],
        [958.2134, 937.9309, 958.1385, 941.0500, 939.2224, 940.2402, 934.7866,
         958.5428, 952.8789, 953.1746],
        [940.7936, 935.0615, 938.5024, 939.7257, 936.7359, 937.9119, 934.5394,
         940.0174, 935.3506, 939.4633],
        [941.1672, 934.5040, 938.8140, 938.1094, 936.5791, 936.7876, 933.4310,
         940.2220, 936.4614, 939.3473],
        [937.8596, 932.8284, 935.6934, 937.1517, 934.4254, 936.2974, 932.2597,
         937.0421, 932.4427, 936.9318],
        [938.9979, 933.3583, 936.6390, 937.5145, 935.0083, 935.7106, 933.7283,
         938.0032, 933.5188, 937.6682],
        [945.4584, 933.4980, 943.9000, 936.2550, 934.6656, 935.3681, 930.7615,
         945.4000, 940.8194, 942.2335],
        [949.7205, 937.3250, 948.4127, 940.5886,