In [2]:
%pip install transformers datasets accelerate torch

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting datasets
  Downloading datasets-4.2.0-py3-none-any.whl.metadata (18 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-21.0.0-cp313-cp313-win_amd64.whl.metadata (3.4 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.6.0-cp313-cp313-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Downloading datasets-4.2.0-py3-none-any.whl (506 kB)
Downloading multiprocess-0.70.16-py312-none-any.whl (146 kB)
Downloading pyarrow-21.0.0-cp313-cp313-win_amd64.whl (26.1 MB)
   ---------------------------------------- 0.0/26.1 MB ? eta -:--:--
   --------------- ------------------------ 10.0/26.1 MB 47.4 MB/s eta 0:00:01
   ---------------------------- ----------- 18.9/26.1 MB 45.9 MB/s eta 0:00:01
   ---------------------------------------- 26.1/26.1 MB 44.8 MB/s eta 0:00:00
Downloading xxhash-3

In [5]:
# pip install transformers datasets accelerate torch pandas
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
from transformers import Trainer, TrainingArguments
from torch.nn import functional as F
import torch
from datasets import load_dataset

# --------------------------
# 1️⃣ Configuration
# --------------------------
csv_path = "healthcare_dataset.csv"   # your dataset filename
model_name = "google/mt5-base"
max_input_len = 256
max_target_len = 128
alpha = 0.6          # CE vs KL loss balance
temperature = 3.0     # distillation temperature
num_epochs = 3
batch_size = 4        # reduced by default to be safer on memory

# --------------------------
# 2️⃣ Load tokenizer and models
# --------------------------
tokenizer = MT5Tokenizer.from_pretrained(model_name)

teacher = MT5ForConditionalGeneration.from_pretrained(model_name)
student = MT5ForConditionalGeneration.from_pretrained(model_name)

# freeze teacher parameters and set eval mode
for param in teacher.parameters():
    param.requires_grad = False
teacher.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
teacher.to(device)
student.to(device)

# --------------------------
# 3️⃣ Load and preprocess dataset
# --------------------------
# Expected columns in CSV: instruction, input, output
dataset = load_dataset("csv", data_files={"full": csv_path})["full"]

# Split 90/10 into train and validation
dataset = dataset.train_test_split(test_size=0.1, seed=42)

def combine_columns(batch):
    # This version expects batched examples
    src_texts = []
    tgt_texts = []
    for instr, inp, out in zip(batch["instruction"], batch["input"], batch["output"]):
        if inp and str(inp).strip() != "":
            src_texts.append(f"{instr} : {inp}")
        else:
            src_texts.append(instr)
        tgt_texts.append(out)

    model_inputs = tokenizer(
        src_texts,
        max_length=max_input_len,
        truncation=True,
        padding="max_length",
    )
    labels = tokenizer(
        tgt_texts,
        max_length=max_target_len,
        truncation=True,
        padding="max_length",
    )["input_ids"]

    # Replace pad token id's in the labels by -100 so they are ignored by the loss
    labels = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label_seq]
        for label_seq in labels
    ]
    model_inputs["labels"] = labels
    return model_inputs

# Use batched=True for speed and safety; remove original text columns
tokenized_datasets = dataset.map(
    combine_columns,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

# Convert to torch tensors (Trainer will do this automatically if not)
# --------------------------
# 4️⃣ Custom Trainer for Self-Distillation
# --------------------------
class SelfDistillationTrainer(Trainer):
    def __init__(self, teacher_model, temperature, alpha, **kwargs):
        super().__init__(**kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha = alpha

    def compute_loss(self, model, inputs, return_outputs=False):
        # inputs already moved to device by Trainer
        labels = inputs.pop("labels")

        # Teacher forward pass (no grad)
        with torch.no_grad():
            # ensure teacher uses float32 for stability when training uses fp16
            # convert any half tensors to float for teacher ops
            teacher_inputs = {k: v.to(self.teacher.device) for k, v in inputs.items()}
            teacher_outputs = self.teacher(**teacher_inputs)
            teacher_logits = teacher_outputs.logits / self.temperature  # (B, S, V)

        # Student forward pass (gives CE loss)
        outputs = model(**inputs, labels=labels)
        student_logits = outputs.logits / self.temperature

        # Cross-Entropy Loss (hard labels) -- outputs.loss already computed by model
        loss_ce = outputs.loss

        # We want KL over vocabulary dimension; ensure both are float32 for stability
        t_logits = teacher_logits.float()
        s_logits = student_logits.float()

        # KL Divergence Loss (soft labels)
        # compute log-probs and probs along the vocab dim
        s_log_probs = F.log_softmax(s_logits, dim=-1)
        t_probs = F.softmax(t_logits, dim=-1)

        kl_loss = F.kl_div(
            s_log_probs,
            t_probs,
            reduction="batchmean"
        ) * (self.temperature ** 2)

        # Combined loss
        loss = self.alpha * loss_ce + (1.0 - self.alpha) * kl_loss

        return (loss, outputs) if return_outputs else loss

# --------------------------
# 5️⃣ Training Arguments
# --------------------------
training_args = TrainingArguments(
    output_dir="./mt5_sdft_healthcare",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    save_total_limit=2,
    predict_with_generate=True,
    logging_steps=200,
    fp16=False,  # start with False; enable only after confirming correctness
    report_to="none",
    dataloader_pin_memory=True,
)

# --------------------------
# 6️⃣ Initialize Trainer
# --------------------------
trainer = SelfDistillationTrainer(
    teacher_model=teacher,
    model=student,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    temperature=temperature,
    alpha=alpha
)

# --------------------------
# 7️⃣ Train and Save
# --------------------------
trainer.train()
trainer.save_model("./mt5_self_distilled_healthcare_student")
tokenizer.save_pretrained("./mt5_self_distilled_healthcare_student")

print("✅ Training complete! Fine-tuned model saved to './mt5_self_distilled_healthcare_student'")

TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'

In [4]:
%pip install datasets


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


In [6]:
%pip install -U transformers datasets accelerate torch

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting datasets
  Downloading datasets-4.2.0-py3-none-any.whl.metadata (18 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-21.0.0-cp39-cp39-win_amd64.whl.metadata (3.4 kB)
Collecting dill<0.4.1,>=0.3.0 (from datasets)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting httpx<1.0.0 (from datasets)
  Downloading httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.6.0-cp39-cp39-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py39-none-any.whl.metadata (7.2 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]<=2025.9.0,>=2023.1.0->datasets)
  Downloading aiohttp-3.13.0-cp39-cp39-win_amd64.whl.metadata (8.4 kB)
Collecting anyio (from httpx<1.0.0->datasets)
  Downloading anyio-4.11.0-py3-none-any.whl.metadata (4.1 kB)
Collecting httpcore==1.* (from http

In [7]:
!pip install --upgrade transformers


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
%pip install tf_keras


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting tf_keras
  Downloading tf_keras-2.20.1-py3-none-any.whl.metadata (1.8 kB)
Downloading tf_keras-2.20.1-py3-none-any.whl (1.7 MB)
   ---------------------------------------- 0.0/1.7 MB ? eta -:--:--
   ---------------------------------------- 1.7/1.7 MB 30.5 MB/s  0:00:00
Installing collected packages: tf_keras
Successfully installed tf_keras-2.20.1
Note: you may need to restart the kernel to use updated packages.
