In [None]:
!pip install bitsandbytes datasets transformers[sentencepiece] --upgrade

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers[sentencepiece]
  Downloading transformers-4.50.3-py3-none-any.whl.metadata (39 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Col

In [1]:
from huggingface_hub import login
login()

In [None]:
from transformers import LlavaForConditionalGeneration, LlavaProcessor

processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", use_fast=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
def get_prompt(clinical_data_dict):
    clinical_data_str = ""

    for key, value in clinical_data_dict.items():
        if value:
            clinical_data_str += f"{key}: {value}\n"

    prompt = f"""
        Given the image of a Saggital MRI scan below, and the clinical details of patient, classify whether the patient is AD/MCI/CN.

        Full forms of the classes are as follows:
        AD: Alzheimer's Disease
        MCI: Mild Cognitive Impairment
        CN: Cognitively Normal

        Also, provide the reasoning behind why you classified the patient into the class you chose using following template:
        '''
            - Image features:
            - Clinical details:
        '''

        The clinical details of the patient are as follows:
        '''
            {clinical_data_str}
        '''

        Rules for output:
        - Classification should be one of the following: AD, MCI, CN.
        - Reasoning should be provided in the template mentioned.
    """
    return prompt

In [None]:
import pandas as pd
df = pd.read_csv('ADNI1_Final_With_Biomarkers.csv', dtype={"GENOTYPE": str})

In [None]:
df["GENOTYPE"]

Unnamed: 0,GENOTYPE
0,4-4
1,4-4
2,4-4
3,3-3
4,3-3
...,...
2289,3-4
2290,3-4
2291,3-4
2292,3-4


In [None]:
df.head()

Unnamed: 0,Image Data ID,Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,...,VISCODE_y.1,HMSCORE,VISCODE_x.2,NPISCORE,VISCODE_y.2,GDTOTAL,VISCODE2,ABETA42,TAU,PTAU
0,I97327,941_S_1311,MCI,M,69,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,03-02-2007,...,sc,1.0,,,sc,1.0,,,,
1,I112538,941_S_1311,MCI,M,70,m12,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,06-01-2008,...,,,m12,4.0,m12,3.0,,,,
2,I97341,941_S_1311,MCI,M,70,m06,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,9-27-2007,...,,,m06,3.0,,,,,,
3,I63874,941_S_1202,CN,M,78,sc,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,1-30-2007,...,sc,0.0,,,sc,0.0,,,,
4,I75150,941_S_1202,CN,M,78,m06,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,8-24-2007,...,,,m06,2.0,,,,,,


In [None]:
df.columns

Index(['Image Data ID', 'Subject', 'Group', 'Sex', 'Age', 'Visit', 'Modality',
       'Description', 'Type', 'Acq Date', 'Format', 'Downloaded', 'GENOTYPE',
       'VISCODE_x', 'CDGLOBAL', 'CDRSB', 'VISCODE_y', 'MMSCORE', 'VISCODE_x.1',
       'TOTAL11', 'TOTALMOD', 'VISCODE_y.1', 'HMSCORE', 'VISCODE_x.2',
       'NPISCORE', 'VISCODE_y.2', 'GDTOTAL', 'VISCODE2', 'ABETA42', 'TAU',
       'PTAU'],
      dtype='object')

In [None]:
FEATURE_COLS = ["Image Data ID", "Age", "GENOTYPE", "CDGLOBAL", "CDRSB", "HMSCORE", "NPISCORE", "MMSCORE", "GDTOTAL", "Group"]

In [None]:
from sklearn.preprocessing import LabelEncoder
df = df[FEATURE_COLS]

In [None]:
df.head()

Unnamed: 0,Image Data ID,Age,GENOTYPE,CDGLOBAL,CDRSB,HMSCORE,NPISCORE,MMSCORE,GDTOTAL,Group
0,I97327,69,4-4,0.5,2.5,1.0,,29.0,1.0,MCI
1,I112538,70,4-4,1.0,5.5,,4.0,27.0,3.0,MCI
2,I97341,70,4-4,0.5,3.5,,3.0,27.0,,MCI
3,I63874,78,3-3,0.0,0.0,0.0,,28.0,0.0,CN
4,I75150,78,3-3,0.0,0.5,,2.0,30.0,,CN


In [None]:
img_paths = []
prompts = []

for i in range(len(df)):
    image_data_id = df.iloc[i]["Image Data ID"]
    img_paths.append(f"preprocessed_images_3/{image_data_id}.png")
    clinical_data_dict = {}

    for col, val in df.iloc[i].items():
        if pd.notna(val):
            clinical_data_dict[col] = val
    prompts.append(get_prompt(clinical_data_dict))

In [None]:
df["image_path"] = img_paths
df["prompt"] = prompts

In [None]:
# split
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df["Group"])
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42, stratify=test_df["Group"])

In [None]:
print(f"Train dataset size: {len(train_df)}")
print(f"Test dataset size: {len(test_df)}")
print(f"Validation dataset size: {len(val_df)}")

Train dataset size: 1605
Test dataset size: 345
Validation dataset size: 344


In [None]:
import torch
from datasets import Dataset
from PIL import Image

label2id = {"CN": 0, "MCI": 1, "AD": 2}

def preprocess(example):
    image = Image.open(example['image_path']).convert("RGB")
    inputs = processor(example["prompt"], image, return_tensors="pt", padding="max_length", truncation=True, max_length=256)
    return {
        "input_ids": inputs["input_ids"].squeeze(0),
        "attention_mask": inputs["attention_mask"].squeeze(0),
        "pixel_values": inputs["pixel_values"].squeeze(0),
        "labels": torch.tensor(label2id[example["Group"]], dtype=torch.long)
    }

In [None]:
train_data_dict = {
    'image_path': train_df['image_path'].tolist(),
    'prompt': train_df['prompt'].tolist(),
    'Group': train_df['Group'].tolist()
}
val_data_dict = {
    'image_path': val_df['image_path'].tolist(),
    'prompt': val_df['prompt'].tolist(),
    'Group': val_df['Group'].tolist()
}

In [None]:
train_dataset = Dataset.from_dict(train_data_dict)
train_dataset = train_dataset.map(preprocess)

val_dataset = Dataset.from_dict(val_data_dict)
val_dataset = val_dataset.map(preprocess)

Map:   0%|          | 0/1605 [00:00<?, ? examples/s]

You may have used the wrong order for inputs. `images` should be passed before `text`. The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47.


Map:   0%|          | 0/344 [00:00<?, ? examples/s]

In [None]:
import torch
import torch.nn as nn
from transformers import LlavaForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType

class LlavaForClassificationWithQLoRA(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()

        # Load LLaVA with 4-bit quantization
        self.base_model = LlavaForConditionalGeneration.from_pretrained(
            model_name,
            load_in_4bit=True,
            device_map="auto"
        )

        # QLoRA config for attention projection layers
        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM
        )

        # Inject LoRA adapters into the base model
        self.base_model = get_peft_model(self.base_model, lora_config)
        self.base_model.print_trainable_parameters()

        # Dimensions
        self.image_hidden_size = self.base_model.vision_tower.config.hidden_size  # 1024
        self.text_hidden_size = self.base_model.language_model.config.hidden_size  # 4096
        self.proj_dim = 1024  # unify both to 1024

        # Project image and text embeddings to same dim
        self.vision_proj = nn.Linear(self.image_hidden_size, self.proj_dim)
        self.text_proj = nn.Linear(self.text_hidden_size, self.proj_dim)

        # Final classifier head
        self.classifier = nn.Linear(self.proj_dim, num_classes)

    def forward(self, input_ids=None, attention_mask=None, pixel_values=None, labels=None):
        # Extract image embeddings (ViT CLS token)
        vision_outputs = self.base_model.vision_tower(pixel_values=pixel_values)
        image_embeds = vision_outputs.last_hidden_state[:, 0]  # shape: [B, 1024]

        # Text embeddings (mean of token embeddings)
        text_embeds = self.base_model.language_model.model.embed_tokens(input_ids).mean(dim=1)  # shape: [B, 4096]

        # Project to common dim
        image_feat = self.vision_proj(image_embeds)  # [B, 1024]
        text_feat = self.text_proj(text_embeds)      # [B, 1024]

        # Fuse
        fused = image_feat + text_feat               # [B, 1024]
        logits = self.classifier(fused)              # [B, num_classes]

        # Optional loss
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {"loss": loss, "logits": logits}

In [20]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

cls_model = LlavaForClassificationWithQLoRA("llava-hf/llava-1.5-7b-hf", num_classes=3)

training_args = TrainingArguments(
    output_dir="./llava-qlora-classifier",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=3e-4,
    num_train_epochs=20,
    fp16=True,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

trainer = Trainer(
    model=cls_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

trainer.train()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 4,980,736 || all params: 7,068,407,808 || trainable%: 0.0705


  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msuj00rit20[0m ([33mjaggery[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss
1,0.8064,1.267674
2,1.0665,0.708969
3,0.6589,0.632309
4,0.4185,0.517686
5,0.6729,0.498477
6,0.5623,0.485094
7,0.4565,0.372575
8,0.3549,0.316272
9,0.2754,0.286266
10,0.1785,0.213087


Epoch,Training Loss,Validation Loss
1,0.8064,1.267674
2,1.0665,0.708969
3,0.6589,0.632309
4,0.4185,0.517686
5,0.6729,0.498477
6,0.5623,0.485094
7,0.4565,0.372575
8,0.3549,0.316272
9,0.2754,0.286266
10,0.1785,0.213087


There were unexpected keys in the checkpoint model loaded: ['base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.weight.absmax', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.weight.quant_map', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.weight.quant_state.bitsandbytes__fp4', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.base_layer.weight.absmax', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.base_layer.weight.quant_map', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.base_layer.weight.quant_state.bitsandbytes__fp4', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.q_proj.base_layer.weight.absmax', 'base_model.base_model.model.vision_tower.vision_model.encoder.layers.0.self_attn.q_proj.base_layer.weight.qu

TrainOutput(global_step=8020, training_loss=0.34922429349580014, metrics={'train_runtime': 12576.9582, 'train_samples_per_second': 2.552, 'train_steps_per_second': 0.638, 'total_flos': 0.0, 'train_loss': 0.34922429349580014, 'epoch': 19.95202492211838})

In [21]:
from PIL import Image
import torch

def predict_class(image_path, prompt, model, processor, id2label=None):
    # Load and preprocess
    image = Image.open(image_path)
    inputs = processor(prompt, image, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            pixel_values=inputs["pixel_values"]
        )
        logits = outputs["logits"]  # [B, num_classes]
        probs = torch.softmax(logits, dim=-1)
        pred_id = torch.argmax(probs, dim=-1).item()

    if id2label:
        return id2label[pred_id]
    return pred_id

In [22]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}

In [23]:
test_df.iloc[0]

Unnamed: 0,2264
Image Data ID,I60465
Age,71
GENOTYPE,4-4
CDGLOBAL,1.0
CDRSB,5.5
HMSCORE,
NPISCORE,7.0
MMSCORE,26.0
GDTOTAL,
Group,AD


In [24]:
predict_class(test_df.iloc[0]["image_path"], test_df.iloc[0]["prompt"], cls_model, processor, id2label)

'AD'

In [25]:
predictions = []

for i in range(len(test_df)):
  predictions.append(predict_class(test_df.iloc[i]["image_path"], test_df.iloc[i]["prompt"], cls_model, processor, id2label))

In [26]:
from sklearn.metrics import classification_report

print(classification_report(test_df["Group"], predictions))

              precision    recall  f1-score   support

          AD       0.99      0.93      0.96        72
          CN       0.99      0.94      0.97       106
         MCI       0.94      0.99      0.97       167

    accuracy                           0.97       345
   macro avg       0.97      0.96      0.96       345
weighted avg       0.97      0.97      0.97       345



# Saving the Fine-tuned model to HF

In [27]:
from huggingface_hub import HfApi, HfFolder, Repository
import os
import torch

hf_repo = "sujayrittikar/adni_llava_qlora"
save_dir = "./llava_classifier_save"

# 1. Save LoRA-augmented base model
cls_model.base_model.save_pretrained(save_dir)

# 2. Save classifier weights separately
torch.save({
    "vision_proj": cls_model.vision_proj.state_dict(),
    "text_proj": cls_model.text_proj.state_dict(),
    "classifier": cls_model.classifier.state_dict()
}, os.path.join(save_dir, "classifier_head.pt"))

In [28]:
# 3. Save any config info needed to reload
import json
json.dump({"num_classes": 3}, open(os.path.join(save_dir, "config.json"), "w"))

# 4. Push to HF
from huggingface_hub import create_repo, upload_folder

create_repo(hf_repo, exist_ok=True)
upload_folder(
    repo_id=hf_repo,
    folder_path=save_dir,
    commit_message="Uploaded LLaVA classification model with QLoRA"
)

adapter_model.safetensors:   0%|          | 0.00/20.0M [00:00<?, ?B/s]

classifier_head.pt:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/sujayrittikar/adni_llava_qlora/commit/867ee0996faa7e0b85c89b608b330609866d2f1b', commit_message='Uploaded LLaVA classification model with QLoRA', commit_description='', oid='867ee0996faa7e0b85c89b608b330609866d2f1b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sujayrittikar/adni_llava_qlora', endpoint='https://huggingface.co', repo_type='model', repo_id='sujayrittikar/adni_llava_qlora'), pr_revision=None, pr_num=None)

# Loading the HF saved model

In [9]:
from peft import PeftModel, PeftConfig
from transformers import LlavaForConditionalGeneration
import torch
import torch.nn as nn
import json

hf_repo = "sujayrittikar/adni_llava_qlora"

class LlavaForClassificationWithQLoRA(nn.Module):
    def __init__(self, hf_repo, num_classes):
        super().__init__()

        # Step 1: Load PEFT config to get base model name
        peft_config = PeftConfig.from_pretrained(hf_repo)

        # Step 2: Load base model
        base_model = LlavaForConditionalGeneration.from_pretrained(
            peft_config.base_model_name_or_path,
            load_in_4bit=True,
            device_map="auto"
        )

        # Step 3: Inject LoRA adapters
        self.base_model = PeftModel.from_pretrained(base_model, hf_repo)

        # Dimensions
        self.image_hidden_size = self.base_model.vision_tower.config.hidden_size
        self.text_hidden_size = self.base_model.language_model.config.hidden_size
        self.proj_dim = 1024

        # Classifier head
        self.vision_proj = nn.Linear(self.image_hidden_size, self.proj_dim)
        self.text_proj = nn.Linear(self.text_hidden_size, self.proj_dim)
        self.classifier = nn.Linear(self.proj_dim, num_classes)

    def forward(self, input_ids=None, attention_mask=None, pixel_values=None, labels=None):
        vision_outputs = self.base_model.vision_tower(pixel_values=pixel_values)
        image_embeds = vision_outputs.last_hidden_state[:, 0]

        text_embeds = self.base_model.language_model.model.embed_tokens(input_ids).mean(dim=1)

        image_feat = self.vision_proj(image_embeds)
        text_feat = self.text_proj(text_embeds)

        fused = image_feat + text_feat
        logits = self.classifier(fused)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {"loss": loss, "logits": logits}

# Load your model
num_classes = 3
model = LlavaForClassificationWithQLoRA(hf_repo, num_classes)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
from huggingface_hub import hf_hub_download
import torch

# Download classifier_head.pt
state_dict_path = hf_hub_download(
    repo_id=hf_repo,
    filename="classifier_head.pt"
)

# Load head weights
state_dict = torch.load(state_dict_path, map_location="cpu")
model.vision_proj.load_state_dict(state_dict["vision_proj"])
model.text_proj.load_state_dict(state_dict["text_proj"])
model.classifier.load_state_dict(state_dict["classifier"])

<All keys matched successfully>