#### Preliminary Testing - Phase 3 - Model Training - Multi-class classification

This notebook attempts to implement the training and evaluation script specific to CLIP

In [1]:
import os
import pandas as pd
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Sample Questions and Classes
questions = [
    "What limb is injured?",
    "Is the patient intubated?",
    "Where is the catheter inserted?",
    "Is there bleeding?",
    "Has the bleeding stopped?",
    "Is the patient moving?",
    "Is the patient breathing?",
    "Is there a tourniquet?",
    "Is there a chest tube?",
    "Are the patient and instruments secured?",
    "If a limb is missing which one?",
    "Is there mechanical ventilation?",
    "What is the position of the injury?"
]

classes = [
    ['no limb is injured', 'left leg', 'left arm', 'right leg', 'right arm'],
    ["can't identify", 'no', 'yes'],
    ['no catheter is used', 'lower limb'],
    ['no', 'yes'],
    ['there is no bleeding', 'no', 'yes'],
    ["can't identify", 'yes', 'no'],
    ["can't identify", 'no', 'yes'],
    ['no', 'yes'],
    ['no', 'yes'],
    ['no', 'yes', "can't identify"],
    ['none', 'left arm', 'left leg', 'right leg'],
    ["can't identify", 'no', 'yes'],
    ['thorax', 'throat', "can't identify", 'lower limb', 'abdomen', 'upper limb']
]

# Load CLIP processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")



Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
# Custom dataset
class ClassificationVQADataset(Dataset):
    def __init__(self, dataframe, image_dir, processor, classes):
        self.data = dataframe
        self.image_dir = image_dir
        self.processor = processor
        self.qa_columns = dataframe.columns[2:]
        self.label_encoders = [LabelEncoder().fit(cls) for cls in classes]

    def __len__(self):
        return len(self.data) * len(self.qa_columns)

    def __getitem__(self, idx):
        row_idx = idx // len(self.qa_columns)
        q_idx = idx % len(self.qa_columns)
        row = self.data.iloc[row_idx]
        question = self.qa_columns[q_idx]
        answer = row[question]
        if pd.isna(answer):
            next_idx = (idx + 1) % len(self)
            return self.__getitem__(next_idx)
        image_path = os.path.join(self.image_dir, row['video_id'], f"{row['video_id']}_frame{row['frame']}.jpg")
        image = Image.open(image_path).convert("RGB")
        label = self.label_encoders[q_idx].transform([str(answer).strip()])[0]
        return {
            "text": question.strip(),
            "image": image,
            "label": torch.tensor(label, dtype=torch.long),
            "question_idx": q_idx
        }


In [4]:
# Custom collate function for batching
def classification_collate(batch):
    texts = [item["text"] for item in batch]
    images = [item["image"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])
    question_idxs = [item["question_idx"] for item in batch]
    processed = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
    return {
        "input_ids": processed["input_ids"],
        "attention_mask": processed["attention_mask"],
        "pixel_values": processed["pixel_values"],
        "labels": labels,
        "question_idxs": question_idxs
    }


In [5]:
# CLIP-based multi-class classifier
class CLIPClassifier(nn.Module):
    def __init__(self, num_classes_per_question):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.classifiers = nn.ModuleList([
            nn.Linear(self.clip.config.projection_dim, num_classes) for num_classes in num_classes_per_question
        ])
    
    def forward(self, input_ids, attention_mask, pixel_values, question_idx):
        outputs = self.clip(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
        pooled_output = outputs.text_embeds + outputs.image_embeds
        logits = self.classifiers[question_idx](pooled_output)
        return logits


In [6]:
# Load your data
train_df = pd.read_csv("data/train_data.csv")
test_df = pd.read_csv("data/test_data.csv")

# Build dataset and dataloader
train_dataset = ClassificationVQADataset(train_df, image_dir="frames", processor=processor, classes=classes)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=classification_collate)

# Initialize model
model = CLIPClassifier([len(c) for c in classes]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()


In [7]:
# Training loop
for epoch in range(8):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        question_idxs = batch["question_idxs"]

        losses = []
        for i in range(len(input_ids)):
            logits = model(
                input_ids[i].unsqueeze(0),
                attention_mask[i].unsqueeze(0),
                pixel_values[i].unsqueeze(0),
                question_idx=question_idxs[i]
            )
            loss = criterion(logits, labels[i].unsqueeze(0))
            losses.append(loss)

        batch_loss = sum(losses) / len(losses)
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += batch_loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"✅ Epoch {epoch+1} Average Loss: {avg_loss:.4f}")


Epoch 1: 100%|██████████| 5191/5191 [1:40:44<00:00,  1.16s/it]


✅ Epoch 1 Average Loss: 0.6270


Epoch 2: 100%|██████████| 5191/5191 [2:48:41<00:00,  1.95s/it]   


✅ Epoch 2 Average Loss: 0.4815


Epoch 3: 100%|██████████| 5191/5191 [2:30:09<00:00,  1.74s/it]  


✅ Epoch 3 Average Loss: 0.3977


Epoch 4: 100%|██████████| 5191/5191 [2:30:03<00:00,  1.73s/it]  


✅ Epoch 4 Average Loss: 0.3429


Epoch 5: 100%|██████████| 5191/5191 [2:27:25<00:00,  1.70s/it]  


✅ Epoch 5 Average Loss: 0.3050


Epoch 6: 100%|██████████| 5191/5191 [2:22:53<00:00,  1.65s/it]  


✅ Epoch 6 Average Loss: 0.2729


Epoch 7: 100%|██████████| 5191/5191 [1:27:50<00:00,  1.02s/it]


✅ Epoch 7 Average Loss: 0.2494


Epoch 8: 100%|██████████| 5191/5191 [1:10:44<00:00,  1.22it/s]

✅ Epoch 8 Average Loss: 0.2261





In [8]:
torch.save(model.state_dict(), "clip_classifier.pt")

In [9]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score

# Load processor and model
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPClassifier([len(cls) for cls in classes])
model.load_state_dict(torch.load("clip_classifier.pt"))  # <-- load saved model
model.to(device)
model.eval()

# Load data
# _, test_df = train_test_split(df, test_size=0.2, random_state=42)

# Create test dataset & loader
test_dataset = ClassificationVQADataset(test_df, image_dir="frames_sample", processor=processor, classes=classes)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=classification_collate)

# Set up label encoders (must match training!)
label_encoders = [LabelEncoder().fit(cls) for cls in classes]

# Initialize results tracker
all_preds = [[] for _ in questions]
all_truths = [[] for _ in questions]


In [None]:
results = []

with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        question_idxs = batch["question_idxs"]

        for i in range(len(input_ids)):
            q_idx = question_idxs[i]

            logits = model(
                input_ids[i].unsqueeze(0),
                attention_mask[i].unsqueeze(0),
                pixel_values[i].unsqueeze(0),
                question_idx=q_idx
            )

            pred_idx = logits.argmax(dim=1).item()
            pred_label = label_encoders[q_idx].inverse_transform([pred_idx])[0]
            true_label = label_encoders[q_idx].inverse_transform([labels[i].item()])[0]

            # Get original row in DataFrame
            dataset_idx = batch_idx * test_loader.batch_size + i
            frame_row_idx = dataset_idx // len(questions)  # one row generates len(questions) samples
            question = questions[q_idx]

            video_id = test_dataset.data.iloc[frame_row_idx]["video_id"]
            frame = test_dataset.data.iloc[frame_row_idx]["frame"]

            results.append({
                "video_id": video_id,
                "frame": frame,
                "question": question,
                "true_answer": true_label,
                "predicted_answer": pred_label
            })

# Build DataFrame
results_df = pd.DataFrame(results)
results_df.to_csv("clip_predictions.csv", index=False)

  video_id  frame                         question           true_answer  \
0   Cric20      1            What limb is injured?    no limb is injured   
1   Cric20      1        Is the patient intubated?        can't identify   
2   Cric20      1  Where is the catheter inserted?   no catheter is used   
3   Cric20      1               Is there bleeding?                    no   
4   Cric20      1        Has the bleeding stopped?  there is no bleeding   

       predicted_answer  
0    no limb is injured  
1        can't identify  
2   no catheter is used  
3                    no  
4  there is no bleeding  


In [79]:
# Print accuracy per question
print("\nAccuracy per question:\n")
for i, q in enumerate(questions):
    le = label_encoders[i]
    if all_truths[i]:
        acc = accuracy_score(all_truths[i], all_preds[i])
        print(f"{q}: {acc:.2%}")
        # Show first 3 predictions
        for j in range(min(3, len(all_preds[i]))):
            pred_label = le.inverse_transform([all_preds[i][j]])[0]
            true_label = le.inverse_transform([all_truths[i][j]])[0]
            print(f"Pred: {pred_label} | GT: {true_label}")
        print()



Accuracy per question:

