In [1]:
import os
import json
from typing import Optional, Tuple, Union, List, Dict, Any
import random
import copy
from tqdm import tqdm
import numpy as np

from datasets import load_dataset, Dataset, DatasetDict, load_metric
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBart50Tokenizer, MBartTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration, MT5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers import EvalPrediction

import evaluate


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
_numpy_rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(False)
os.environ['PYTHONHASHSEED'] = str(seed)

In [3]:
os.environ["WANDB_DISABLED"] = "true"


In [4]:
# Set the device to run the model on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
args = {
    'device': device,
    'min_seq_len': 2,
    'max_seq_len': 128,
    'num_beams': 4,
    'truncation': True,
    'checkpoint': './models/mmt_africa/checkpoint/mmt_translation.pt'
}

In [30]:
def load_params(args: dict) -> dict:
    """
    Load the parameters passed to `translate`
    """

    params = {}
    model_repo = 'google/mt5-base'
    LANG_TOKEN_MAPPING = {
                'ig': '<ig>',
                'fon': '<fon>',
                'en': '<en>',
                'fr': '<fr>',
                'rw':'<rw>',
                'yo':'<yo>',
                'xh':'<xh>',
                'sw':'<sw>'
            }
    tokenizer = AutoTokenizer.from_pretrained(model_repo)
   
    model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)


    # Update tokenizer
    special_tokens_dict = {'additional_special_tokens': list(LANG_TOKEN_MAPPING.values())}
    tokenizer.add_special_tokens(special_tokens_dict)
    
    model.resize_token_embeddings(len(tokenizer))

    state_dict = torch.load(args['checkpoint'],map_location=args['device'])
    
    # print(state_dict)
   
    model.load_state_dict(state_dict['model_state_dict'])
      
    model = model.to(args['device'])

    # Load the model, load the tokenizer, max and min seq len
    params['model'] = model
    params['device'] = args['device']
    params['max_seq_len'] = args['max_seq_len'] if 'max_seq_len' in args else 50
    params['min_seq_len'] = args['min_seq_len'] if 'min_seq_len' in args else 2
    params['tokenizer'] = tokenizer
    params['num_beams'] = args['num_beams'] if 'num_beams' in args else 4
    params['lang_token'] = LANG_TOKEN_MAPPING
    params['truncation'] = args['truncation'] if 'truncation' in args else True

    return params


In [7]:
def translate(
    params: dict,
    sentence: str,
    source_lang: str,
    target_lang: str
):
    """
    Given a sentence and its target language, this translates the sentence
    to the given target sentence. 
    """

    def encode_input(params,text, target_lang, tokenizer, seq_len):
  
        target_lang_token = params['lang_token'][target_lang]
        print(f'target_lang_token: {target_lang_token}')

        # Encode
        input_ids = tokenizer.encode(
            text = str(target_lang_token) + str(text),
            return_tensors = 'pt',
            padding = 'max_length',
            truncation =  params['truncation'] ,
            max_length = seq_len)

        return input_ids[0]
    
    print(f'src: {source_lang} and tgt: {target_lang}')
    if source_lang!='' and target_lang!='':
        inp = [sentence]    
   
        input_tokens = [encode_input(params,text = inp[i],target_lang = target_lang,tokenizer = params['tokenizer'],seq_len =params['max_seq_len']).unsqueeze(0).to(params['device']) for i in range(len(inp))]
  
 
        output = [params['model'].generate(input_ids, num_beams=params['num_beams'], num_return_sequences=1,max_length=params['max_seq_len'],min_length=params['min_seq_len']) for input_ids in input_tokens]
        output = [params['tokenizer'].decode(out[0], skip_special_tokens=True) for out in tqdm(output)]
  
        return output[0]
    
    else:
        return None    
 

In [31]:
params = load_params(args=args)

In [36]:
# source_text = 'My name is vignesh'
source_text = 'How was your weekend?'

In [37]:
output = translate(
    params = params,
    sentence = source_text,
    source_lang = 'en',
    target_lang = 'sw'
)

src: en and tgt: sw
target_lang_token: <sw>


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4940.29it/s]


In [35]:
output

'Jina langu ni vignesh.'

In [38]:
output

'Jinsi ilikuwa mwishoni mwa wiki yako'

In [42]:
# source_text_2 = 'Jina langu ni vignesh.'
source_text_2 = 'Jinsi ilikuwa mwishoni mwa wiki yako'

In [43]:
output_2 = translate(
    params = params,
    sentence = source_text_2,
    source_lang = 'sh',
    target_lang = 'en'
)


src: sh and tgt: en
target_lang_token: <en>


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5497.12it/s]


In [41]:
output_2

'My name is vignesh.'

In [44]:
output_2

'How it was that weekend'