In [1]:
import os

os.environ["HF_HOME"] = "/home/shared/.cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/home/shared/.cache/huggingface/hub"

In [2]:
import json
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from datasets import Dataset
from torch.optim import AdamW
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    default_data_collator
)
from torch.utils.data import DataLoader

# ✅ Step 1: Load soft labels from file
with open("soft_labels_finetuned_biogpt.json", "r") as f:
    soft_dataset = json.load(f)


In [3]:
import os
from huggingface_hub import HfApi
from huggingface_hub import login

load_dotenv()
hf_token = os.getenv("hf_token")
login(token=hf_token)

In [9]:
# ✅ Step 2: Load PubMedBERT tokenizer and model with manual BERT specification

from transformers import AutoTokenizer, AutoModelForSequenceClassification

student_model_id = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
student_model = AutoModelForSequenceClassification.from_pretrained(student_model_id, num_labels=3)

ValueError: Unrecognized model in microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, aria, aria_text, audio-spectrogram-transformer, autoformer, aya_vision, bamba, bark, bart, beit, bert, bert-generation, big_bird, bigbird_pegasus, biogpt, bit, blenderbot, blenderbot-small, blip, blip-2, bloom, bridgetower, bros, camembert, canine, chameleon, chinese_clip, chinese_clip_vision_model, clap, clip, clip_text_model, clip_vision_model, clipseg, clvp, code_llama, codegen, cohere, cohere2, colpali, conditional_detr, convbert, convnext, convnextv2, cpmant, ctrl, cvt, dab-detr, dac, data2vec-audio, data2vec-text, data2vec-vision, dbrx, deberta, deberta-v2, decision_transformer, deepseek_v3, deformable_detr, deit, depth_anything, depth_pro, deta, detr, diffllama, dinat, dinov2, dinov2_with_registers, distilbert, donut-swin, dpr, dpt, efficientformer, efficientnet, electra, emu3, encodec, encoder-decoder, ernie, ernie_m, esm, falcon, falcon_mamba, fastspeech2_conformer, flaubert, flava, fnet, focalnet, fsmt, funnel, fuyu, gemma, gemma2, gemma3, gemma3_text, git, glm, glm4, glpn, got_ocr2, gpt-sw3, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, gptsan-japanese, granite, granitemoe, granitemoeshared, granitevision, graphormer, grounding-dino, groupvit, helium, hiera, hubert, ibert, idefics, idefics2, idefics3, idefics3_vision, ijepa, imagegpt, informer, instructblip, instructblipvideo, jamba, jetmoe, jukebox, kosmos-2, layoutlm, layoutlmv2, layoutlmv3, led, levit, lilt, llama, llama4, llama4_text, llava, llava_next, llava_next_video, llava_onevision, longformer, longt5, luke, lxmert, m2m_100, mamba, mamba2, marian, markuplm, mask2former, maskformer, maskformer-swin, mbart, mctct, mega, megatron-bert, mgp-str, mimi, mistral, mistral3, mixtral, mllama, mobilebert, mobilenet_v1, mobilenet_v2, mobilevit, mobilevitv2, modernbert, moonshine, moshi, mpnet, mpt, mra, mt5, musicgen, musicgen_melody, mvp, nat, nemotron, nezha, nllb-moe, nougat, nystromformer, olmo, olmo2, olmoe, omdet-turbo, oneformer, open-llama, openai-gpt, opt, owlv2, owlvit, paligemma, patchtsmixer, patchtst, pegasus, pegasus_x, perceiver, persimmon, phi, phi3, phi4_multimodal, phimoe, pix2struct, pixtral, plbart, poolformer, pop2piano, prompt_depth_anything, prophetnet, pvt, pvt_v2, qdqbert, qwen2, qwen2_5_vl, qwen2_audio, qwen2_audio_encoder, qwen2_moe, qwen2_vl, qwen3, qwen3_moe, rag, realm, recurrent_gemma, reformer, regnet, rembert, resnet, retribert, roberta, roberta-prelayernorm, roc_bert, roformer, rt_detr, rt_detr_resnet, rt_detr_v2, rwkv, sam, sam_vision_model, seamless_m4t, seamless_m4t_v2, segformer, seggpt, sew, sew-d, shieldgemma2, siglip, siglip2, siglip_vision_model, smolvlm, smolvlm_vision, speech-encoder-decoder, speech_to_text, speech_to_text_2, speecht5, splinter, squeezebert, stablelm, starcoder2, superglue, superpoint, swiftformer, swin, swin2sr, swinv2, switch_transformers, t5, table-transformer, tapas, textnet, time_series_transformer, timesformer, timm_backbone, timm_wrapper, trajectory_transformer, transfo-xl, trocr, tvlt, tvp, udop, umt5, unispeech, unispeech-sat, univnet, upernet, van, video_llava, videomae, vilt, vipllava, vision-encoder-decoder, vision-text-dual-encoder, visual_bert, vit, vit_hybrid, vit_mae, vit_msn, vitdet, vitmatte, vitpose, vitpose_backbone, vits, vivit, wav2vec2, wav2vec2-bert, wav2vec2-conformer, wavlm, whisper, xclip, xglm, xlm, xlm-prophetnet, xlm-roberta, xlm-roberta-xl, xlnet, xmod, yolos, yoso, zamba, zamba2, zoedepth

In [None]:
# ✅ Step 3: Convert soft dataset to HF Dataset and tokenize
hf_dataset = Dataset.from_list(soft_dataset)

def tokenize_function(example):
    tokens = student_tokenizer(
        example["input_text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )
    tokens["labels"] = torch.tensor(example["soft_label"], dtype=torch.float)
    return tokens

tokenized_dataset = hf_dataset.map(tokenize_function, remove_columns=["input_text", "soft_label", "gold_index"])

In [None]:
# ✅ Step 4: Build DataLoader
train_loader = DataLoader(
    tokenized_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=default_data_collator
)

In [None]:
# ✅ Step 5: Training config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model.to(device)

optimizer = AdamW(student_model.parameters(), lr=5e-5)
loss_fn = torch.nn.KLDivLoss(reduction="batchmean")

In [None]:
# ✅ Step 6: Distillation training loop
epochs = 3
student_model.train()

for epoch in range(epochs):
    total_loss = 0
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        loss = loss_fn(log_probs, labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    print(f"✅ Epoch {epoch+1} - Avg Distillation Loss: {total_loss / len(train_loader):.4f}")