In [11]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import SeamlessM4Tv2Model, AutoProcessor

from src.tokenize import AggregatedTokenizer
from src.generate import AggregatedGenerator
from src.evaluate import SimilarityChecker



In [None]:
device = torch.device("cuda:0") 

tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-3.3B")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B").to(device)

processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
model2 = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)

tokenizer3 = AutoTokenizer.from_pretrained("google/madlad400-3b-mt")
model3 = AutoModelForSeq2SeqLM.from_pretrained("google/madlad400-3b-mt").to(device)

In [None]:
def seamless_tokenizer_postprocessing(decoder_inputs):
    input_ids_data = decoder_inputs['input_ids']
    decoder_inputs['input_ids'] = input_ids_data[(input_ids_data != 0) & (input_ids_data != 3)].unsqueeze(0) 


agg_tokenizer = AggregatedTokenizer(
    tokenizers=[
        tokenizer, 
        processor.tokenizer,
        tokenizer3,
    ],
    tokenization_kwargs=[
        dict(),
        dict(src_lang="eng", tgt_lang="rus"),
        dict(),
    ],
    decoder_tokenization_postprocessing=[
        None,
        seamless_tokenizer_postprocessing,
        None
    ]
)

In [None]:
agg_generator = AggregatedGenerator(
    models=[
        model, 
        model2,
        model3,
    ],
    generation_kwargs=[
        dict(),
        dict(generate_speech=False),
        dict(),

    ],
    agg_tokenizer=agg_tokenizer,
    decoder_prompts = [
        "rus_Cyrl",
        "__rus__",
        None,
    ],
    encoder_prompts = [
        None,
        None,
        "<2ru>"
    ]
)

In [None]:
similarity_checker = SimilarityChecker(score_names=['bertscore'])

In [None]:
encoder_input_text = """
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, 
is an Australian-based airline. It is the largest airline by fleet size 
to use the Virgin brand. It commenced services on 31 August 2000 as 
Virgin Blue, with two aircraft on a single route. It suddenly found 
itself as a major airline in Australia's domestic market after the 
collapse of Ansett Australia in September 2001. The airline has since 
grown to directly serve 32 cities in Australia, from hubs in Brisbane, 
Melbourne and Sydney."
"""

In [None]:
all_single_translations = agg_generator.generate_all_single(encoder_input_text, device=device)
all_single_translations

In [None]:
ensemble_tranlation = agg_generator.generate_agg(encoder_input_text, num_beams=3, max_new_tokens=256, device=device)
ensemble_tranlation

In [5]:
similarity_checker.check_similarity(text=all_single_translations)

{'bertscore': {'precision': 0.7491934100786845,
  'recall': 0.6914637486139933,
  'f1': 0.7189777493476868}}