In [6]:
import torch
from transformers import MarianMTModel, MarianTokenizer

In [7]:
class BacktranslationMachine:
    def __init__(self, src="en", tgt="de"):
        # Languages code: https://developers.google.com/admin-sdk/directory/v1/languages 
        # https://towardsdatascience.com/data-augmentation-in-nlp-using-back-translation-with-marianmt-a8939dfea50a
        self.src = src
        self.tgt = tgt
        
        self.tokenizer1 = MarianTokenizer.from_pretrained(f"Helsinki-NLP/opus-mt-{src}-{tgt}")  
        self.model1 = MarianMTModel.from_pretrained(f"Helsinki-NLP/opus-mt-{src}-{tgt}")

        self.tokenizer2 = MarianTokenizer.from_pretrained(f"Helsinki-NLP/opus-mt-{tgt}-{src}")  
        self.model2 = MarianMTModel.from_pretrained(f"Helsinki-NLP/opus-mt-{tgt}-{src}")


        self.device = torch.device("cpu")
        self.model1 = self.model1.to(self.device)
        self.model2 = self.model2.to(self.device)
        

    def process_text(self, lang_code, text):
        formatted_text = [f">>{lang_code}<< {t}" for t in text]
        return formatted_text


    def translation1(self, input_sentence):
        # translate to second language
        formatted_text = self.process_text(self.tgt, input_sentence)
        encoded_lang1 = self.tokenizer1(formatted_text, padding=True, return_tensors="pt")
        translated_encoded_lang1 = self.model1.generate(**encoded_lang1)
        decoded_lang2 = self.tokenizer1.batch_decode(translated_encoded_lang1, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return decoded_lang2

    def translation2(self, input_sentence):
        # translate back to first language
        formatted_text = self.process_text(self.src, input_sentence)
        encoded_lang2 = self.tokenizer2(formatted_text, padding=True, return_tensors="pt")
        translated_encoded_lang1 = self.model2.generate(**encoded_lang2)
        decoded_lang1 = self.tokenizer2.batch_decode(translated_encoded_lang1, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return decoded_lang1

    def backtranslation(self, input_sentence):
        translated_sentence = self.translation1(input_sentence)
        backtranslated_sentence = self.translation2(translated_sentence)
        return backtranslated_sentence

In [8]:
src = "en"
tgt = "fi" # de, zh, 
bm = BacktranslationMachine(src=src, tgt=tgt)

In [15]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_sentence = ["Sometimes it can be difficult to think of a creative domain that you'd like to pursue, but it's all right and don't worry! We'll figure it out for you. Here are a few exercises you can do to reconnect to your childhood and ignite a creative domain."]

In [16]:
bm_generated_sentence = bm.backtranslation(input_sentence)
print(bm_generated_sentence)

['Sometimes it may be difficult to think about the creative area that you want to pursue, but it’s all right! We’ll figure it out for you. Here are some exercises you can do to get back to childhood and set the creative area on fire.']
