In [1]:
import os
import json
from typing import Optional, Tuple, Union, List, Dict, Any
import random
import copy
from tqdm import tqdm
import numpy as np

from datasets import load_dataset, Dataset, DatasetDict, load_metric
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBart50Tokenizer, MBartTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration, MT5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers import EvalPrediction

import evaluate


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
_numpy_rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(False)
os.environ['PYTHONHASHSEED'] = str(seed)

In [3]:
os.environ["WANDB_DISABLED"] = "true"


In [4]:
# Set the device to run the model on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
model_name = 'mmt_africa'
experiment = 'zero-shot'

In [6]:
args = {
    'device': device,
    'min_seq_len': 2,
    'max_seq_len': 512,
    'num_beams': 4,
    'truncation': True,
    'checkpoint': './models/mmt_africa/checkpoint/mmt_translation.pt'
}

In [7]:
def load_params(args: dict) -> dict:
    """
    Load the parameters passed to `translate`
    """

    params = {}
    model_repo = 'google/mt5-base'
    LANG_TOKEN_MAPPING = {
                'ig': '<ig>',
                'fon': '<fon>',
                'en': '<en>',
                'fr': '<fr>',
                'rw':'<rw>',
                'yo':'<yo>',
                'xh':'<xh>',
                'sw':'<sw>'
            }
    tokenizer = AutoTokenizer.from_pretrained(model_repo)
   
    model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)


    # Update tokenizer
    special_tokens_dict = {'additional_special_tokens': list(LANG_TOKEN_MAPPING.values())}
    tokenizer.add_special_tokens(special_tokens_dict)
    
    model.resize_token_embeddings(len(tokenizer))

    state_dict = torch.load(args['checkpoint'],map_location=args['device'])
   
    model.load_state_dict(state_dict['model_state_dict'])
      
    model = model.to(args['device'])

    # Load the model, load the tokenizer, max and min seq len
    params['model'] = model
    params['device'] = args['device']
    params['max_seq_len'] = args['max_seq_len'] if 'max_seq_len' in args else 50
    params['min_seq_len'] = args['min_seq_len'] if 'min_seq_len' in args else 2
    params['tokenizer'] = tokenizer
    params['num_beams'] = args['num_beams'] if 'num_beams' in args else 4
    params['lang_token'] = LANG_TOKEN_MAPPING
    params['truncation'] = args['truncation'] if 'truncation' in args else True

    return params


In [8]:
def translate(
    params: dict,
    sentence: str,
    source_lang: str,
    target_lang: str
):
    """
    Given a sentence and its source and target languages, this translates the sentence
    to the given target sentence. 
    """

    def encode_input_str_translate(params,text, target_lang, tokenizer, seq_len):
  
        target_lang_token = params['lang_token'][target_lang]

        # Tokenize and add special tokens
        input_ids = tokenizer.encode(
            text = str(target_lang_token) + str(text),
            return_tensors = 'pt',
            padding = 'max_length',
            truncation =  params['truncation'] ,
            max_length = seq_len)

        return input_ids[0]
    
    if source_lang!='' and target_lang!='':
        inp = [sentence]    
   
        input_tokens = [encode_input_str_translate(params,text = inp[i],target_lang = target_lang,tokenizer = params['tokenizer'],seq_len =params['max_seq_len']).unsqueeze(0).to(params['device']) for i in range(len(inp))]
  
 
        output = [params['model'].generate(input_ids, num_beams=params['num_beams'], num_return_sequences=1,max_length=params['max_seq_len'],min_length=params['min_seq_len']) for input_ids in input_tokens]
        output = [params['tokenizer'].decode(out[0], skip_special_tokens=True) for out in tqdm(output)]
  
        return output[0]
    
    else:
        return None    
 

In [9]:
class ReadFiles():
    def __init__(self, source_lang, target_lang):
        self.src_lang = source_lang
        self.tgt_lang = target_lang
          
    def read_test_flores_data(self, src, tgt):
        src_lang = self.src_lang
        tgt_lang = self.tgt_lang
        self.src_lang_flores = src
        self.tgt_lang_flores = tgt
        self.test_src_flores = []
        self.test_tgt_flores = []
        
        # reading flores data in formal setting
        with open(f'./data/{src_lang}-{tgt_lang}/{self.src_lang_flores}.devtest') as test_src_reader,\
        open(f'./data/{src_lang}-{tgt_lang}/{self.tgt_lang_flores}.devtest') as test_tgt_reader:
            for src_test_line, tgt_test_line in zip(test_src_reader.readlines(), test_tgt_reader.readlines()):
                self.test_src_flores.append(src_test_line.strip())
                self.test_tgt_flores.append([tgt_test_line.strip()])
        test_src_reader.close()
        test_tgt_reader.close()
  

In [10]:
en_ig = ReadFiles('en', 'ig')
en_ig.read_test_flores_data('eng_Latn', 'ibo_Latn')

en_sw = ReadFiles('en', 'sw')
en_sw.read_test_flores_data('eng_Latn', 'swh_Latn')

en_yo = ReadFiles('en', 'yo')
en_yo.read_test_flores_data('eng_Latn', 'yor_Latn')


ig_en = ReadFiles('ig', 'en')
ig_en.read_test_flores_data('ibo_Latn', 'eng_Latn')

sw_en = ReadFiles('sw', 'en')
sw_en.read_test_flores_data('swh_Latn', 'eng_Latn')

yo_en = ReadFiles('yo', 'en')
yo_en.read_test_flores_data('yor_Latn', 'eng_Latn')


In [11]:
params = load_params(args=args)



In [12]:
sacrebleu = evaluate.load("sacrebleu")

# EN - IG

In [13]:
translations = []
for _src_sentence in tqdm(en_ig.test_src_flores):
    
    output = translate(        
        params = params,
        sentence = _src_sentence,
        source_lang = en_ig.src_lang,
        target_lang = en_ig.tgt_lang
    )
    
    translations.append(output)

  0%|                                                                                                                                                                              | 0/1012 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2034.10it/s][A
  0%|▏                                                                                                                                                                     | 1/1012 [00:01<19:55,  1.18s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2116.20it/s][A
  0%|▎                                                                                                                                                                     | 2

In [14]:
translations

['Anyị ugbu a nwere 4-ọnwa-erughị ọnwa nkịta na-abụghị-diabetic na na na-eji ịbụ-diabetes,ka ọ',
 'Dr.Ehud Ur,prọfesọ nke ọgwụ na Dalhousie University na Halifax, Nova Scotia na oche nke ahụ ike na nkà mmụta sayensị nke Canadian Diabetes',
 'Dị ka ụfọdụ ndị ọzọ ndị ọkachamara,ọ na-atụ egwu banyere ma ọ bụrụ na ọrịa shuga nwere ike',
 "On Monday, Sara Danius,isi odeakwụkwọ nke Nobel Committee maka Literature na Swedish Academy, publicly mara ọkwa n'oge a redio usoro na Sveriges Radio na Sweden ",
 'Danius kwuru, Ugbu a,anyị na-eme ihe ọ bụla.M na-akpọ ma ziga ozi ịntanetị ka ya na-akpakọrịta na-enweta ',
 'Tupu mgbe ahụ, Ring si CEO, Jamie Siminoff, kwuru na ụlọ ọrụ malitere mgbe ya doorbell abụghị ekwu okwu si ya shop na',
 'O wuru a WiFi ọnụ ụzọ ntụ, o kwuru.',
 'Siminoff kwuru na ahịa boosted mgbe ya 2013 anya na a Shark Tank episode ebe na-egosi paneelụ na-akwụsị ego na startup.',
 'Na ọgwụgwụ 2017, Siminoff pụtara na ahịa telivishọn ụlọ ọrụ ụlọ ọrụ QVC.',
 'Ahịa nwekwara kpatakwara

In [15]:
bleu_metrics = sacrebleu.compute(predictions=translations, references=en_ig.test_tgt_flores)


In [16]:
bleu_metrics

{'score': 6.740920635382656,
 'counts': [8669, 2687, 1052, 440],
 'totals': [20509, 19497, 18485, 17473],
 'precisions': [42.26924764737432,
  13.781607426783609,
  5.691100892615634,
  2.5181708922337322],
 'bp': 0.7052088351826793,
 'sys_len': 20509,
 'ref_len': 27672}

In [17]:
# dump data

with open(f'./data/{en_ig.src_lang}-{en_ig.tgt_lang}/{model_name}/{experiment}/flores.predictions.{en_ig.src_lang}-{en_ig.tgt_lang}.{en_ig.tgt_lang}', 'w') as fp:
    for translation in translations:
        fp.write(translation + '\n')
fp.close()

json.dump(bleu_metrics, open(f'data/{en_ig.src_lang}-{en_ig.tgt_lang}/{model_name}/{experiment}/flores.prediction.{en_ig.src_lang}-{en_ig.tgt_lang}.{en_ig.tgt_lang}.metrics', 'w'))


# EN-SW

In [18]:
translations = []
for _src_sentence in tqdm(en_sw.test_src_flores):
    
    output = translate(        
        params = params,
        sentence = _src_sentence,
        source_lang = en_sw.src_lang,
        target_lang = en_sw.tgt_lang
    )
    
    translations.append(output)

  0%|                                                                                                                                                                              | 0/1012 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1395.78it/s][A
  0%|▏                                                                                                                                                                     | 1/1012 [00:00<09:22,  1.80it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1451.32it/s][A
  0%|▎                                                                                                                                                                     | 2

In [19]:
bleu_metrics = sacrebleu.compute(predictions = translations, references = en_sw.test_tgt_flores)



In [20]:
bleu_metrics

{'score': 19.919016279633468,
 'counts': [11835, 5818, 3187, 1792],
 'totals': [22591, 21579, 20567, 19555],
 'precisions': [52.38811916249834,
  26.96139765512767,
  15.495696990324307,
  9.163896701610842],
 'bp': 0.9412419161641801,
 'sys_len': 22591,
 'ref_len': 23959}

In [21]:
# dump data

with open(f'./data/{en_sw.src_lang}-{en_sw.tgt_lang}/{model_name}/{experiment}/flores.predictions.{en_sw.src_lang}-{en_sw.tgt_lang}.{en_sw.tgt_lang}', 'w') as fp:
    for translation in translations:
        fp.write(translation + '\n')
fp.close()

json.dump(bleu_metrics, open(f'data/{en_sw.src_lang}-{en_sw.tgt_lang}/{model_name}/{experiment}/flores.prediction.{en_sw.src_lang}-{en_sw.tgt_lang}.{en_sw.tgt_lang}.metrics', 'w'))


# EN-YO

In [22]:
translations = []
for _src_sentence in tqdm(en_yo.test_src_flores):
    
    output = translate(        
        params = params,
        sentence = _src_sentence,
        source_lang = en_yo.src_lang,
        target_lang = en_yo.tgt_lang
    )
    
    translations.append(output)

  0%|                                                                                                                                                                              | 0/1012 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2584.29it/s][A
  0%|▏                                                                                                                                                                     | 1/1012 [00:00<12:31,  1.35it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2610.02it/s][A
  0%|▎                                                                                                                                                                     | 2

In [23]:
bleu_metrics = sacrebleu.compute(predictions = translations, references = en_yo.test_tgt_flores)



In [24]:
bleu_metrics

{'score': 2.3839838609685664,
 'counts': [4475, 1123, 372, 134],
 'totals': [16353, 15341, 14329, 13317],
 'precisions': [27.365009478383172,
  7.320252917019751,
  2.5961337148440227,
  1.0062326349778479],
 'bp': 0.4984441746532194,
 'sys_len': 16353,
 'ref_len': 27739}

In [25]:
# dump data

with open(f'./data/{en_yo.src_lang}-{en_yo.tgt_lang}/{model_name}/{experiment}/flores.predictions.{en_yo.src_lang}-{en_yo.tgt_lang}.{en_yo.tgt_lang}', 'w') as fp:
    for translation in translations:
        fp.write(translation + '\n')
fp.close()

json.dump(bleu_metrics, open(f'data/{en_yo.src_lang}-{en_yo.tgt_lang}/{model_name}/{experiment}/flores.prediction.{en_yo.src_lang}-{en_yo.tgt_lang}.{en_yo.tgt_lang}.metrics', 'w'))


# IG-EN

In [26]:
translations = []
for _src_sentence in tqdm(ig_en.test_src_flores):
    
    output = translate(        
        params = params,
        sentence = _src_sentence,
        source_lang = ig_en.src_lang,
        target_lang = ig_en.tgt_lang
    )
    
    translations.append(output)

  0%|                                                                                                                                                                              | 0/1012 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1779.51it/s][A
  0%|▏                                                                                                                                                                     | 1/1012 [00:00<05:57,  2.83it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1673.04it/s][A
  0%|▎                                                                                                                                                                     | 2

In [27]:
bleu_metrics = sacrebleu.compute(predictions = translations, references = ig_en.test_tgt_flores)



In [28]:
bleu_metrics

{'score': 14.864954597247952,
 'counts': [11194, 4542, 2277, 1213],
 'totals': [24425, 23413, 22401, 21389],
 'precisions': [45.83009211873081,
  19.39947892196643,
  10.164724789071917,
  5.671139370704568],
 'bp': 0.9879544052726602,
 'sys_len': 24425,
 'ref_len': 24721}

In [29]:
# dump data

with open(f'./data/{ig_en.src_lang}-{ig_en.tgt_lang}/{model_name}/{experiment}/flores.predictions.{ig_en.src_lang}-{ig_en.tgt_lang}.{ig_en.tgt_lang}', 'w') as fp:
    for translation in translations:
        fp.write(translation + '\n')
fp.close()

json.dump(bleu_metrics, open(f'data/{ig_en.src_lang}-{ig_en.tgt_lang}/{model_name}/{experiment}/flores.prediction.{ig_en.src_lang}-{ig_en.tgt_lang}.{ig_en.tgt_lang}.metrics', 'w'))


# SW-EN

In [30]:
translations = []
for _src_sentence in tqdm(sw_en.test_src_flores):
    
    output = translate(        
        params = params,
        sentence = _src_sentence,
        source_lang = sw_en.src_lang,
        target_lang = sw_en.tgt_lang
    )
    
    translations.append(output)

  0%|                                                                                                                                                                              | 0/1012 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3084.05it/s][A
  0%|▏                                                                                                                                                                     | 1/1012 [00:00<06:14,  2.70it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2680.07it/s][A
  0%|▎                                                                                                                                                                     | 2

In [31]:
bleu_metrics = sacrebleu.compute(predictions = translations, references = sw_en.test_tgt_flores)



In [32]:
bleu_metrics

{'score': 27.19745686918656,
 'counts': [14557, 7886, 4728, 2896],
 'totals': [24293, 23281, 22269, 21257],
 'precisions': [59.92261145185856,
  33.87311541600447,
  21.231308096456956,
  13.62374747142118],
 'bp': 0.9825360498637394,
 'sys_len': 24293,
 'ref_len': 24721}

In [33]:
# dump data

with open(f'./data/{sw_en.src_lang}-{sw_en.tgt_lang}/{model_name}/{experiment}/flores.predictions.{sw_en.src_lang}-{sw_en.tgt_lang}.{sw_en.tgt_lang}', 'w') as fp:
    for translation in translations:
        fp.write(translation + '\n')
fp.close()

json.dump(bleu_metrics, open(f'data/{sw_en.src_lang}-{sw_en.tgt_lang}/{model_name}/{experiment}/flores.prediction.{sw_en.src_lang}-{sw_en.tgt_lang}.{sw_en.tgt_lang}.metrics', 'w'))


# YO-EN

In [34]:
translations = []
for _src_sentence in tqdm(yo_en.test_src_flores):
    
    output = translate(        
        params = params,
        sentence = _src_sentence,
        source_lang = yo_en.src_lang,
        target_lang = yo_en.tgt_lang
    )
    
    translations.append(output)

  0%|                                                                                                                                                                              | 0/1012 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2693.84it/s][A
  0%|▏                                                                                                                                                                     | 1/1012 [00:00<08:26,  2.00it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2818.75it/s][A
  0%|▎                                                                                                                                                                     | 2

In [35]:
bleu_metrics = sacrebleu.compute(predictions = translations, references = yo_en.test_tgt_flores)



In [36]:
bleu_metrics

{'score': 8.431320882381804,
 'counts': [8576, 2716, 1150, 539],
 'totals': [22853, 21841, 20829, 19817],
 'precisions': [37.52680173281407,
  12.435328052744838,
  5.5211483988669645,
  2.719886965736489],
 'bp': 0.9215116907122233,
 'sys_len': 22853,
 'ref_len': 24721}

In [37]:
# dump data

with open(f'./data/{yo_en.src_lang}-{yo_en.tgt_lang}/{model_name}/{experiment}/flores.predictions.{yo_en.src_lang}-{yo_en.tgt_lang}.{yo_en.tgt_lang}', 'w') as fp:
    for translation in translations:
        fp.write(translation + '\n')
fp.close()

json.dump(bleu_metrics, open(f'data/{yo_en.src_lang}-{yo_en.tgt_lang}/{model_name}/{experiment}/flores.prediction.{yo_en.src_lang}-{yo_en.tgt_lang}.{yo_en.tgt_lang}.metrics', 'w'))
