In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

Mounted at /content/drive


In [None]:
zip_path = "/content/drive/MyDrive/FinalProjectHealthCare/archive (8).zip"
extract_path = "/content/drive/MyDrive/FinalProjectHealthCare/data_image(1)"

import zipfile

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


disease_labels = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
    'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
    'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

class ImprovedChestXRayDataset(Dataset):
    def __init__(self, dataframe, image_folder, clip_processor, disease_labels):
        self.df = dataframe
        self.image_folder = image_folder
        self.processor = clip_processor
        self.disease_labels = disease_labels

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_folder, row['Image Index'])

        try:
            image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, return_tensors="pt")
            inputs = {k: v.squeeze(0) for k, v in inputs.items()}

            label_values = [1 if label in row['Finding Labels'].split('|') else 0 for label in self.disease_labels]
            labels = torch.tensor(label_values).float()

            return inputs, labels, row['Image Index'], image
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None
    inputs = {k: torch.stack([item[0][k] for item in batch]) for k in batch[0][0]}
    labels = torch.stack([item[1] for item in batch])
    image_names = [item[2] for item in batch]
    images = [item[3] for item in batch]
    return inputs, labels, image_names, images

class CustomClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.net(x)

def zero_shot_predict(clip_model, processor, image, device, disease_labels):
    text_inputs = processor(text=[f"A chest X-ray showing {label}" for label in disease_labels], return_tensors="pt", padding=True).to(device)
    image_inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        image_features = clip_model.get_image_features(**image_inputs)
        text_features = clip_model.get_text_features(**text_inputs)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        logits_per_image = image_features @ text_features.T
        probs = logits_per_image.softmax(dim=-1).squeeze().cpu().numpy()

        top2_indices = np.argsort(probs)[-2:]
        top2_labels = [(disease_labels[i], probs[i]) for i in top2_indices[::-1]]
        return top2_labels

def train_model(clip_model, processor, device, data_path):
    df = pd.read_csv(os.path.join(data_path, "Data_Entry_2017.csv"))
    df = df.head(10)
    image_folder = os.path.join(data_path, "images-224/images-224")

    train_dataset = ImprovedChestXRayDataset(df, image_folder, processor, disease_labels)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn_ignore_none, num_workers=2)

    with torch.no_grad():
        dummy_inputs = processor(images=Image.new('RGB', (224, 224)), return_tensors="pt")
        dummy_inputs = {k: v.to(device) for k, v in dummy_inputs.items()}
        dummy_features = clip_model.get_image_features(**dummy_inputs)
        feature_dim = dummy_features.shape[-1]

    classifier = CustomClassifier(feature_dim, len(disease_labels)).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

    num_epochs = 15
    classifier.train()

    for epoch in range(num_epochs):
        total_loss = 0
        all_labels, all_preds = [], []

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            if batch is None:
                continue
            inputs, labels, image_names, images = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)

            with torch.no_grad():
                image_features = clip_model.get_image_features(**inputs)

            outputs = classifier(image_features)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            probs = torch.sigmoid(outputs).detach().cpu().numpy()
            labels_np = labels.cpu().numpy()

            batch_preds = []
            for i in range(len(probs)):
                top2_indices = np.argsort(probs[i])[-2:]
                preds_top2 = np.zeros_like(probs[i], dtype=int)
                preds_top2[top2_indices] = 1
                batch_preds.append(preds_top2)

            batch_preds = np.array(batch_preds)
            all_preds.extend(batch_preds)
            all_labels.extend(labels_np)

            for i in range(len(image_names)):
                true_labels = [disease_labels[j] for j, val in enumerate(labels_np[i]) if val == 1]
                if not true_labels:
                    true_labels = ['No disease']

                pred_labels = [(disease_labels[j], probs[i][j]) for j in np.argsort(probs[i])[-2:][::-1]]
                zs_labels = zero_shot_predict(clip_model, processor, images[i], device, disease_labels)

                print(f"Image: {image_names[i]}")
                print(f"  True labels: {true_labels}")
                print(f"  Fine-tuned Predicted labels (Top 2):")
                for label, prob in pred_labels:
                    print(f"    - {label}: {prob:.4f}")

                print(f"  Zero-Shot Predicted labels (Top 2):")
                for label, prob in zs_labels:
                    print(f"    - {label}: {prob:.4f}")

                plt.imshow(images[i])
                plt.title(f"{image_names[i]}\nFT: {', '.join([lbl for lbl, _ in pred_labels])}\nZS: {', '.join([lbl for lbl, _ in zs_labels])}")
                plt.axis('off')
                plt.show()

        avg_loss = total_loss / len(train_loader)
        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)

        precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

        print(f"Epoch {epoch+1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Precision (macro): {precision:.4f}")
        print(f"  Recall (macro): {recall:.4f}")
        print(f"  F1 Score (macro): {f1:.4f}")
        print("-" * 40)

        print(f"Epoch {epoch+1} Confusion Matrices per Disease:")
        for i, disease in enumerate(disease_labels):
            cm = confusion_matrix(all_labels[:, i], all_preds[:, i])
            if cm.shape == (2, 2):
                plt.figure(figsize=(4,3))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                            xticklabels=['Pred Neg', 'Pred Pos'],
                            yticklabels=['True Neg', 'True Pos'])
                plt.title(f'Confusion Matrix for {disease} (Epoch {epoch+1})')
                plt.ylabel('True label')
                plt.xlabel('Predicted label')
                plt.show()
            else:
                print(f"Not enough data to plot confusion matrix for {disease} in this epoch.")

    return classifier

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    data_path = "/content/drive/MyDrive/FinalProjectHealthCare/data_image(1)"
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    trained_model = train_model(clip_model, processor, device, data_path)
