In [1]:
import os, re, math, json, random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
from tqdm.auto import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)


  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


  return torch._C._cuda_getDeviceCount() > 0


In [None]:
BASE_MODEL_DIR = "./mT5_multilingual_XLSum"     
TRAIN_CSV      = "summarization_dataset.csv"    
OUT_DIR        = "model-langanchor-finetuned"

os.makedirs(OUT_DIR, exist_ok=True)

EPOCHS            = 3
BATCH_SIZE        = 8         
LEARNING_RATE     = 2e-5
WEIGHT_DECAY      = 0.01
WARMUP_RATIO      = 0.06
GRAD_ACCUM_STEPS  = 1         
MAX_INPUT_LENGTH  = 512
MAX_TARGET_LENGTH = 84
FP16              = True      
GRAD_CHKPT        = True      

LAMBDA_ANCHOR     = 1.0       
ANCHOR_BATCH_EVERY= 1        
SAVE_EVERY_STEPS  = 1000      
SEED              = 42
random.seed(SEED); torch.manual_seed(SEED)


<torch._C.Generator at 0x23404aee810>

In [3]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR)

model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_DIR)
if GRAD_CHKPT:
    model.gradient_checkpointing_enable()
model.to(DEVICE)
model.train()

orig_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_DIR)
orig_model.eval()
orig_model.to(DEVICE)
for p in orig_model.parameters():
    p.requires_grad = False

print("Models loaded.")


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Models loaded.


In [4]:
WHITESPACE = lambda s: re.sub(r'\s+', ' ', re.sub(r'\n+', ' ', s.strip()))

class SummDataset(Dataset):
    def __init__(self, df, tokenizer, max_in=512, max_out=84):
        self.inputs  = [WHITESPACE(x) for x in df["article"].tolist()]
        self.targets = [WHITESPACE(x) for x in df["highlights"].tolist()]
        self.tok = tokenizer
        self.max_in = max_in
        self.max_out= max_out
    def __len__(self): return len(self.inputs)
    def __getitem__(self, i):
        enc = self.tok(self.inputs[i], truncation=True, padding="max_length",
                       max_length=self.max_in, return_tensors="pt")
        with self.tok.as_target_tokenizer():
            lab = self.tok(self.targets[i], truncation=True, padding="max_length",
                           max_length=self.max_out, return_tensors="pt")
        labels = lab["input_ids"].squeeze(0)
        labels[labels == self.tok.pad_token_id] = -100
        item = {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": labels
        }
        return item

df = pd.read_csv(TRAIN_CSV)
train_ds = SummDataset(df, tokenizer, MAX_INPUT_LENGTH, MAX_TARGET_LENGTH)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
len(train_ds), len(train_dl)


(10000, 1250)

In [5]:
ANCHOR_PROMPTS = {
    "en": "This is a generic test sentence to anchor the language features.",
    "hi": "यह भाषा की विशेषताओं को स्थिर करने के लिए एक सामान्य वाक्य है।",
    "bn": "এটি ভাষার বৈশিষ্ট্য স্থির করতে একটি সাধারণ বাক্য।",
    "ur": "یہ زبان کی خصوصیات کو برقرار رکھنے کے لیے ایک عمومی جملہ ہے۔",
    "gu": "ભાષાની વિશેષતાઓને સ્થિર કરવા માટે આ એક સામાન્ય વાક્ય છે.",
    "mr": "भाषेच्या वैशिष्ट्यांना स्थिर ठेवण्यासाठी हे एक सामान्य वाक्य आहे.",
    "ta": "மொழியின் பண்புகளை உறுதிப்படுத்த இந்த ஒரு பொதுவான வாக்கியம்.",
    "te": "భాష లక్షణాలను నిలుపుకునేందుకు ఇది ఒక సాధారణ వాక్యం.",
    "kn": "ಭಾಷಾ ಲಕ್ಷಣಗಳನ್ನು ಸ್ಥಿರಗೊಳಿಸಲು ಇದು ಒಂದು ಸಾಮಾನ್ಯ ವಾಕ್ಯ.",
    "ml": "ഭാഷയുടെ സവിശേഷതകളെ നിലനിർത്താൻ ഇത് ഒരു പൊതുവായ വാക്ക്യമാണ്.",
    "pa": "ਭਾਸ਼ਾ ਦੇ ਗੁਣਾਂ ਨੂੰ ਕਾਇਮ ਰੱਖਣ ਲਈ ਇਹ ਇੱਕ ਸਧਾਰਣ ਵਾਕ ਹੈ।",
    "ar": "هذه جملة عامة لتثبيت خصائص اللغة.",
    "fa": "این یک جملهٔ عمومی برای پایدار کردن ویژگی‌های زبان است.",
    "tr": "Dil özelliklerini sabitlemek için bu genel bir cümledir.",
    "ru": "Это обычное предложение для закрепления языковых особенностей.",
    "uk": "Це загальне речення для закріплення мовних особливостей.",
    "fr": "Ceci est une phrase générique pour ancrer les caractéristiques de la langue.",
    "es": "Esta es una frase genérica para anclar las características del idioma.",
    "de": "Dies ist ein allgemeiner Satz, um die Spracheigenschaften zu verankern.",
    "it": "Questa è una frase generica per ancorare le caratteristiche della lingua.",
    "pt": "Esta é uma frase genérica para ancorar as características do idioma.",
    "sw": "Hii ni sentensi ya jumla ya kuimarisha sifa za lugha.",
    "yo": "Eyi jẹ gbolohun apapọ lati di awọn ẹya ede mọ.",
    "am": "የቋንቋውን ባህሪያት ለማስቀመጥ ይህ አጠቃላይ አረፍተ ነገር ነው።",
    "zh": "这是一个用于固定语言特征的通用句子。",
    "ja": "言語の特徴を固定するための一般的な文です。",
    "ko": "언어 특징을 고정하기 위한 일반적인 문장입니다.",
}

anchor_texts = list(ANCHOR_PROMPTS.values())
anchor_batch = tokenizer(anchor_texts, return_tensors="pt", padding=True, truncation=True, max_length=64)
anchor_ids   = anchor_batch["input_ids"].to(DEVICE)
anchor_mask  = anchor_batch["attention_mask"].to(DEVICE)

with torch.no_grad():
    orig_enc = orig_model.get_encoder()(input_ids=anchor_ids, attention_mask=anchor_mask)
    orig_hidden = orig_enc.last_hidden_state.detach()  


In [6]:
mse = nn.MSELoss(reduction="sum")

def anchor_loss(new_hidden, ref_hidden, mask):
    m = mask.unsqueeze(-1).float()
    diff2 = (new_hidden - ref_hidden) ** 2
    num = (diff2 * m).sum()
    den = m.sum().clamp(min=1.0)
    return num / den


In [7]:
no_decay = ["bias", "LayerNorm.weight"]
params = [
    {"params": [p for n,p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": WEIGHT_DECAY},
    {"params": [p for n,p in model.named_parameters() if     any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(params, lr=LEARNING_RATE)
total_steps = math.ceil(len(train_dl) / GRAD_ACCUM_STEPS) * EPOCHS
warmup_steps = int(WARMUP_RATIO * total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
scaler = torch.cuda.amp.GradScaler(enabled=FP16)


  scaler = torch.cuda.amp.GradScaler(enabled=FP16)


In [8]:
global_step = 0
model.train()

for epoch in range(1, EPOCHS+1):
    pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{EPOCHS}")
    running = 0.0
    for step, batch in enumerate(pbar, 1):
        batch = {k: v.to(DEVICE) for k,v in batch.items()}

        with torch.cuda.amp.autocast(enabled=FP16):
           
            out = model(**batch)
            loss_main = out.loss

            
            if (global_step % ANCHOR_BATCH_EVERY) == 0:
                new_enc = model.get_encoder()(input_ids=anchor_ids, attention_mask=anchor_mask)
                new_hidden = new_enc.last_hidden_state
                loss_anchor = anchor_loss(new_hidden, orig_hidden, anchor_mask)
            else:
                loss_anchor = torch.tensor(0.0, device=DEVICE)

            loss = loss_main + LAMBDA_ANCHOR * loss_anchor

        scaler.scale(loss / GRAD_ACCUM_STEPS).backward()
        if (step % GRAD_ACCUM_STEPS) == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            global_step += 1

        running += loss.item()
        if step % 50 == 0:
            pbar.set_postfix({
                "loss": f"{running/step:.4f}",
                "main": f"{loss_main.item():.4f}",
                "anchor": f"{loss_anchor.item():.4f}"
            })

        if global_step % SAVE_EVERY_STEPS == 0:
            ckpt_dir = os.path.join(OUT_DIR, f"checkpoint-{global_step}")
            os.makedirs(ckpt_dir, exist_ok=True)
            model.save_pretrained(ckpt_dir)
            tokenizer.save_pretrained(ckpt_dir)

   
    model.save_pretrained(os.path.join(OUT_DIR, f"epoch-{epoch}"))
    tokenizer.save_pretrained(os.path.join(OUT_DIR, f"epoch-{epoch}"))

model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)
print("LangAnchor fine-tuning complete; saved to:", OUT_DIR)


  with torch.cuda.amp.autocast(enabled=FP16):
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
Epoch 1/3: 100%|██████████| 1250/1250 [12:42:21<00:00, 36.59s/it, loss=4.6027, main=1.7270, anchor=2.0255]  
Epoch 2/3: 100%|██████████| 1250/1250 [11:59:59<00:00, 34.56s/it, loss=4.0399, main=1.9877, anchor=1.9100]  
Epoch 3/3: 100%|██████████| 1250/1250 [11:51:57<00:00, 34.17s/it, loss=3.9201, main=1.9727, anchor=1.8414]  


✅ LangAnchor fine-tuning complete; saved to: model-langanchor-finetuned
