<a href="https://colab.research.google.com/github/tennisvish/NASA_ML_Error_Detection_Sp25/blob/main/AICervicalFractureV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

# Fix random seeds
torch.manual_seed(42)
np.random.seed(42)



In [43]:
class RobustRSNADataset(Dataset):
    def __init__(self, csv_path, img_root, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_root = img_root
        self.transform = transform

        # Clean IDs (remove trailing slashes/quotes)
        self.df["StudyInstanceUID"] = self.df["StudyInstanceUID"].astype(str).str.strip().str.strip('"')
        folder_ids = set(os.listdir(img_root))

        # Filter to only existing folders
        self.df = self.df[self.df["StudyInstanceUID"].isin(folder_ids)]

        if len(self.df) == 0:
            raise ValueError("No matching folders found. Verify CSV IDs match folder names.")

    def __getitem__(self, idx):
        study_uid = self.df.iloc[idx]["StudyInstanceUID"]
        study_path = os.path.join(self.img_root, study_uid)
        # ... (rest of your DICOM loading code)
    def __init__(self, csv_path, img_root, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_root = img_root.rstrip("/")  # Remove trailing slash if present
        self.transform = transform

        # Get available folders
        available_folders = set(os.listdir(self.img_root))
        self.df = self.df[self.df["StudyInstanceUID"].isin(available_folders)]

        if len(self.df) == 0:
            raise ValueError(f"No matching folders found in {self.img_root}")

    def __getitem__(self, idx):
        study_uid = self.df.iloc[idx]["StudyInstanceUID"]
        study_path = os.path.join(self.img_root, study_uid)  # Auto-handles slashes
        slices = [f for f in os.listdir(study_path) if f.endswith(".dcm")]
    def __init__(self, csv_path, img_root, transform=None, target_type="patient"):
        """
        Args:
            csv_path: Path to train.csv
            img_root: Path to train_images/
            transform: Torchvision transforms
            target_type: "patient" (patient_overall) or "vertebra" (C1-C7)
        """
        self.df = pd.read_csv(csv_path)
        self.img_root = img_root
        self.transform = transform
        self.target_type = target_type

        # Auto-filter: Only keep patients with existing folders
        available_patients = set(os.listdir(img_root))
        self.df = self.df[self.df["StudyInstanceUID"].isin(available_patients)].copy()
        print(f"Using {len(self.df)}/{len(available_patients)} available patients")

        if target_type == "vertebra":
            self.labels = self.df[["C1", "C2", "C3", "C4", "C5", "C6", "C7"]].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        study_uid = self.df.iloc[idx]["StudyInstanceUID"]
        study_path = os.path.join(self.img_root, study_uid)

        # Load random slice
        slices = [f for f in os.listdir(study_path) if f.endswith(".dcm")]
        slice_path = os.path.join(study_path, np.random.choice(slices))

        # Read and preprocess DICOM
        dicom = pydicom.dcmread(slice_path)
        img = apply_voi_lut(dicom.pixel_array, dicom)
        img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0,1]
        img = np.stack([img]*3, axis=-1)  # Convert to 3-channel
        img = Image.fromarray((img * 255).astype(np.uint8))

        if self.transform:
            img = self.transform(img)

        # Get label(s)
        if self.target_type == "patient":
            label = self.df.iloc[idx]["patient_overall"]
            return img, torch.tensor(label, dtype=torch.float32)
        else:
            labels = self.labels[idx]
            return img, torch.tensor(labels, dtype=torch.float32)

In [47]:
class SpineFractureModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet18(pretrained=True)

        # Replace the final fully connected layer
        num_features = self.backbone.fc.in_features  # This should be 512 for resnet18
        self.backbone.fc = nn.Identity()  # Remove the original FC layer

        # Patient head
        self.patient_head = nn.Sequential(
            nn.Linear(num_features, 1),  # 512 → 1
            nn.Sigmoid()
        )

        # Vertebrae heads (C1-C7)
        self.vertebrae_heads = nn.Sequential(
            nn.Linear(num_features, 7),  # 512 → 7 (one output per vertebra)
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.backbone(x)  # Output shape: [batch_size, 512]
        patient_out = self.patient_head(features)  # Shape: [batch_size, 1]
        vertebrae_out = self.vertebrae_heads(features)  # Shape: [batch_size, 7]
        return patient_out, vertebrae_out

In [54]:
def train_model(model, train_loader, val_loader, epochs=5, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    patient_criterion = nn.BCELoss()
    vertebra_criterion = nn.BCELoss()

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            imgs = imgs.to(device)
            # Separate patient and vertebrae labels based on target_type
            if len(labels.shape) == 1:  # patient_overall
                patient_labels = labels.to(device).float()  # Ensure float type for BCELoss
                vertebra_labels = torch.zeros(labels.shape[0], 7, device=device)  # Dummy labels for vertebrae
            else:  # C1-C7
                patient_labels = torch.zeros(labels.shape[0], device=device).float()  # Dummy labels for patient_overall
                vertebra_labels = labels.to(device).float()  # Ensure float type for BCELoss

            optimizer.zero_grad()
            patient_preds, vertebra_preds = model(imgs)
            # Ensure correct dimensions for loss calculation:
            loss = patient_criterion(patient_preds.squeeze(1), patient_labels) + \
                   vertebra_criterion(vertebra_preds, vertebra_labels)

            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for imgs, labels in val_loader:  # Changed here
                imgs = imgs.to(device)
                # Separate patient and vertebrae labels based on target_type
                if len(labels.shape) == 1:  # patient_overall
                    patient_labels = labels.to(device).float()  # Ensure float type for BCELoss
                    vertebra_labels = torch.zeros(labels.shape[0], 7, device=device)  # Dummy labels for vertebrae
                else:  # C1-C7
                    patient_labels = torch.zeros(labels.shape[0], device=device).float()  # Dummy labels for patient_overall
                    vertebra_labels = labels.to(device).float()  # Ensure float type for BCELoss


                patient_preds, vertebra_preds = model(imgs)
                loss = patient_criterion(patient_preds.squeeze(1), patient_labels) + \
                       vertebra_criterion(vertebra_preds, vertebra_labels)
                val_loss += loss.item()

                all_preds.append(torch.cat([patient_preds, vertebra_preds], dim=1).cpu())
                all_labels.append(torch.cat([patient_labels.unsqueeze(1), vertebra_labels], dim=1).cpu())

        # Calculate AUC
        all_preds = torch.cat(all_preds).numpy()
        all_labels = torch.cat(all_labels).numpy()
        auc_scores = [roc_auc_score(all_labels[:, i], all_preds[:, i]) for i in range(8)]
        mean_auc = np.mean(auc_scores)

        print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.4f} | "
              f"Val Loss: {val_loss/len(val_loader):.4f} | Mean AUC: {mean_auc:.4f}")

In [55]:
# Transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize dataset (auto-filters to available patients)
train_dataset = FilteredRSNADataset(
    csv_path="/content/train.csv",
    img_root="/content/train_images/",
    transform=transform,
    target_type="vertebra"
)

# Train/val split
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)

# Initialize and train model
model = SpineFractureModel()
train_model(model, train_loader, val_loader, epochs=5)

# Sample prediction
sample_img, sample_label = train_dataset[0]
with torch.no_grad():
    patient_pred, vertebra_preds = model(sample_img.unsqueeze(0))
    print(f"\nSample Prediction:")
    print(f"Patient fracture prob: {patient_pred.item():.2%}")
    print(f"Vertebrae probs: {[f'{v.item():.2%}' for v in vertebra_preds[0]]}")



Using 3/4 available patients


Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s]


Epoch 1 | Train Loss: 1.5763 | Val Loss: 1.7298 | Mean AUC: nan


Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]


Epoch 2 | Train Loss: 1.4067 | Val Loss: 1.5239 | Mean AUC: nan


Epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s]


Epoch 3 | Train Loss: 1.3939 | Val Loss: 1.3779 | Mean AUC: nan


Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  2.55it/s]


Epoch 4 | Train Loss: 1.3673 | Val Loss: 1.2732 | Mean AUC: nan


Epoch 5: 100%|██████████| 1/1 [00:00<00:00,  2.60it/s]

Epoch 5 | Train Loss: 1.1918 | Val Loss: 1.3767 | Mean AUC: nan

Sample Prediction:
Patient fracture prob: 74.14%
Vertebrae probs: ['4.40%', '5.54%', '32.94%', '43.26%', '39.03%', '56.69%', '37.40%']





In [66]:
def predict(image_path, model, transform):
    """Predict for a single DICOM image"""
    # Define device here
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load and preprocess image (same as dataset __getitem__)
    dicom = pydicom.dcmread(image_path)
    img = apply_voi_lut(dicom.pixel_array, dicom)
    img = (img - img.min()) / (img.max() - img.min())
    img = np.stack([img]*3, axis=-1)
    img = Image.fromarray((img * 255).astype(np.uint8))
    img = transform(img).unsqueeze(0)

    # Predict
    with torch.no_grad():
        patient_prob, vertebrae_probs = model(img.to(device))

    return {
        "patient_fracture_prob": patient_prob.item(),
        "vertebrae_probs": vertebrae_probs.squeeze().tolist()
    }

# Example usage
pred = predict("test_images/10.dcm", model, transform)
print(f"Patient fracture probability: {pred['patient_fracture_prob']:.2%}")
for i, prob in enumerate(pred['vertebrae_probs'], 1):
    print(f"C{i} fracture probability: {prob:.2%}")

Patient fracture probability: 57.37%
C1 fracture probability: 14.17%
C2 fracture probability: 12.07%
C3 fracture probability: 37.66%
C4 fracture probability: 29.42%
C5 fracture probability: 30.58%
C6 fracture probability: 66.82%
C7 fracture probability: 62.64%


In [None]:
!pip install pydicom pylibjpeg pylibjpeg-libjpeg gdcm

In [None]:
import os
import pandas as pd

# Load CSV and get StudyInstanceUIDs
df = pd.read_csv("train.csv")
csv_ids = set(df["StudyInstanceUID"].astype(str))  # Force string type

# Get folder names in train_images
folder_ids = set(os.listdir("train_images"))

# Find mismatches
mismatched = csv_ids - folder_ids
print(f"{len(mismatched)} CSV IDs missing in folders. Examples:")
print(list(mismatched)[:3])  # Print first 3 problematic IDs

In [59]:
# Recommended save method (saves both architecture and weights)
torch.save(model, "spine_fracture_model.pth")

In [60]:
# Alternative (saves just weights - requires model class definition when loading)
torch.save(model.state_dict(), "spine_fracture_weights.pth")

In [67]:
# Try saving with different protocol
torch.save(model.state_dict(), "new_model.pt", pickle_protocol=4)