## Translation

This notebook is created for those people who want to do back translation for data enhancement or some other nlp tasks.  

Also, everyone can try it!




It provides an off-line translated API using the facebook/mbart-large-50-many-to-many-mmt, and have done some preprocessing for a longer text (max_len>500 or any length). Besides, for running faster, it creates batch processing using a batch contain more samples.

Feel free to use it, any questions can be commented!  


More info:  

Pretrained Model:  
https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/tree/main  
Data Augmentation:   
[Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/pdf/1904.12848.pdf)


## Download model

In [None]:
import torch 
import pandas as pd
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
df_data=pd.read_csv('../input/jigsaw-toxic-severity-rating/comments_to_score.csv')

In [None]:
df_data.sample(2)

In [None]:
!git clone https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt

In [None]:
!rm ./mbart-large-50-many-to-many-mmt/pytorch_model.bin

In [None]:
!wget https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin -P ./mbart-large-50-many-to-many-mmt/

In [None]:
!rm ./mbart-large-50-many-to-many-mmt/sentencepiece.bpe.model

In [None]:
!wget  https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model -P ./mbart-large-50-many-to-many-mmt/

In [None]:
!ls -lh ./mbart-large-50-many-to-many-mmt

In [None]:
ml2en_tokenizer = MBart50TokenizerFast.from_pretrained("./mbart-large-50-many-to-many-mmt")

ml2en_model = MBartForConditionalGeneration.from_pretrained("./mbart-large-50-many-to-many-mmt").to(device)

## Tranlation function

In [None]:

def trans_module(text, source_language, target_language, piece_len=256, max_batch =8):
    '''
    piece_len: max length of input
    max_batch: num sample of translation per time
    '''
    
    ml2en_tokenizer.src_lang = source_language
    
    input_id = ml2en_tokenizer.encode(text)
    
    # special inputid for different language
    start_id=[input_id[0]]
    end_id=[input_id[-1]]
    input_id = input_id[1:-1]
    
    #save translated result
    res_text=''
    
    input_id_list= []
    attention_mask_list=[]
    
    # create batch samples
    for i in range(0,len(input_id),piece_len):
        tmp_id = start_id+input_id[i:i+piece_len]+end_id
        if len(input_id)<piece_len:
            #only one sample
            input_id_list.append(tmp_id)
            attention_mask_list.append([1]*len(tmp_id))
            break
        else:
            input_id_list.append(tmp_id+((piece_len+2)-len(tmp_id))*[1])#padding
            attention_mask_list.append([1]*len(tmp_id)+((piece_len+2)-len(tmp_id))*[0])
    
    # translation
    for i in range(0, len(input_id_list),max_batch):
        input_id_list_batch = input_id_list[i:i+max_batch]
        attention_mask_list_batch= attention_mask_list[i:i+max_batch]
        input_dict = {'input_ids':torch.LongTensor(input_id_list_batch).to(device),"attention_mask":torch.LongTensor(attention_mask_list_batch).to(device)}
        generated_tokens = ml2en_model.generate(
            **input_dict,
            forced_bos_token_id=ml2en_tokenizer.lang_code_to_id[target_language]
        )
        res_tmp =ml2en_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        
        # concate
        res_text+=' '.join(res_tmp)
    return res_text

## Example

In [None]:
for text in df_data['text'][:2].values:
    print('*'*20)
    print(text)

In [None]:
for text in df_data['text'][:2].values:
    chinese_translated_res=trans_module(text,"en_XX",'zh_CN')
    print('chinese_translated_res:',chinese_translated_res)
    english_translated_res=trans_module(chinese_translated_res,"zh_CN","en_XX")
    print('english_translated_res:',english_translated_res)