<a href="https://colab.research.google.com/github/upriser72/Bridges-Distribution-Gap-in-Language-Model-Fine-Tuning/blob/main/mT5_Vanilla_Fine_Tuning_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q --upgrade transformers datasets pandas torch sentencepiece accelerate

In [None]:
import torch
from datasets import Dataset
import pandas as pd
from transformers import (
    MT5ForConditionalGeneration,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

def main():
    """
    Main function to run the fine-tuning process.
    """
    # --- 1. Load and Prepare the Dataset ---
    # Load the dataset from the CSV file.
    # Using engine='python' to handle potential parsing errors in large CSV files.
    # on_bad_lines='skip' will skip rows that have formatting issues.
    try:
        df = pd.read_csv("healthcare_dataset.csv", engine='python', on_bad_lines='skip')
    except FileNotFoundError:
        print("Error: 'healthcare_dataset.csv' not found.")
        print("Please make sure the dataset file is uploaded to your Colab environment and the name matches exactly.")
        return

    # Convert the pandas DataFrame to a Hugging Face Dataset
    dataset = Dataset.from_pandas(df)

    # --- 2. Load Tokenizer and Model ---
    model_name = "google/mt5-small"
    # The tokenizer is responsible for converting text into a format the model can understand.
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # The model is the pre-trained mT5-small architecture.
    model = MT5ForConditionalGeneration.from_pretrained(model_name)

    # --- 3. Preprocess the Data ---
    # We need to format the input and output correctly for the T5 model.
    # The model will be trained to generate the 'output' from the 'input' column.
    prefix = "diagnose: "
    max_input_length = 512
    max_target_length = 256

    def preprocess_function(examples):
        """Tokenizes the dataset."""
        # Create the combined input text for each item in the batch
        inputs = [prefix + str(inp) for inp in examples["input"]]

        # Tokenize the inputs
        model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")

        # Tokenize the targets (outputs)
        # The 'with tokenizer.as_target_tokenizer():' block is essential for T5-style models.
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(examples["output"], max_length=max_target_length, truncation=True, padding="max_length")

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    # Apply the preprocessing function to the entire dataset
    tokenized_dataset = dataset.map(
        preprocess_function,
        batched=True,
    )

    # --- 4. Set Up Training ---
    # Define training arguments. These control various aspects of the training process.
    training_args = Seq2SeqTrainingArguments(
        output_dir="./results_mt5_medical",    # Directory to save the model and results
        num_train_epochs=50,                   # Total number of training epochs (increase for better performance on small datasets)
        per_device_train_batch_size=2,         # Batch size per device during training
        per_device_eval_batch_size=2,          # Batch size for evaluation (if used)
        warmup_steps=50,                       # Number of warmup steps for learning rate scheduler
        weight_decay=0.01,                     # Strength of weight decay
        logging_dir='./logs',                  # Directory for storing logs
        logging_steps=10,
        save_total_limit=2,                    # Only keep the last 2 saved models
        predict_with_generate=True,            # Whether to use generate to calculate generative metrics
        report_to="none",                      # Disable integration with Weights & Biases
    )

    # Data collator prepares batches of data for the model.
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    # The Trainer class handles the training and evaluation loop.
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # --- 5. Start Fine-Tuning ---
    print("Starting the fine-tuning process...")
    trainer.train()
    print("Fine-tuning complete.")

    # --- 6. Save the Fine-Tuned Model ---
    final_model_path = "./fine_tuned_mt5_medical"
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)
    print(f"Model saved to {final_model_path}")

    # --- 7. Inference Example ---
    print("\n--- Running Inference with the Fine-Tuned Model ---")

    # Load the fine-tuned model and tokenizer
    trained_model = MT5ForConditionalGeneration.from_pretrained(final_model_path)
    trained_tokenizer = AutoTokenizer.from_pretrained(final_model_path)

    # Define a new patient query
    patient_input = "I have a persistent dry cough for the last 3 days and a slight fever. I feel tired and have a mild headache. What could this be?"

    # Prepare the input for the model
    prompt = f"diagnose: {patient_input}"
    inputs = trained_tokenizer(prompt, return_tensors="pt").input_ids

    # Generate the output
    print("Generating diagnosis...")
    outputs = trained_model.generate(
        inputs,
        max_length=200,
        num_beams=5,
        early_stopping=True
    )

    # Decode and print the result
    generated_text = trained_tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\nPatient Query:")
    print(patient_input)
    print("\nGenerated Medical Advice:")
    print(generated_text)


if __name__ == "__main__":
    main()

