In [53]:
from transformers import MarianTokenizer, MarianMTModel
import torch
import numpy as np
from tqdm import tqdm

In [54]:
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = MarianMTModel.from_pretrained(model_name).to(device)

def translate_batch(sentences):
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=128).to(model.device)
    outputs = model.generate(**inputs)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

def translate_batch_with_progress(sentences):
    batch_size = 10
    total_sentences = len(sentences)
    translations = []
    for i in tqdm(range(0, total_sentences, batch_size), desc="Decoding Progress", total=total_sentences // batch_size):
        batch = sentences[i:i+batch_size]
        batch_translations = translate_batch(batch)
        translations.extend(batch_translations)

    return translations



In [55]:
def load_texts(eng_file, hin_file):
    with open(eng_file, "r", encoding="utf-8") as f:
        eng_sentences = f.readlines()
    with open(hin_file, "r", encoding="utf-8") as f:
        hin_sentences = f.readlines()
    return [s.strip() for s in eng_sentences], [s.strip() for s in hin_sentences]

english_sentences, hindi_sentences = load_texts("/Users/vishalsankarram/Desktop/github/document-level-mt-project/PM India en hi/pmindia.en-hi.en", "/Users/vishalsankarram/Desktop/github/document-level-mt-project/PM India en hi/pmindia.en-hi.hi")

In [56]:
np.random.seed(42)
sample_indices = np.random.choice(len(english_sentences), 100, replace=False)
english_sample = [english_sentences[i].strip() for i in sample_indices]
hindi_sample = [hindi_sentences[i].strip() for i in sample_indices]

In [57]:
import jiwer
transform = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemovePunctuation(),
    jiwer.Strip(),
    # jiwer.ReduceToListOfWords()
])
translations=translate_batch_with_progress(english_sample)

Decoding Progress: 100%|██████████| 10/10 [00:48<00:00,  4.85s/it]


In [58]:
with open("output.txt", "w") as file:
    for item in translations:
        file.write(item + "\n")
with open("ref.txt", "w") as file:
    for item in hindi_sample:
        file.write(item + "\n")

In [59]:
wer = jiwer.wer(hindi_sample, translations)
print(f"WER: {wer:.4f}")

WER: 0.9376


In [60]:
def calculate_ser(reference, hypothesis):
    assert len(reference) == len(hypothesis), "Reference and hypothesis lists must have the same length"
    error_count = 0
    for ref, hyp in zip(reference, hypothesis):
        if ref.strip().lower() != hyp.strip().lower():
            error_count += 1
    ser = error_count / len(reference)
    return ser

ser = calculate_ser(hindi_sample, translations)

print(f"SER: {ser:.4f}")

SER: 1.0000


In [61]:
from transformers import MarianMTModel, MarianTokenizer, DataCollatorForSeq2Seq, Trainer, TrainingArguments
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("MPS device not found, using CPU")


model_name = "Helsinki-NLP/opus-mt-en-hi"
model = MarianMTModel.from_pretrained(model_name).to(device)
tokenizer = MarianTokenizer.from_pretrained(model_name)

train_encodings = tokenizer(english_sample, max_length=128, truncation=True, padding="max_length")
train_labels = tokenizer(hindi_sample, max_length=128, truncation=True, padding="max_length")

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels["input_ids"][idx]
        return item
    
    def __len__(self):
        return len(self.encodings["input_ids"])

train_dataset = CustomDataset(train_encodings, train_labels)
eval_dataset = CustomDataset(train_encodings, train_labels)

Using MPS device


In [62]:
training_args = TrainingArguments(
    output_dir="./finetuned-mt-en-hi",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=40,
    fp16=False,
    bf16=True,
    push_to_hub=False,
    logging_dir="./logs",
    use_mps_device=True
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,No log,2.618104
2,No log,2.243296
3,No log,2.05162
4,No log,1.91591
5,No log,1.759223
6,No log,1.61989
7,No log,1.501233
8,No log,1.390108
9,No log,1.283342
10,No log,1.177199




TrainOutput(global_step=520, training_loss=1.0707148808699387, metrics={'train_runtime': 498.5488, 'train_samples_per_second': 8.023, 'train_steps_per_second': 1.043, 'total_flos': 135593459712000.0, 'train_loss': 1.0707148808699387, 'epoch': 40.0})

In [63]:
translations=translate_batch_with_progress(english_sample)

Decoding Progress: 100%|██████████| 10/10 [01:07<00:00,  6.71s/it]


In [64]:
with open("output_finetune.txt", "w") as file:
    for item in translations:
        file.write(item + "\n")

In [65]:
wer = jiwer.wer(hindi_sample, translations)
print(f"WER: {wer:.4f}")
ser = calculate_ser(hindi_sample, translations)
print(f"SER: {ser:.4f}")

WER: 0.5745
SER: 0.9900
