In [None]:
import pandas as pd

# Load the cleaned Spanish -> English dataset
df = pd.read_csv('../datasets/cleaned_es_en_dataset.csv', delimiter=';', encoding='utf-8')
inputs = df['input_text'].tolist()

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

# Load the model and tokenizer
model_name = 'facebook/mbart-large-50-many-to-many-mmt'
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)

# 10.2 to 11.7 to 8.6
# delta is ~1.5gb

# 9.1 to 10.7 to 6.8
# delta is ~1.6gb

# 9 to 10.3 to 8.6
# delta is ~1.3gb

# 9.5 to 11.2
# delta is ~1.7gb

In [None]:
# Set source and target languages
source_lang = 'es_XX'  # Spanish
target_lang = 'en_XX'  # English

# Set the tokenizer to the source language
tokenizer.src_lang = source_lang

In [None]:
import time

# Generate translations (takes ~x minutes)
translations = []
times = []
token_counts = []
for text in inputs:
    encoded_input = tokenizer(text, return_tensors="pt")
    
    start_time = time.time()
    translated_tokens = model.generate(
        **encoded_input,
        forced_bos_token_id=tokenizer.lang_code_to_id[target_lang]
    )
    end_time = time.time()
    
    translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
    translations.append(translated_text)
    times.append(end_time - start_time)
    num_tokens = len(encoded_input.input_ids[0])
    token_counts.append(num_tokens)

total_time = sum(times)
total_tokens = sum(token_counts)
average_time_per_token = total_time / total_tokens

In [None]:
print(f'Total time taken: {total_time:.6f} seconds')
print(f'Total number of tokens: {total_tokens}')
print(f'Average time per token: {average_time_per_token:.6f} seconds')
print('=================================================================')

for text, translation, duration, token_count in zip(inputs, translations, times, token_counts):
    print(f'Spanish text: {text}')
    print(f'English translation: {translation}')
    print(f'Time taken: {duration:.6f} seconds')
    print(f'Number of tokens: {token_count}\n')

In [None]:
# Save translated text to a new CSV file
tr_df = pd.DataFrame(translations, columns=['translated_text'])
tr_df.to_csv('../translated-datasets/mbart-large-50-many-to-many-mmt-translated_es_en_dataset.csv', sep=';', index=False)