In [None]:
import re
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
)

from peft import get_peft_model, LoraConfig, TaskType


model_name = "google/mt5-base"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


data_files = {"train": "summarization_dataset.csv"}
raw_dataset = load_dataset("csv", data_files=data_files)["train"]

MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 84
WHITESPACE_HANDLER = lambda text: re.sub(r'\s+', ' ', re.sub(r'\n+', ' ', text).strip())

def preprocess_batch(batch):
    inputs = [WHITESPACE_HANDLER(x) for x in batch["article"]]
    targets = [WHITESPACE_HANDLER(x) for x in batch["highlights"]]
    
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
    
    labels["input_ids"] = [
        [(lid if lid != tokenizer.pad_token_id else -100) for lid in label] 
        for label in labels["input_ids"]
    ]
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = raw_dataset.map(preprocess_batch, batched=True, remove_columns=raw_dataset.column_names)


lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

model_lora = get_peft_model(base_model, lora_config)
model_lora.print_trainable_parameters()


data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_lora)


training_args = Seq2SeqTrainingArguments(
    output_dir="model-lora-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    logging_dir="logs-lora",
    report_to="none",
    run_name="lora-finetune"
)


trainer = Seq2SeqTrainer(
    model=model_lora,
    args=training_args,
    train_dataset=tokenized_train,
    tokenizer=tokenizer,
    data_collator=data_collator
)


trainer.train()
# To resume from checkpoint use : trainer.train(resume_from_checkpoint=True) and comment above step.

trainer.save_model("model-lora-finetuned")
tokenizer.save_pretrained("model-lora-finetuned")


trainable params: 1,769,472 || all params: 584,170,752 || trainable%: 0.3029


  trainer = Seq2SeqTrainer(


Step,Training Loss
100,2.9092
200,2.7548
300,2.6688
400,2.5099
500,2.4112
600,2.381
700,2.3798
800,2.3118
900,2.337
1000,2.2685


('model-lora-finetuned\\tokenizer_config.json',
 'model-lora-finetuned\\special_tokens_map.json',
 'model-lora-finetuned\\spiece.model',
 'model-lora-finetuned\\added_tokens.json',
 'model-lora-finetuned\\tokenizer.json')

In [13]:
import torch
torch.cuda.empty_cache()

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel

base_model_name = "./mT5_multilingual_XLSum"          
lora_model_path = "./model-lora-finetuned"             

tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)

model = PeftModel.from_pretrained(base_model, lora_model_path)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print("LoRA model loaded successfully on", device)

texts = {
    "en": """YouTube has announced a new policy to remove videos spreading false information 
    about vaccines. The company says it will take down content that claims vaccines cause 
    autism or infertility, and accounts of anti-vaccine influencers may be terminated.""",

    "hi": """‡§≠‡§æ‡§∞‡§§ ‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§®‡§à ‡§®‡•Ä‡§§‡§ø ‡§ï‡•á ‡§§‡§π‡§§ ‡§∏‡•ç‡§µ‡§ö‡•ç‡§õ ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§ï‡•ã ‡§¨‡§¢‡§º‡§æ‡§µ‡§æ ‡§¶‡•á‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è 
    2025 ‡§§‡§ï 500 ‡§ó‡•Ä‡§ó‡§æ‡§µ‡•â‡§ü ‡§®‡§µ‡•Ä‡§ï‡§∞‡§£‡•Ä‡§Ø ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§ï‡•ç‡§∑‡§Æ‡§§‡§æ ‡§ï‡§æ ‡§≤‡§ï‡•ç‡§∑‡•ç‡§Ø ‡§∞‡§ñ‡§æ ‡§π‡•à‡•§""",

    "bn": """‡¶¨‡¶æ‡¶Ç‡¶≤‡¶æ‡¶¶‡ßá‡¶∂‡ßá ‡¶∏‡¶Æ‡ßç‡¶™‡ßç‡¶∞‡¶§‡¶ø ‡¶ï‡ßÉ‡¶∑‡¶ø ‡¶™‡ßç‡¶∞‡¶Ø‡ßÅ‡¶ï‡ßç‡¶§‡¶ø‡¶§‡ßá ‡¶®‡¶§‡ßÅ‡¶® ‡¶â‡¶¶‡ßç‡¶≠‡¶æ‡¶¨‡¶®‡ßá‡¶∞ ‡¶´‡¶≤‡ßá ‡¶ß‡¶æ‡¶® ‡¶â‡ßé‡¶™‡¶æ‡¶¶‡¶® 
    ‡¶Ü‡¶ó‡ßá‡¶∞ ‡¶ö‡ßá‡¶Ø‡¶º‡ßá ‡¶Ö‡¶®‡ßá‡¶ï ‡¶¨‡ßá‡¶∂‡¶ø ‡¶¨‡ßá‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§""",

    "sw": """Serikali ya Kenya imetangaza mpango mpya wa kusaidia wakulima wadogo 
    kuongeza uzalishaji wa chakula kupitia teknolojia za kisasa."""
}

def summarize(text, lang):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        summary_ids = model.generate(
            **inputs,
            max_length=80,
            num_beams=4,
            no_repeat_ngram_size=2
        )
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    print(f"\n [{lang.upper()}] Input:\n{text}\n\n Summary:\n{summary}\n" + "-"*80)

for lang, text in texts.items():
    summarize(text, lang)




‚úÖ LoRA model loaded successfully on cuda

üåç [EN] Input:
YouTube has announced a new policy to remove videos spreading false information 
    about vaccines. The company says it will take down content that claims vaccines cause 
    autism or infertility, and accounts of anti-vaccine influencers may be terminated.

üßæ Summary:
YouTube says it will remove videos spreading false information about vaccines.
--------------------------------------------------------------------------------

üåç [HI] Input:
‡§≠‡§æ‡§∞‡§§ ‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§®‡§à ‡§®‡•Ä‡§§‡§ø ‡§ï‡•á ‡§§‡§π‡§§ ‡§∏‡•ç‡§µ‡§ö‡•ç‡§õ ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§ï‡•ã ‡§¨‡§¢‡§º‡§æ‡§µ‡§æ ‡§¶‡•á‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è 
    2025 ‡§§‡§ï 500 ‡§ó‡•Ä‡§ó‡§æ‡§µ‡•â‡§ü ‡§®‡§µ‡•Ä‡§ï‡§∞‡§£‡•Ä‡§Ø ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§ï‡•ç‡§∑‡§Æ‡§§‡§æ ‡§ï‡§æ ‡§≤‡§ï‡•ç‡§∑‡•ç‡§Ø ‡§∞‡§ñ‡§æ ‡§π‡•à‡•§

üßæ Summary:
‡§≠‡§æ‡§∞‡§§ ‡§∏‡§∞‡§ï‡§æ‡§∞ ‡§®‡•á ‡§∏‡•ç‡§µ‡§ö‡•ç‡§õ ‡§ä‡§∞‡•ç‡§ú‡§æ ‡§ï‡•ã ‡§¨‡§¢‡§º‡§æ‡§µ‡§æ ‡§¶‡•á‡§®‡•á ‡§ï‡•á ‡§≤‡§ø‡§è ‡§®‡§à ‡§®‡•Ä‡§§‡§ø ‡§ï‡•Ä ‡§ò‡•