In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import logging
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from torch.utils.data import DataLoader

from utils import io
from utils import translate
from model import m2m100_dataset

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

## Read data

In [4]:
data_dir = 'data/train'

In [5]:
train_file = f'{data_dir}/train.tsv'
data, vocab = io.load_xnli_dataset(train_file)

In [52]:
data_dict = translate.create_languagewise_dict(data)

In [53]:
data.head()

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,"At ground level, the asymmetrical cathedral is...",It's hard to find a dramatic view of the cathe...,en
1,contradiction,Hanuman is a beneficent deity predating classi...,Hanuman declared that all the lemurs here need...,en
2,contradiction,All other spending as well as federal revenue ...,None of the federal spending is assumed to grow,en
3,neutral,uh-huh that's interesting well it sounds as th...,That information about graduation rates is int...,en
4,neutral,Some kind of instant recognition on his father...,Did his father recognize him?,en


## Model

In [10]:
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")

model = model.to(device)

## Dataloader

In [11]:
lang_code = 'en'

In [20]:
dataset = m2m100_dataset.M2M100Dataset(data_dict, lang_code, tokenizer, device)

In [23]:
batch_size = 16

In [24]:
dataloader = DataLoader(
        dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=False,
    )

In [25]:
batch = next(iter(dataloader))

In [30]:
batch['label'].tolist()

[1, 0, 0, 1, 1, 2, 1, 1, 1, 0, 0, 0, 0, 0, 1, 2]

In [62]:
tar_code = 'hi'
outputs = model.generate(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], 
                         forced_bos_token_id=tokenizer.get_lang_id(tar_code), max_new_tokens=500)

In [70]:
tar_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [72]:
#tar_text

## Translation

In [62]:
class TranslateText():
    
    def __init__(self, model, tokenizer, data_dict, vocab, batch_size, device):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        
        self.vocab = vocab
        self.data_dict = data_dict
        
        self.batch_size = batch_size
        
        self.device = device
        
    def run(self, lang_list, save_file):
        
        for src_code in tqdm(lang_list):
            for tar_code in lang_list:
                
                if src_code != tar_code:
                    logging.info(f'Translating {src_code} to {tar_code}')
                    output = self.translate(src_code, tar_code)
                    df = pd.DataFrame(output)
                    df.to_csv(save_file.format(f'{src_code}_{tar_code}'))
        
    
    def translate(self, src_code, tar_code, max_new_tokens=500):
        dataset = m2m100_dataset.M2M100Dataset(self.data_dict, src_code, 
                                               self.tokenizer, self.device)
        
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size,
            drop_last=False,
            num_workers=0,
            shuffle=False,
        )
        
        tar_texts = {}
        for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            outputs = model.generate(**batch['premise'], forced_bos_token_id=tokenizer.get_lang_id(tar_code), 
                                      max_new_tokens=max_new_tokens)
            tar_premise = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            l = tar_texts.setdefault('premise', [])
            l.extend(tar_premise)
            
            outputs = model.generate(**batch['hypothesis'], forced_bos_token_id=tokenizer.get_lang_id(tar_code), 
                                      max_new_tokens=max_new_tokens)
            tar_hypothesis = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            l = tar_texts.setdefault('hypothesis', [])
            l.extend(tar_hypothesis)
            
            l = tar_texts.setdefault('label', [])
            l.extend(np.array(self.vocab)[batch['label'].tolist()].tolist())

        return tar_texts
            

In [64]:
log_path = f"{data_dir}/log.txt"
logger = io.set_logger(log_path)

NameError: name 'set_logger' is not defined

In [63]:
trans_text = TranslateText(model, tokenizer, data_dict, vocab, batch_size=256, device=device)

lang_list = ['en', 'hi', 'sw', 'es', 'zh']
save_file = data_dir + '/extended_train_{}.csv'

trans_text.run(lang_list, save_file)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

## Extended train

In [187]:
train_extended = pd.read_csv(save_file)

In [188]:
train_extended.head()

Unnamed: 0,premise,hypothesis,gold_label,language
0,"At ground level, the asymmetrical cathedral is...",It's hard to find a dramatic view of the cathe...,1,en
1,Hanuman is a beneficent deity predating classi...,Hanuman declared that all the lemurs here need...,0,en
2,All other spending as well as federal revenue ...,None of the federal spending is assumed to grow,0,en
3,uh-huh that's interesting well it sounds as th...,That information about graduation rates is int...,1,en
4,Some kind of instant recognition on his father...,Did his father recognize him?,1,en
