In [6]:
from transformers import MarianMTModel, MarianTokenizer
from transformers import pipeline
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from copy import deepcopy
import numpy as np
from tqdm import tqdm
import os
import random
import gc

In [7]:
def get_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)

    # Freezing all but the last two decoder blocks for fine tuning
    for name, param in model.named_parameters():
        param.requires_grad = True
        # if name.startswith("model.decoder.layers.4") or name.startswith("model.decoder.layers.5"):
        #     param.requires_grad = True
        # else:
        #     param.requires_grad = False

    device = torch.device("mps" if torch.mps.is_available() else "cpu")
    model.to(device)

    return model, tokenizer

In [8]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")
# model.to(device)

In [9]:
def self_training(model, source_doc, tokenizer, lr=5e-3, decay_lambda=0.7, num_steps=2, passes=2):
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    
    # Store original parameters for decay regularization
    og_params = [p.detach().clone() for p in model.parameters()]

    for pass_idx in range(passes):
        for sent_idx, sentence in enumerate(source_doc):
            
            source_tokens = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

            # Generate translation
            with torch.no_grad():
                generated_tokens = model.generate(**source_tokens, max_length=128)
            translations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

            # Tokenize generated translation as target
            targets = tokenizer(text_target=translations, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

            for step in range(num_steps):
                optimizer.zero_grad()
                outputs = model(**source_tokens, labels=targets["input_ids"])
                loss = outputs.loss
                loss.backward()

                # Optional: gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Decay regularization
                for p_current, p_orig in zip(model.parameters(), og_params):
                    if p_current.grad is not None:
                        p_current.grad += decay_lambda * (p_current - p_orig)

                optimizer.step()

                # print(f"[Pass {pass_idx} | Sentence {sent_idx} | Step {step}] Loss: {loss.item():.4f}")

    return model

In [31]:
def print_translations(model, tokenizer, sentences, docName=""):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
        outputs = model.generate(**inputs, max_length=128)
        translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        with open(f"5a/translations_{docName}.txt", "w") as f:
            for hi in translations:
                f.write(f"{hi}\n")
        # for en, hi in zip(sentences, translations):
        #     print(f"EN: {en}")
        #     print(f"HI: {hi}")
        #     print()

In [11]:
model, tokenizer = get_model()

In [26]:
import json
with open('/Users/anishkumariyer/Documents/NLP/NLP Project/output_with_train_split.json', 'r', encoding='utf-8') as file:
    data = json.load(file)
print(data[0])
english_sentences = []
hindi_sentences = []
english_test_sentences=[]
hindi_test_sentences=[]
document_names_test=[]
for document in data:
    if document['is_train'] == False:
        english_test_sentence = [sentence['english'] for sentence in document['sentences']]
        hindi_test_sentence = [sentence['hindi'] for sentence in document['sentences']]
        english_test_sentences.append(english_test_sentence)
        hindi_test_sentences.append(hindi_test_sentence)
        document_names_test.append(document['doc_name'].replace('.txt',''))
    else:
        english_sentence = [sentence['english'] for sentence in document['sentences']]
        hindi_sentence = [sentence['hindi'] for sentence in document['sentences']]
        english_sentences.append(english_sentence)
        hindi_sentences.append(hindi_sentence)
print(len(english_sentences))
print(len(hindi_sentences))
print(len(english_test_sentences))
print(len(hindi_test_sentences))
print(len(document_names_test))
print(document_names_test[0])

{'doc_id': '1', 'doc_name': 'pm-to-visit-varanasi-on-september-17-and-18-2018.txt', 'sentences': [{'english': 'The Prime Minister, Shri Narendra Modi will be on a visit to his Parliamentary Constituency, Varanasi, on September 17 and 18, 2018.', 'hindi': 'प्रधानमंत्री श्री नरेन्द्र मोदी 17-18, 2018 सितंबर को अपने संसदीय क्षेत्र वाराणसी का दौरा करेंगे।'}, {'english': 'He will arrive in the city on the afternoon of 17th September.', 'hindi': 'वह शहर में 17 सितंबर की दोपहर को पहुंचेंगे।'}, {'english': 'He will head directly for Narur village, where he will interact with children of a primary school who are being aided by the non-profit organisation “Room to Read.”', 'hindi': 'वह सीधे नरुर गांव के लिए रवाना हो जाएंगे जहां वह एक प्राथमिक विद्यालय के छात्रों से मिलेंगे जो एक गैर-लाभकारी संगठन ‘रुम टू रीड‘ की सहायता से चल रहा है।'}, {'english': 'Later, at DLW campus, the Prime Minister will interact with students of Kashi Vidyapeeth, and children assisted by them.', 'hindi': 'बाद में, डीएलडब्

In [None]:
for doc in tqdm(english_sentences, desc="Self Training..."):
    model = self_training(
        model=model,
        source_doc=doc,
        tokenizer=tokenizer,
        lr=1e-4,
        decay_lambda=0.9,
        num_steps=2,
        passes=3
    )

    if torch.backends.mps.is_available():
        gc.collect()
        torch.mps.empty_cache()

torch.save(model.state_dict(), "5a.pth")

Self Training...:   0%|          | 0/1277 [00:31<?, ?it/s]


KeyboardInterrupt: 

In [32]:
for idx, doc in enumerate(english_test_sentences):
    print_translations(model, tokenizer, doc, document_names_test[idx])

KeyboardInterrupt: 