In [2]:
from transformers import MarianMTModel, MarianTokenizer
from transformers import pipeline
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from copy import deepcopy
import numpy as np
from tqdm import tqdm
import os
import random
import gc
import json

In [3]:
def get_model():
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)

    # Freezing all but the last two decoder blocks for fine tuning
    for name, param in model.named_parameters():
        if name.startswith("model.decoder.layers.4") or name.startswith("model.decoder.layers.5"):
            param.requires_grad = True
        else:
            param.requires_grad = False

    device = torch.device("mps" if torch.mps.is_available() else "cpu")
    model.to(device)

    return model, tokenizer

In [4]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")
# model.to(device)

In [5]:
def self_training(model, source_doc, target_doc, tokenizer, lr=5e-3, decay_lambda=0.7, hybrid_alpha=0.4, num_steps=2, passes=2):
    model.train()
    # optimizer = AdamW(model.parameters(), lr=lr)
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    
    # Store original parameters for decay regularization
    og_params = [p.detach().clone() for p in model.parameters()]

    for pass_idx in range(passes):
        for sent_idx, (en_sent, hi_sent) in enumerate(zip(source_doc, target_doc)):
            
            source_tokens = tokenizer(
                en_sent, 
                return_tensors="pt", 
                padding=True, 
                truncation=True
                # max_length=512
            ).to(model.device)

            # Generate translation
            with torch.no_grad():
                # generated_tokens = model.generate(**source_tokens, max_length=128)
                generated_tokens = model.generate(**source_tokens)
            pseudo_translations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

            # Tokenize generated translation as target
            pseudo_targets = tokenizer(text_target=pseudo_translations, return_tensors="pt", padding=True, truncation=True).to(model.device)
            gold_targets = tokenizer(text_target=[hi_sent], return_tensors="pt", padding=True, truncation=True).to(model.device)

            for step in range(num_steps):
                optimizer.zero_grad()

                pseudo_loss = model(**source_tokens, labels=pseudo_targets["input_ids"]).loss
                gold_loss = model(**source_tokens, labels=gold_targets["input_ids"]).loss    
                total_loss = hybrid_alpha*pseudo_loss + (1-hybrid_alpha)*gold_loss

                total_loss.backward()

                # Optional: gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Decay regularization
                for p_current, p_orig in zip(model.parameters(), og_params):
                    if p_current.grad is not None:
                        p_current.grad += decay_lambda * (p_current - p_orig)

                optimizer.step()

                # print(f"[Pass {pass_idx} | Sentence {sent_idx} | Step {step}] Loss: {total_loss.item():.4f}")

    return model

Read from the .json file

In [6]:
with open('../output_with_train_split.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

en_train = []
hi_train = []
en_test = []
hi_test = []
hi_test_names = []

for document in data:
    en_doc = [sent['english'] for sent in document['sentences']]
    hi_doc = [sent['hindi'] for sent in document['sentences']]
    
    if document['is_train']:
        en_train.append(en_doc)
        hi_train.append(hi_doc)
    else:
        en_test.append(en_doc)
        hi_test.append(hi_doc)
        hi_test_names.append(document['doc_name'].replace('.txt',''))
    
print(len(en_train))
print(len(hi_train))
print(len(en_test))
print(len(hi_test))

1277
1277
450
450


In [7]:
def print_translations(model, tokenizer, sentences, title=""):
    print(f"\n--- {title} ---")
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
        outputs = model.generate(**inputs, max_length=128)
        translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for en, hi in zip(sentences, translations):
            print(f"EN: {en}")
            print(f"HI: {hi}")
            print()

def write_translations(model, tokenizer, sentences, docName=""):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
        outputs = model.generate(**inputs, max_length=128)
        translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        with open(f"5c/translations_{docName}.txt", "w") as f:
            for hi in translations:
                f.write(f"{hi}\n")

In [19]:
def load_from_pm(dir_path, num_docs=250, seed=42):
    random.seed(seed)
    all_docs = [f for f in os.listdir(dir_path) if f.endswith(".txt")]

    selected_docs = random.sample(all_docs, k=num_docs)

    documents = []

    for filename in selected_docs:
        file_path = os.path.join(dir_path, filename)
        with open(file_path, "r", encoding='utf-8') as fin:
            sentences = [line.strip() for line in fin.readlines() if line.strip()]
            if sentences:
                documents.append(sentences)
    return documents

dir_path = "../../dataset/split"
source_docs = load_from_pm(dir_path, num_docs=50)

In [8]:
model, tokenizer = get_model()

In [9]:
for idx, doc in enumerate(en_test[:2]):
    print_translations(model, tokenizer, doc, title=f"Doc {idx+1}: Before Self-Training")


--- Doc 1: Before Self-Training ---
EN: The Prime Minister, Shri Narendra Modi will be on a visit to his Parliamentary Constituency, Varanasi, on September 17 and 18, 2018.
HI: प्रधानमंत्री, शैरी नाही मोडी, सितंबर 17 और 1818 के सितंबर में अपने कॉन्वेंटीसी के लिए एक भेंट पर होगा ।

EN: He will arrive in the city on the afternoon of 17th September.
HI: वह 17 सितंबर की दोपहर को शहर आएगा ।

EN: He will head directly for Narur village, where he will interact with children of a primary school who are being aided by the non-profit organisation “Room to Read.”
HI: वह सीधे मोकर गाँव का मुखिया होगा, जहाँ वो एक प्राथमिक स्कूल के बच्चों के साथ व्यवहार करेगा जो गैर-कानूनी संगठन द्वारा मदद की जा रही हैं " पढ़ें"।

EN: Later, at DLW campus, the Prime Minister will interact with students of Kashi Vidyapeeth, and children assisted by them.
HI: बाद में, DLWWA में, प्रधानमंत्री कोशीशी विटस्‌ के विद्यार्थियों के साथ व्यवहार करेंगे, और बच्चों ने उनकी सहायता की ।

EN: On the 18th, at BHU Amphitheatre, the 

Self Training Loop with Hybrid Loss Function

In [None]:
# num_docs = len(en_train)
for src_doc, tar_doc in tqdm(zip(en_train, hi_train), desc='Self Training...'):
    model = self_training(
        model=model,
        source_doc=src_doc,
        target_doc=tar_doc,
        tokenizer=tokenizer,
        lr=1e-4,
        decay_lambda=0.9,
        hybrid_alpha=0.25,
        num_steps=2,
        passes=3
    )

    if torch.backends.mps.is_available():
        gc.collect()
        torch.mps.empty_cache()

torch.save(model.state_dict(), "./5c.pth")

Self Training...: 1it [00:31, 31.39s/it]


KeyboardInterrupt: 

In [43]:
for idx, doc in enumerate(en_test[:2]):
    print_translations(model, tokenizer, doc, title=f"Doc {idx+1}: After Self-Training")


--- Doc 1: After Self-Training ---
EN: The Prime Minister, Shri Narendra Modi will be on a visit to his Parliamentary Constituency, Varanasi, on September 17 and 18, 2018.
HI: प्रधानमंत्री नरेन्‍न्‍द्र मोदी ने 17 और 1818 के सितंबर को अपनी सरकार से मिलने के लिए कहा ।

EN: He will arrive in the city on the afternoon of 17th September.
HI: वह 17 सितंबर की दोपहर को शहर आएगा ।

EN: He will head directly for Narur village, where he will interact with children of a primary school who are being aided by the non-profit organisation “Room to Read.”
HI: वह सीधे मोयर गांव के लिए सिर होगा, जहां एक प्राथमिक स्कूल के बच्चों के साथ बातचीत की जा रही है जो गैर-कानूनी संगठन द्वारा मदद की जा रही है " पढ़ें"।

EN: Later, at DLW campus, the Prime Minister will interact with students of Kashi Vidyapeeth, and children assisted by them.
HI: बाद में, DLWWA में प्रधानमंत्री ने नीशी वदवस्‍त के विद्यार्थियों के साथ व्यवहार किया, और बच्चों ने उनकी सहायता की ।

EN: On the 18th, at BHU Amphitheatre, the Prime Minist

In [None]:
for idx, doc in enumerate(en_test):
    write_translations(model, tokenizer, doc, hi_test_names[idx])