In [7]:
from transformers import MarianTokenizer, MarianMTModel
import torch
import numpy as np
from tqdm import tqdm

In [8]:
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = MarianMTModel.from_pretrained(model_name).to(device)

def translate_batch(sentences):
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=128).to(model.device)
    outputs = model.generate(**inputs)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

def translate_batch_with_progress(sentences):
    batch_size = 10
    total_sentences = len(sentences)
    translations = []
    for i in tqdm(range(0, total_sentences, batch_size), desc="Decoding Progress", total=total_sentences // batch_size):
        batch = sentences[i:i+batch_size]
        batch_translations = translate_batch(batch)
        translations.extend(batch_translations)

    return translations



In [9]:
# def load_texts(eng_file, hin_file):
#     with open(eng_file, "r", encoding="utf-8") as f:
#         eng_sentences = f.readlines()
#     with open(hin_file, "r", encoding="utf-8") as f:
#         hin_sentences = f.readlines()
#     return [s.strip() for s in eng_sentences], [s.strip() for s in hin_sentences]

# english_sentences, hindi_sentences = load_texts("/Users/vishalsankarram/Desktop/github/document-level-mt-project/PM India en hi/pmindia.en-hi.en", "/Users/vishalsankarram/Desktop/github/document-level-mt-project/PM India en hi/pmindia.en-hi.hi")

In [10]:
import json
with open('/Users/vishalsankarram/Desktop/github/document-level-mt-project/output_with_train_split.json', 'r', encoding='utf-8') as file:
    data = json.load(file)
print(data[0])
english_sentences = []
hindi_sentences = []
english_test_sentences=[]
hindi_test_sentences=[]
for document in data:
    if document['is_train'] == False:
        english_test_sentence = [sentence['english'] for sentence in document['sentences']]
        hindi_test_sentence = [sentence['hindi'] for sentence in document['sentences']]
        english_test_sentences.extend(english_test_sentence)
        hindi_test_sentences.extend(hindi_test_sentence)
    else:
        english_sentence = [sentence['english'] for sentence in document['sentences']]
        hindi_sentence = [sentence['hindi'] for sentence in document['sentences']]
        english_sentences.extend(english_sentence)
        hindi_sentences.extend(hindi_sentence)
print(len(english_sentences))
print(len(hindi_sentences))
print(len(english_test_sentences))
print(len(hindi_test_sentences))

{'doc_id': '1', 'doc_name': 'pm-to-visit-varanasi-on-september-17-and-18-2018.txt', 'sentences': [{'english': 'The Prime Minister, Shri Narendra Modi will be on a visit to his Parliamentary Constituency, Varanasi, on September 17 and 18, 2018.', 'hindi': 'प्रधानमंत्री श्री नरेन्द्र मोदी 17-18, 2018 सितंबर को अपने संसदीय क्षेत्र वाराणसी का दौरा करेंगे।'}, {'english': 'He will arrive in the city on the afternoon of 17th September.', 'hindi': 'वह शहर में 17 सितंबर की दोपहर को पहुंचेंगे।'}, {'english': 'He will head directly for Narur village, where he will interact with children of a primary school who are being aided by the non-profit organisation “Room to Read.”', 'hindi': 'वह सीधे नरुर गांव के लिए रवाना हो जाएंगे जहां वह एक प्राथमिक विद्यालय के छात्रों से मिलेंगे जो एक गैर-लाभकारी संगठन ‘रुम टू रीड‘ की सहायता से चल रहा है।'}, {'english': 'Later, at DLW campus, the Prime Minister will interact with students of Kashi Vidyapeeth, and children assisted by them.', 'hindi': 'बाद में, डीएलडब्

In [11]:
np.random.seed(42)
sample_indices = np.random.choice(len(english_sentences), len(english_sentences), replace=False)
english_sample = [english_sentences[i].strip() for i in sample_indices]
hindi_sample = [hindi_sentences[i].strip() for i in sample_indices]

In [12]:
import jiwer
transform = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemovePunctuation(),
    jiwer.Strip(),
    # jiwer.ReduceToListOfWords()
])
translations=translate_batch_with_progress(english_test_sentences)

Decoding Progress:   3%|▎         | 13/381 [01:24<40:03,  6.53s/it] 


KeyboardInterrupt: 

In [7]:
with open("output.txt", "w") as file:
    for item in translations:
        file.write(item + "\n")
with open("ref.txt", "w") as file:
    for item in hindi_sample:
        file.write(item + "\n")

In [8]:
wer = jiwer.wer(hindi_sample, translations)
print(f"WER: {wer:.4f}")

WER: 0.9376


In [13]:
def calculate_ser(reference, hypothesis):
    assert len(reference) == len(hypothesis), "Reference and hypothesis lists must have the same length"
    error_count = 0
    for ref, hyp in zip(reference, hypothesis):
        if ref.strip().lower() != hyp.strip().lower():
            error_count += 1
    ser = error_count / len(reference)
    return ser

# ser = calculate_ser(hindi_sample, translations)

# print(f"SER: {ser:.4f}")

In [14]:
from transformers import MarianMTModel, MarianTokenizer, DataCollatorForSeq2Seq, Trainer, TrainingArguments
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("MPS device not found, using CPU")


model_name = "Helsinki-NLP/opus-mt-en-hi"
model = MarianMTModel.from_pretrained(model_name).to(device)
tokenizer = MarianTokenizer.from_pretrained(model_name)

train_encodings = tokenizer(english_sample, max_length=128, truncation=True, padding="max_length")
train_labels = tokenizer(hindi_sample, max_length=128, truncation=True, padding="max_length")

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels["input_ids"][idx]
        return item
    
    def __len__(self):
        return len(self.encodings["input_ids"])

train_dataset = CustomDataset(train_encodings, train_labels)
eval_dataset = CustomDataset(train_encodings, train_labels)

Using MPS device


In [None]:
training_args = TrainingArguments(
    output_dir="./finetuned-mt-en-hi",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=1e-4,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=10,
    fp16=False,
    bf16=True,
    push_to_hub=False,
    logging_dir="./logs",
    use_mps_device=True
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss


In [12]:
translations=translate_batch_with_progress(english_test_sentences)

Decoding Progress: 100%|██████████| 10/10 [01:45<00:00, 10.54s/it]


In [13]:
with open("output_finetune.txt", "w") as file:
    for item in translations:
        file.write(item + "\n")

In [14]:
wer = jiwer.wer(hindi_sample, translations)
print(f"WER: {wer:.4f}")
ser = calculate_ser(hindi_sample, translations)
print(f"SER: {ser:.4f}")

WER: 0.3370
SER: 0.9600
