In [1]:
from transformers import MarianMTModel, MarianTokenizer
from transformers import pipeline
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from copy import deepcopy
import numpy as np
from tqdm import tqdm
import os
import random
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)

    # Freezing all but the last two decoder blocks for fine tuning
    for name, param in model.named_parameters():
        if name.startswith("model.decoder.layers.4") or name.startswith("model.decoder.layers.5"):
            param.requires_grad = True
        else:
            param.requires_grad = False

    device = torch.device("mps" if torch.mps.is_available() else "cpu")
    model.to(device)

    return model, tokenizer

In [3]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")
# model.to(device)

In [4]:
def self_training(model, source_doc, tokenizer, lr=5e-3, decay_lambda=0.7, num_steps=2, passes=2):
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    
    # Store original parameters for decay regularization
    og_params = [p.detach().clone() for p in model.parameters()]

    for pass_idx in range(passes):
        for sent_idx, sentence in enumerate(source_doc):
            
            source_tokens = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

            # Generate translation
            with torch.no_grad():
                generated_tokens = model.generate(**source_tokens, max_length=128)
            translations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

            # Tokenize generated translation as target
            targets = tokenizer(text_target=translations, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

            for step in range(num_steps):
                optimizer.zero_grad()
                outputs = model(**source_tokens, labels=targets["input_ids"])
                loss = outputs.loss
                loss.backward()

                # Optional: gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Decay regularization
                for p_current, p_orig in zip(model.parameters(), og_params):
                    if p_current.grad is not None:
                        p_current.grad += decay_lambda * (p_current - p_orig)

                optimizer.step()

                # print(f"[Pass {pass_idx} | Sentence {sent_idx} | Step {step}] Loss: {loss.item():.4f}")

    return model

In [5]:
def print_translations(model, tokenizer, sentences, title=""):
    print(f"\n--- {title} ---")
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
        outputs = model.generate(**inputs, max_length=128)
        translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for en, hi in zip(sentences, translations):
            print(f"EN: {en}")
            print(f"HI: {hi}")
            print()

In [6]:
def read_documents_from_dir(directory_path):
    documents = []

    for filename in sorted(os.listdir(directory_path)):
        if filename.endswith(".txt"):
            file_path = os.path.join(directory_path, filename)
            with open(file_path, "r", encoding="utf-8") as f:
                # Each line is a sentence; remove blank lines
                sentences = [line.strip() for line in f.readlines() if line.strip()]
                if sentences:
                    documents.append(sentences)
    return documents

dir_path = "./dataset/test1"
test_docs = read_documents_from_dir(dir_path)

In [7]:
def load_from_pm(dir_path, num_docs=250, seed=42):
    random.seed(seed)
    all_docs = [f for f in os.listdir(dir_path) if f.endswith(".txt")]

    selected_docs = random.sample(all_docs, k=num_docs)

    documents = []

    for filename in selected_docs:
        file_path = os.path.join(dir_path, filename)
        with open(file_path, "r", encoding='utf-8') as fin:
            sentences = [line.strip() for line in fin.readlines() if line.strip()]
            if sentences:
                documents.append(sentences)
    return documents

dir_path = "./dataset/split"
source_docs = load_from_pm(dir_path, num_docs=50)

In [13]:
model, tokenizer = get_model()

In [14]:
for idx, doc in enumerate(test_docs):
    print_translations(model, tokenizer, doc, title=f"Doc {idx+1}: Before Self-Training")


--- Doc 1: Before Self-Training ---
EN: Prime Minister, Shri Narendra Modi, has greeted the people of Telangana and Andhra Pradesh.
HI: प्रधानमंत्री, शैरी नाही मोडी ने तेलहना और आशश के लोगों को नमस्कार किया है ।

EN: “My best wishes to the people of Telangana on the occasion of their Statehood Day.
HI: “ मेरी सबसे अच्छी ख्वाहिश है कि मैं अपने राष्ट्र - समारोह के अवसर पर तेलगाना के लोगों से दूर हो जाऊँ ।

EN: My best wishes for the State’s development journey.
HI: मैं सरकार के विकास यात्रा के लिए पूरी इच्छा रखता हूँ.

EN: Greetings & good wishes to my sisters & brothers of Andhra Pradesh in the development journey of this hardworking State,” the Prime Minister said.
HI: इस मेहनती राज्य के विकास यात्रा में मेरी बहनों और उनके भाइयों के लिए नमस्कार अच्छा चाहता है.


--- Doc 2: Before Self-Training ---
EN: Prime Minister, Shri Narendra Modi has conveyed his condolences to the families of those on Flight QZ8501.
HI: प्रधानमंत्री, शैरी नाही मोडी ने अपने परिवार को QZ8501 पर उड़ान भरने वाले पर

In [None]:
for doc in tqdm(test_docs, desc="Self Training..."):
    model = self_training(
        model=model,
        source_doc=doc,
        tokenizer=tokenizer,
        lr=1e-4,
        decay_lambda=0.9,
        num_steps=2,
        passes=3
    )

    if torch.backends.mps.is_available():
        gc.collect()
        torch.mps.empty_cache()

Self Training...: 100%|██████████| 50/50 [41:13<00:00, 49.48s/it]  


In [12]:
for idx, doc in enumerate(test_docs):
    print_translations(model, tokenizer, doc, title=f"Doc {idx+1}: After Self-Training")


--- Doc 1: After Self-Training ---
EN: Prime Minister, Shri Narendra Modi, has greeted the people of Telangana and Andhra Pradesh.
HI: उन्होंने यह भी बताया कि उन्होंने क्या - क्या किया ।

EN: “My best wishes to the people of Telangana on the occasion of their Statehood Day.
HI: उन्होंने यह भी बताया कि उन्होंने क्या - क्या किया ।

EN: My best wishes for the State’s development journey.
HI: उन्होंने यह भी बताया कि उन्होंने क्या - क्या किया ।

EN: Greetings & good wishes to my sisters & brothers of Andhra Pradesh in the development journey of this hardworking State,” the Prime Minister said.
HI: उन्होंने यह भी बताया कि उन्होंने क्या - क्या किया ।


--- Doc 2: After Self-Training ---
EN: Prime Minister, Shri Narendra Modi has conveyed his condolences to the families of those on Flight QZ8501.
HI: उन्होंने यह भी बताया कि उन्होंने क्या - क्या किया ।

EN: “Our thoughts are with the families of those on Flight QZ8501.
HI: उन्होंने यह भी बताया कि उन्होंने क्या - क्या किया ।

EN: We offer our c