In [1]:
from huggingface_hub import login
login()

In [2]:
pip install bitsandbytes



In [None]:
!pip install datasets

In [32]:
!pip install transformers[sentencepiece]




In [2]:
from transformers import LlavaForConditionalGeneration, LlavaProcessor

processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", use_fast=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
def get_prompt(clinical_data_dict):
    clinical_data_str = ""

    for key, value in clinical_data_dict.items():
        if value:
            clinical_data_str += f"{key}: {value}\n"

    prompt = f"""
        Given the image of a Saggital MRI scan below, and the clinical details of patient, classify whether the patient is AD/MCI/CN.

        Full forms of the classes are as follows:
        AD: Alzheimer's Disease
        MCI: Mild Cognitive Impairment
        CN: Cognitively Normal

        Also, provide the reasoning behind why you classified the patient into the class you chose using following template:
        '''
            - Image features:
            - Clinical details:
        '''

        The clinical details of the patient are as follows:
        '''
            {clinical_data_str}
        '''

        Rules for output:
        - Classification should be one of the following: AD, MCI, CN.
        - Reasoning should be provided in the template mentioned.
    """
    return prompt

In [4]:
import pandas as pd
df = pd.read_csv('ADNI1_Final_With_Biomarkers.csv')

In [5]:
FEATURE_COLS = ["Image Data ID", "Age", "GENOTYPE", "CDGLOBAL", "CDRSB", "MMSCORE", "HMSCORE", "NPISCORE", "GDTOTAL", "Group"]

In [6]:
from sklearn.preprocessing import LabelEncoder
df = df[FEATURE_COLS]

In [7]:
df.head()

Unnamed: 0,Image Data ID,Age,GENOTYPE,CDGLOBAL,CDRSB,MMSCORE,HMSCORE,NPISCORE,GDTOTAL,Group
0,I97327,69,4/4,0.5,2.5,29.0,1.0,,1.0,MCI
1,I112538,70,4/4,1.0,5.5,27.0,,4.0,3.0,MCI
2,I97341,70,4/4,0.5,3.5,27.0,,3.0,,MCI
3,I63874,78,3/3,0.0,0.0,28.0,0.0,,0.0,CN
4,I75150,78,3/3,0.0,0.5,30.0,,2.0,,CN


In [8]:
img_paths = []
prompts = []

for i in range(len(df)):
    image_data_id = df.iloc[i]["Image Data ID"]
    img_paths.append(f"preprocessed_images_3/{image_data_id}.png")
    clinical_data_dict = {}

    for col, val in df.iloc[i].items():
        if pd.notna(val):
            clinical_data_dict[col] = val
    prompts.append(get_prompt(clinical_data_dict))

In [9]:
df["image_path"] = img_paths
df["prompt"] = prompts

In [10]:
# split
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df["Group"])
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42, stratify=test_df["Group"])

In [11]:
print(f"Train dataset size: {len(train_df)}")
print(f"Test dataset size: {len(test_df)}")
print(f"Validation dataset size: {len(val_df)}")

Train dataset size: 1605
Test dataset size: 345
Validation dataset size: 344


In [12]:
import torch
from datasets import Dataset
from PIL import Image

label2id = {"CN": 0, "MCI": 1, "AD": 2}

def preprocess(example):
    image = Image.open(example['image_path']).convert("RGB")
    inputs = processor(example["prompt"], image, return_tensors="pt", padding="max_length", truncation=True, max_length=256)
    return {
        "input_ids": inputs["input_ids"].squeeze(0),
        "attention_mask": inputs["attention_mask"].squeeze(0),
        "pixel_values": inputs["pixel_values"].squeeze(0),
        "labels": torch.tensor(label2id[example["Group"]], dtype=torch.long)
    }

In [13]:
train_data_dict = {
    'image_path': train_df['image_path'].tolist(),
    'prompt': train_df['prompt'].tolist(),
    'Group': train_df['Group'].tolist()
}
val_data_dict = {
    'image_path': val_df['image_path'].tolist(),
    'prompt': val_df['prompt'].tolist(),
    'Group': val_df['Group'].tolist()
}

In [14]:
train_dataset = Dataset.from_dict(train_data_dict)
train_dataset = train_dataset.map(preprocess)

val_dataset = Dataset.from_dict(val_data_dict)
val_dataset = val_dataset.map(preprocess)

Map:   0%|          | 0/1605 [00:00<?, ? examples/s]

You may have used the wrong order for inputs. `images` should be passed before `text`. The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47.


Map:   0%|          | 0/344 [00:00<?, ? examples/s]

In [15]:
import torch
import torch.nn as nn
from transformers import LlavaForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType

class LlavaForClassificationWithQLoRA(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()

        # Load LLaVA with 4-bit quantization
        self.base_model = LlavaForConditionalGeneration.from_pretrained(
            model_name,
            load_in_8bit=True,
            device_map="auto"
        )

        # QLoRA config for attention projection layers
        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM
        )

        # Inject LoRA adapters into the base model
        self.base_model = get_peft_model(self.base_model, lora_config)
        self.base_model.print_trainable_parameters()

        # Dimensions
        self.image_hidden_size = self.base_model.vision_tower.config.hidden_size  # typically 1024
        self.text_hidden_size = self.base_model.language_model.config.hidden_size  # typically 4096
        self.proj_dim = 1024  # unify both to 1024

        # Project image and text embeddings to same dim
        self.vision_proj = nn.Linear(self.image_hidden_size, self.proj_dim)
        self.text_proj = nn.Linear(self.text_hidden_size, self.proj_dim)

        # Final classifier head
        self.classifier = nn.Linear(self.proj_dim, num_classes)

    def forward(self, input_ids=None, attention_mask=None, pixel_values=None, labels=None):
        # Extract image embeddings (ViT CLS token)
        vision_outputs = self.base_model.vision_tower(pixel_values=pixel_values)
        image_embeds = vision_outputs.last_hidden_state[:, 0]  # shape: [B, 1024]

        # Text embeddings (mean of token embeddings)
        text_embeds = self.base_model.language_model.model.embed_tokens(input_ids).mean(dim=1)  # shape: [B, 4096]

        # Project to common dim
        image_feat = self.vision_proj(image_embeds)  # [B, 1024]
        text_feat = self.text_proj(text_embeds)      # [B, 1024]

        # Fuse
        fused = image_feat + text_feat               # [B, 1024]
        logits = self.classifier(fused)              # [B, num_classes]

        # Optional loss
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {"loss": loss, "logits": logits}

In [None]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

cls_model = LlavaForClassificationWithQLoRA("llava-hf/llava-1.5-7b-hf", num_classes=3)

training_args = TrainingArguments(
    output_dir="./llava-qlora-classifier",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=3e-4,
    num_train_epochs=20,
    fp16=True,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

trainer = Trainer(
    model=cls_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.train()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 4,980,736 || all params: 7,068,407,808 || trainable%: 0.0705


  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msuj00rit20[0m ([33mjaggery[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss
1,0.7652,0.955786
2,0.5973,0.560444
3,0.6639,0.540988
4,0.5018,0.504734
5,0.7233,0.552047
6,0.6742,0.478134
7,0.5184,0.447283
8,0.4657,0.41786
9,0.3239,0.347361
10,0.299,0.301934




In [None]:
from PIL import Image
import torch

def predict_class(image_path, prompt, model, processor, id2label=None):
    # Load and preprocess
    image = Image.open(image_path)
    inputs = processor(prompt, image, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            pixel_values=inputs["pixel_values"]
        )
        logits = outputs["logits"]  # [B, num_classes]
        probs = torch.softmax(logits, dim=-1)
        pred_id = torch.argmax(probs, dim=-1).item()

    if id2label:
        return id2label[pred_id]
    return pred_id

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}

In [None]:
test_df.iloc[0]

In [None]:
predict_class(test_df.iloc[0]["image_path"], test_df.iloc[0]["prompt"], cls_model, processor, id2label)

In [None]:
predictions = []

for i in range(len(test_df)):
  predictions.append(predict_class(test_df.iloc[i]["image_path"], test_df.iloc[i]["prompt"], cls_model, processor, id2label))

In [None]:
from sklearn.metrics import classification_report

print(classification_report(test_df["Group"], predictions))