In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
!pip install opencv-python h5py -q

In [3]:
import os
import h5py
import cv2
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet50_Weights

warnings.filterwarnings("ignore")

SEED = 42
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
# Paths to CSV files
TRAIN_CSV = "/content/drive/MyDrive/splits/train.csv"
VAL_CSV = "/content/drive/MyDrive/splits/val.csv"
TEST_CSV = "/content/drive/MyDrive/splits/test.csv"

CACHE_DIR = "/content/drive/MyDrive/features_cache_paper"
NUM_FRAMES = 50
NUM_CLASSES = 3
FEATURE_DIM = 512  # After first 3 blocks of ResNet50

# Loss weights
ALPHA = 0.1
BETA = 0.001

# Alternative reduced weights if training is unstable:
# ALPHA = 0.01   # Reduced spatial constraint
# BETA = 0.0001  # Reduced ranking constraint

In [5]:
def replace_batchnorm_with_groupnorm(module, num_groups=32):
    """
    Replace all BatchNorm layers with GroupNorm to avoid interactions between instances.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            num_channels = child.num_features
            # Use num_groups=32 or adjust based on num_channels
            actual_groups = min(num_groups, num_channels)
            if num_channels % actual_groups != 0:
                actual_groups = 1  # Fallback to LayerNorm equivalent
            setattr(module, name, nn.GroupNorm(actual_groups, num_channels))
        else:
            replace_batchnorm_with_groupnorm(child, num_groups)

def extract_video_frames(video_path, num_frames=50):
    """Extract uniformly sampled frames from video"""
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if total_frames == 0:
        cap.release()
        return None

    # Uniform sampling
    indices = np.linspace(0, total_frames-1, min(num_frames, total_frames), dtype=int)

    frames = []
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)

    cap.release()
    return frames if frames else None

def extract_features_with_cache(video_path, cache_dir, feature_extractor, device, transform, overwrite=False):
    """Extract ResNet50 (first 3 blocks) features with H5 caching"""
    os.makedirs(cache_dir, exist_ok=True)

    # Cache path
    video_name = Path(video_path).stem
    class_name = Path(video_path).parent.name
    cache_subdir = os.path.join(cache_dir, class_name)
    os.makedirs(cache_subdir, exist_ok=True)
    cache_path = os.path.join(cache_subdir, f"{video_name}.h5")

    # Load from cache
    if not overwrite and os.path.exists(cache_path):
        try:
            with h5py.File(cache_path, 'r') as f:
                features = torch.from_numpy(f['features'][:])
                coords = torch.from_numpy(f['coords'][:])
                nearest = eval(f['nearest'][()])
            return features, coords, nearest
        except:
            pass

    # Extract frames
    frames = extract_video_frames(video_path, num_frames=NUM_FRAMES)
    if frames is None:
        return None, None, None

    # Extract features
    features_list = []
    batch_size = 10

    with torch.no_grad():
        for i in range(0, len(frames), batch_size):
            batch = [transform(f) for f in frames[i:i+batch_size]]
            batch = torch.stack(batch).to(device)
            feat = feature_extractor(batch)  # [B, 512, H, W]
            feat = F.adaptive_avg_pool2d(feat, (1, 1)).squeeze(-1).squeeze(-1)  # [B, 512]
            features_list.append(feat.cpu())

    features = torch.cat(features_list, dim=0)  # [num_frames, 512]

    # Create temporal coordinates
    num_frames_actual = features.shape[0]
    coords = torch.zeros(num_frames_actual, 2)
    coords[:, 0] = torch.arange(num_frames_actual)

    # Nearest temporal neighbors
    nearest = []
    for i in range(num_frames_actual):
        neighbors = []
        if i > 0:
            neighbors.append(i - 1)
        neighbors.append(i)
        if i < num_frames_actual - 1:
            neighbors.append(i + 1)
        nearest.append(neighbors)

    # Save to cache
    try:
        with h5py.File(cache_path, 'w') as f:
            f.create_dataset('features', data=features.numpy(), compression='gzip', compression_opts=4)
            f.create_dataset('coords', data=coords.numpy(), compression='gzip', compression_opts=4)
            f.create_dataset('nearest', data=str(nearest))
    except:
        pass

    return features, coords, nearest

In [6]:
class VideoMILDataset(Dataset):
    def __init__(self, video_paths, labels, cache_dir, feature_extractor, device, transform):
        self.video_paths = video_paths
        self.labels = labels
        self.cache_dir = cache_dir
        self.feature_extractor = feature_extractor
        self.device = device
        self.transform = transform

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]

        features, coords, nearest = extract_features_with_cache(
            video_path, self.cache_dir, self.feature_extractor,
            self.device, self.transform
        )

        if features is None:
            features = torch.zeros(NUM_FRAMES, FEATURE_DIM)
            coords = torch.zeros(NUM_FRAMES, 2)
            nearest = [[i] for i in range(NUM_FRAMES)]

        return features, label, coords, nearest

In [7]:
# Constraint functions from AttriMIL paper
def spatial_constraint(attribute_score, coords, nearest_indices, n_classes):
    """
    Spatial constraint loss to enforce smoothness in attribute scores.
    From paper: Encourages similar attribute scores for neighboring instances.
    FIXED: Added numerical stability and proper normalization
    """
    if attribute_score.shape[2] == 0:
        return torch.tensor(0.0, device=attribute_score.device)

    loss = 0.0
    num_instances = attribute_score.shape[2]
    count = 0

    for i in range(num_instances):
        neighbors = nearest_indices[i]
        for neighbor_idx in neighbors:
            if neighbor_idx != i and neighbor_idx < num_instances:
                for c in range(n_classes):
                    # L2 distance between attribute scores
                    diff = attribute_score[0, c, i] - attribute_score[0, c, neighbor_idx]
                    loss += diff * diff
                    count += 1

    # Normalize by actual number of comparisons
    if count > 0:
        return loss / count
    return torch.tensor(0.0, device=attribute_score.device)

def ranking_constraint(attribute_score, label, n_classes):
    """
    Ranking constraint loss to ensure correct class has higher attribute scores.
    From paper: Encourages attribute scores of positive class > negative classes.
    FIXED: Use bag-level aggregation instead of instance-level, reduced margin
    """
    if attribute_score.shape[2] == 0:
        return torch.tensor(0.0, device=attribute_score.device)

    target_class = label.item()

    # Aggregate attribute scores to bag level (sum over instances)
    bag_scores = torch.sum(attribute_score[0], dim=1)  # Shape: [n_classes]

    loss = 0.0
    # Encourage target class bag score > other class bag scores
    for c in range(n_classes):
        if c != target_class:
            # Smaller margin for stability
            margin = 0.1
            loss += F.relu(bag_scores[c] - bag_scores[target_class] + margin)

    # Normalize by number of negative classes
    return loss / (n_classes - 1)

In [8]:
class Attn_Net_Gated(nn.Module):
    def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh()]
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))
        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)
        return A, x

class AttriMIL(nn.Module):
    def __init__(self, n_classes=3, dim=512):
        super().__init__()
        self.adaptor = nn.Sequential(
            nn.Linear(dim, dim//2),
            nn.ReLU(),
            nn.Linear(dim//2, dim)
        )

        attention = []
        classifiers = []
        for i in range(n_classes):
            attention.append(Attn_Net_Gated(L=dim, D=dim//2))
            classifiers.append(nn.Linear(dim, 1))

        self.attention_nets = nn.ModuleList(attention)
        self.classifiers = nn.ModuleList(classifiers)
        self.n_classes = n_classes
        self.bias = nn.Parameter(torch.zeros(n_classes), requires_grad=True)

    def forward(self, h):
        h = h + self.adaptor(h)
        A_raw = torch.empty(self.n_classes, h.size(0)).to(h.device)
        instance_score = torch.empty(1, self.n_classes, h.size(0)).float().to(h.device)

        for c in range(self.n_classes):
            A, h_out = self.attention_nets[c](h)
            A = torch.transpose(A, 1, 0)
            A_raw[c] = A
            instance_score[0, c] = self.classifiers[c](h)[:, 0]


        attribute_score = torch.empty(1, self.n_classes, h.size(0)).float().to(h.device)
        logits = torch.empty(1, self.n_classes).float().to(h.device)

        for c in range(self.n_classes):
            # Clamp attention scores to prevent explosion
            A_clamped = torch.clamp(A_raw[c], min=-10, max=10)
            exp_A = torch.exp(A_clamped)

            # Compute attribute scores with clamped values
            attribute_score[0, c] = instance_score[0, c] * exp_A

            # Add epsilon to prevent division by zero
            eps = 1e-8
            sum_exp_A = torch.sum(exp_A, dim=-1) + eps

            # Compute logits
            logits[0, c] = torch.sum(attribute_score[0, c], keepdim=True, dim=-1) / sum_exp_A + self.bias[c]

        Y_hat = torch.topk(logits, 1, dim=1)[1]
        Y_prob = F.softmax(logits, dim=1)

        return logits, Y_prob, Y_hat, attribute_score

In [9]:
# Setup feature extractor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

weights = ResNet50_Weights.IMAGENET1K_V1
resnet = models.resnet50(weights=weights)

# Extract first 3 blocks (conv1, bn1, relu, maxpool, layer1, layer2, layer3)
# Output: 512 channels after layer3
feature_extractor = nn.Sequential(
    resnet.conv1,
    resnet.bn1,
    resnet.relu,
    resnet.maxpool,
    resnet.layer1,
    resnet.layer2,
    resnet.layer3
)

# Replace BatchNorm with GroupNorm
replace_batchnorm_with_groupnorm(feature_extractor, num_groups=32)

feature_extractor.eval()
feature_extractor = feature_extractor.to(device)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

print("âœ“ Feature extractor ready (First 3 blocks, BatchNormâ†’GroupNorm)")
print(f"âœ“ Output feature dim: {FEATURE_DIM}")

Using device: cpu
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 97.8M/97.8M [00:02<00:00, 46.5MB/s]


âœ“ Feature extractor ready (First 3 blocks, BatchNormâ†’GroupNorm)
âœ“ Output feature dim: 512


In [10]:
# Load data from CSV files
def load_split_from_csv(csv_path, label_col='label', filename_col='filename'):
    """
    Load video paths and labels from CSV file.
    CSV should have columns: filename, label
    """
    df = pd.read_csv(csv_path)
    video_paths = df[filename_col].tolist()
    labels = df[label_col].tolist()
    return video_paths, labels

# Load train/val/test splits from CSV
print("ðŸ“‚ Loading data splits from CSV files...")
train_paths, train_labels = load_split_from_csv(TRAIN_CSV)
val_paths, val_labels = load_split_from_csv(VAL_CSV)
test_paths, test_labels = load_split_from_csv(TEST_CSV)

print(f"\nTotal videos: {len(train_paths) + len(val_paths) + len(test_paths)}")
print(f"  Train: {len(train_paths)}")
print(f"  Val:   {len(val_paths)}")
print(f"  Test:  {len(test_paths)}")

# Count class distribution
label_names = {0: 'Normal', 1: 'Adenoma', 2: 'Malignant'}
print(f"\nClass distribution:")
print(f"Train - Normal: {train_labels.count(0)}, Adenoma: {train_labels.count(1)}, Malignant: {train_labels.count(2)}")
print(f"Val   - Normal: {val_labels.count(0)}, Adenoma: {val_labels.count(1)}, Malignant: {val_labels.count(2)}")
print(f"Test  - Normal: {test_labels.count(0)}, Adenoma: {test_labels.count(1)}, Malignant: {test_labels.count(2)}")

ðŸ“‚ Loading data splits from CSV files...


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/splits/train.csv'

In [None]:
# Pre extract all features
print("\n Pre-extracting features for all videos...")
print("First 3 ResNet blocks with GroupNorm")


all_paths = train_paths + val_paths + test_paths
success_count = 0
failed_videos = []

for video_path in tqdm(all_paths, desc="Extracting features"):
    result = extract_features_with_cache(video_path, CACHE_DIR, feature_extractor, device, transform)
    if result[0] is not None:
        success_count += 1
    else:
        failed_videos.append(video_path)

print(f"\nâœ“ Feature extraction complete!")
print(f"  Successfully cached: {success_count}/{len(all_paths)} videos")
if failed_videos:
    print(f"  Failed: {len(failed_videos)} videos")
print(f"\nFeatures saved to: {CACHE_DIR}")

In [None]:
# Create datasets and loaders
train_dataset = VideoMILDataset(train_paths, train_labels, CACHE_DIR, feature_extractor, device, transform)
val_dataset = VideoMILDataset(val_paths, val_labels, CACHE_DIR, feature_extractor, device, transform)

# batch_size=1
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")

In [None]:
# Initialize model
model = AttriMIL(n_classes=NUM_CLASSES, dim=FEATURE_DIM).to(device)

# Paper hyperparameters: lr=2e-4, weight_decay from AdamW
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=150, eta_min=1e-6)

print("âœ“ Model initialized")
print(f"âœ“ Loss weights: Î±={ALPHA} (spatial), Î²={BETA} (ranking)")

In [None]:
# Training loop with spatial and ranking constraints
train_losses, val_losses, train_accs, val_accs, train_f1s, val_f1s = [], [], [], [], [], []
train_bag_losses, train_spa_losses, train_rank_losses = [], [], []
best_acc = 0

for epoch in range(150):
    # Train
    model.train()
    train_loss = 0
    train_bag_loss = 0
    train_spa_loss = 0
    train_rank_loss = 0
    train_preds = []
    train_true = []
    nan_count = 0

    for features, label, coords, nearest in tqdm(train_loader, desc=f"Epoch {epoch+1}/150 [Train]"):
        features = features.squeeze(0).to(device)  # Remove batch dim: [1, T, D] -> [T, D]
        label = label.to(device)
        coords = coords.squeeze(0)  # [T, 2]
        nearest = nearest[0]  # List of neighbor indices

        optimizer.zero_grad()

        logits, probs, pred, attr_scores = model(features)

        # Bag-level loss (standard cross-entropy)
        loss_bag = criterion(logits, label)

        # Spatial constraint loss (smoothness between neighbors)
        loss_spa = spatial_constraint(attr_scores, coords, nearest, NUM_CLASSES)

        # Ranking constraint loss (positive class > negative classes)
        loss_rank = ranking_constraint(attr_scores, label, NUM_CLASSES)

        # Total loss with paper's weights: Î±=0.1, Î²=0.001
        loss = loss_bag + ALPHA * loss_spa + BETA * loss_rank


        if torch.isnan(loss) or torch.isinf(loss):
            nan_count += 1
            print(f"  Warning: NaN/Inf loss detected (bag={loss_bag.item():.4f}, spa={loss_spa.item():.4f}, rank={loss_rank.item():.4f})")
            continue

        loss.backward()

        #Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        train_loss += loss.item()
        train_bag_loss += loss_bag.item()
        train_spa_loss += loss_spa.item()
        train_rank_loss += loss_rank.item()
        train_preds.append(pred.item())
        train_true.append(label.item())

    # Handle division by zero if all batches had NaN
    num_valid_batches = len(train_loader) - nan_count
    if num_valid_batches == 0:
        print(f"Epoch {epoch+1}: All batches had NaN losses! Stopping training.")
        break

    train_loss /= num_valid_batches
    train_bag_loss /= num_valid_batches
    train_spa_loss /= num_valid_batches
    train_rank_loss /= num_valid_batches
    train_acc = accuracy_score(train_true, train_preds)
    train_f1 = f1_score(train_true, train_preds, average='macro')

    # Validation (no constraints during evaluation)
    model.eval()
    val_loss = 0
    val_preds = []
    val_true = []

    with torch.no_grad():
        for features, label, coords, nearest in tqdm(val_loader, desc=f"Epoch {epoch+1}/150 [Val]"):
            features = features.squeeze(0).to(device)
            label = label.to(device)

            logits, probs, pred, attr_scores = model(features)
            loss = criterion(logits, label)

            # Skip NaN losses in validation too
            if not (torch.isnan(loss) or torch.isinf(loss)):
                val_loss += loss.item()
                val_preds.append(pred.item())
                val_true.append(label.item())

    val_loss /= len(val_loader)
    val_acc = accuracy_score(val_true, val_preds)
    val_f1 = f1_score(val_true, val_preds, average='macro')

    scheduler.step()

    print(f"Epoch {epoch+1}: Train Acc={train_acc:.4f}, F1={train_f1:.4f}, Loss={train_loss:.4f} (Bag={train_bag_loss:.4f}, Spa={train_spa_loss:.4f}, Rank={train_rank_loss:.4f}) | Val Acc={val_acc:.4f}, F1={val_f1:.4f}, Loss={val_loss:.4f}")
    if nan_count > 0:
        print(f"  (Skipped {nan_count} batches due to NaN)")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "/content/drive/MyDrive/attrimil_paper_best_model.pth")
        print(f"  âœ“ Best model saved (Val Acc: {best_acc:.4f})")

    torch.save(model.state_dict(), "/content/drive/MyDrive/attrimil_paper_current_model.pth")

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    train_f1s.append(train_f1)
    val_f1s.append(val_f1)
    train_bag_losses.append(train_bag_loss)
    train_spa_losses.append(train_spa_loss)
    train_rank_losses.append(train_rank_loss)

print(f"\nBest validation accuracy: {best_acc:.4f}")

In [None]:
print(f"Best Val Acc: {np.array(val_accs).max():.4f}")

In [None]:
epochs = range(len(train_losses))

plt.figure(figsize=(20, 4))

# Plotting Losses
plt.subplot(1, 4, 1)
plt.plot(epochs, train_losses, 'r', label='Training Loss (Total)')
plt.plot(epochs, val_losses, 'b', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Plotting Loss Components
plt.subplot(1, 4, 2)
plt.plot(epochs, train_bag_losses, 'g', label='Bag Loss')
plt.plot(epochs, train_spa_losses, 'orange', label=f'Spatial Loss (Î±={ALPHA})')
plt.plot(epochs, train_rank_losses, 'purple', label=f'Ranking Loss (Î²={BETA})')
plt.title('Training Loss Components')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Plotting Accuracies
plt.subplot(1, 4, 3)
plt.plot(epochs, train_accs, 'r', label='Training Accuracy')
plt.plot(epochs, val_accs, 'b', label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

# Plotting F1 Scores
plt.subplot(1, 4, 4)
plt.plot(epochs, train_f1s, 'r', label='Training F1 Score')
plt.plot(epochs, val_f1s, 'b', label='Validation F1 Score')
plt.title('Training and Validation F1 Score')
plt.xlabel('Epochs')
plt.ylabel('F1 Score')
plt.legend()

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/attrimil_paper_results.png', dpi=300, bbox_inches='tight')
plt.show()