# SAM 3 Linearprobing

## Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!git clone https://github.com/facebookresearch/sam3.git /content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM\ 3/sam3_repo


In [4]:
%cd /content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM\ 3/sam3_repo
!pip install -e ".[notebooks]"
%cd /content
!pip install -q supervision jupyter_bbox_widget  > /dev/null
!pip install triton decord  > /dev/null

/content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM 3/sam3_repo
Obtaining file:///content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM%203/sam3_repo
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: sam3
  Building editable for sam3 (pyproject.toml) ... [?25l[?25hdone
  Created wheel for sam3: filename=sam3-0.1.0-0.editable-py3-none-any.whl size=15413 sha256=29004d4c38cf2f061338d2c10ec541667978190abdfada7395569635d161d2c1
  Stored in directory: /tmp/pip-ephem-wheel-cache-s3psnu20/wheels/dd/85/0d/50f71564a220f942d76bb9b370d0bd2a76e6e4fe8108c3cf67
Successfully built sam3
Installing collected packages: sam3
  Attempting uninstall: sam3
    Found existing installation: sam3 0.1.0
    Uninstalling sam3-0.1.0:
      Successfully uninstalled sam3-0.1.0


In [1]:
from google.colab import userdata
import os
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

In [7]:
import sys
sys.path.insert(0, '/content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM 3/sam3_repo')
from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
import os

bpe_path = "/content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM 3/sam3_repo/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=bpe_path)
processor = Sam3Processor(model, confidence_threshold=0.5)

AssertionError: Torch not compiled with CUDA enabled

In [6]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, jaccard_score
import matplotlib.pyplot as plt

## Tumor Dataset

In [None]:
class BraTSDataset(Dataset):
    """
    BraTS dataset for linear probing.
    Returns: 2D slices from 3D volumes with corresponding segmentation masks
    """
    def __init__(self, data_root, modality='t1ce', slice_range=(50, 130),
                 normalize=True, img_size=1008):
        """
        Args:
            data_root: Path to BraTS training data directory
            modality: Which MRI modality to use ('flair', 't1', 't1ce', 't2')
            slice_range: (min, max) slice indices to use from each volume
            normalize: Whether to normalize images to [0, 1]
            img_size: Size to resize images to (SAM3 expects 1008x1008)
        """
        self.data_root = data_root
        self.modality = modality
        self.slice_range = slice_range
        self.normalize = normalize
        self.img_size = img_size

        # Find all patient directories
        self.patient_dirs = sorted(glob.glob(os.path.join(data_root, "BraTS20_Training_*")))
        print(f"Found {len(self.patient_dirs)} patients")

        # Build index of all valid slices
        self.slice_index = []
        for patient_dir in self.patient_dirs:
            patient_id = os.path.basename(patient_dir)
            # Check if files exist
            img_path = os.path.join(patient_dir, f"{patient_id}_{modality}.nii")
            seg_path = os.path.join(patient_dir, f"{patient_id}_seg.nii")

            if os.path.exists(img_path) and os.path.exists(seg_path):
                # Add each slice in range
                for slice_idx in range(slice_range[0], slice_range[1]):
                    self.slice_index.append((patient_dir, patient_id, slice_idx))

        print(f"Total slices: {len(self.slice_index)}")

    def __len__(self):
        return len(self.slice_index)

    def __getitem__(self, idx):
        patient_dir, patient_id, slice_idx = self.slice_index[idx]

        # Load image and segmentation
        img_path = os.path.join(patient_dir, f"{patient_id}_{self.modality}.nii")
        seg_path = os.path.join(patient_dir, f"{patient_id}_seg.nii")

        img_nib = nib.load(img_path)
        seg_nib = nib.load(seg_path)

        img_data = img_nib.get_fdata()
        seg_data = seg_nib.get_fdata()

        # Extract 2D slice
        img_slice = img_data[:, :, slice_idx]
        seg_slice = seg_data[:, :, slice_idx]

        # Normalize image
        if self.normalize:
            img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min() + 1e-8)

        # Convert to RGB (SAM expects 3 channels)
        img_slice_rgb = np.stack([img_slice, img_slice, img_slice], axis=-1)

        # Binary segmentation: any tumor (label > 0) vs background (label == 0)
        seg_binary = (seg_slice > 0).astype(np.float32)

        # Convert to tensors
        img_tensor = torch.from_numpy(img_slice_rgb).float().permute(2, 0, 1)  # C, H, W
        seg_tensor = torch.from_numpy(seg_binary).float()  # H, W

        # Resize
        img_tensor = torch.nn.functional.interpolate(
            img_tensor.unsqueeze(0),
            size=(self.img_size, self.img_size),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)

        seg_tensor = torch.nn.functional.interpolate(
            seg_tensor.unsqueeze(0).unsqueeze(0),
            size=(self.img_size, self.img_size),
            mode='nearest'
        ).squeeze(0).squeeze(0).long()

        return img_tensor, seg_tensor

## Linear probe classifier with SAM 3 frozen features

In [None]:
class LinearProbe(nn.Module):
    """
    Simple linear classifier for segmentation.
    Takes frozen SAM3 features and predicts per-pixel class labels.
    """
    def __init__(self, feature_dim=256, num_classes=2, feature_spatial_size=(72, 72)):
        """
        Args:
            feature_dim: Dimension of SAM3 features (256 for SAM3)
            num_classes: Number of segmentation classes (2 for binary)
            feature_spatial_size: Spatial resolution of SAM3 features (72x72)
        """
        super().__init__()
        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.feature_spatial_size = feature_spatial_size

        # Simple 1x1 convolution (equivalent to per-pixel linear classifier)
        self.classifier = nn.Conv2d(feature_dim, num_classes, kernel_size=1)

    def forward(self, features):
        """
        Args:
            features: [B, C, H, W] feature maps from SAM3 (e.g., [B, 256, 72, 72])
        Returns:
            logits: [B, num_classes, H, W] per-pixel class logits
        """
        logits = self.classifier(features)
        return logits

## Feature Extractor from SAM 3

In [None]:
class SAM3FeatureExtractor:
    """
    Extracts features from frozen SAM3 encoder.
    """
    def __init__(self, sam3_model, device='cuda'):
        self.model = sam3_model
        self.device = device

        if hasattr(self.model, 'backbone'):
            for param in self.model.backbone.parameters():
                param.requires_grad = False
            for name, p in sam3_model.named_parameters():
                p.requires_grad = False
            for name, p in sam3_model.named_parameters():
                if name.startswith("segmentation_head"):
                    p.requires_grad = True
            print("Froze SAM3 backbone parameters")
        else:
            raise AttributeError("SAM3 model doesn't have 'backbone' attribute")

        self.model.eval()
        self.model.to(device)

    @torch.no_grad()
    def extract_features(self, images, captions=None):
        """
        Extract features from SAM3 backbone.

        Args:
            images: [B, 3, H, W] input images
            captions: List of text captions (SAM3 is vision-language model)
        Returns:
            features: Tensor of visual features [B, C, H, W]
        """
        # SAM3 requires captions (vision-language model)
        if captions is None:
            # Default caption for tumor segmentation
            batch_size = images.shape[0]
            captions = ["brain tumor"] * batch_size

        # SAM3 backbone returns a dictionary of features
        backbone_out = self.model.backbone(images, captions)

        # Extract the visual features from the dictionary
        if isinstance(backbone_out, dict):
            # SAM3 uses 'vision_features' key
            if 'vision_features' in backbone_out:
                return backbone_out['vision_features']
            # Fallback to other common keys
            for key in ['image_features', 'visual_features', 'features']:
                if key in backbone_out:
                    return backbone_out[key]
            # If none found, return first tensor
            for val in backbone_out.values():
                if isinstance(val, torch.Tensor) and len(val.shape) == 4:
                    return val

        return backbone_out

## Train

In [None]:
def train_linear_probe(
    train_loader,
    val_loader,
    feature_extractor,
    probe,
    num_epochs=20,
    lr=0.001,
    device='cuda'
):
    """
    Train the linear probe on frozen SAM3 features.
    """
    probe.to(device)
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    val_accuracies = []
    val_ious = []

    for epoch in range(num_epochs):
        # Training phase
        probe.train()
        epoch_loss = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)

            # Extract frozen features
            with torch.no_grad():
                features = feature_extractor.extract_features(images)

            # Forward through probe
            logits = probe(features)

            # Resize logits to match mask size if needed
            if logits.shape[-2:] != masks.shape[-2:]:
                logits = torch.nn.functional.interpolate(
                    logits, size=masks.shape[-2:], mode='bilinear', align_corners=False
                )

            # Compute loss
            loss = criterion(logits, masks)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)

        # Validation phase
        val_acc, val_iou = evaluate_probe(val_loader, feature_extractor, probe, device)
        val_accuracies.append(val_acc)
        val_ious.append(val_iou)

        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Val Acc={val_acc:.4f}, Val IoU={val_iou:.4f}")

    return {
        'train_losses': train_losses,
        'val_accuracies': val_accuracies,
        'val_ious': val_ious
    }

## Evaluate

In [None]:
def evaluate_probe(data_loader, feature_extractor, probe, device='cuda'):
    """
    Evaluate the linear probe on a dataset.
    """
    probe.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, masks in tqdm(data_loader, desc="Evaluating"):
            images = images.to(device)
            masks = masks.to(device)

            # Extract features
            features = feature_extractor.extract_features(images)

            # Predict
            logits = probe(features)

            # Resize if needed
            if logits.shape[-2:] != masks.shape[-2:]:
                logits = torch.nn.functional.interpolate(
                    logits, size=masks.shape[-2:], mode='bilinear', align_corners=False
                )

            preds = torch.argmax(logits, dim=1)

            all_preds.append(preds.cpu().numpy().flatten())
            all_targets.append(masks.cpu().numpy().flatten())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    accuracy = accuracy_score(all_targets, all_preds)
    iou = jaccard_score(all_targets, all_preds, average='binary')

    return accuracy, iou

## Plots

In [None]:
def plot_results(history, save_path='linear_probe_results.png'):
    """
    Plot training history.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Loss
    axes[0].plot(history['train_losses'])
    axes[0].set_title('Training Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].grid(True)

    # Accuracy
    axes[1].plot(history['val_accuracies'])
    axes[1].set_title('Validation Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].grid(True)

    # IoU
    axes[2].plot(history['val_ious'])
    axes[2].set_title('Validation IoU')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('IoU')
    axes[2].grid(True)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Results saved to {save_path}")

### Parameters

In [None]:
data_root = "/content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM 3/data/tumor/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
bpe_path = "/content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM 3/sam3_repo/sam3/assets/bpe_simple_vocab_16e6.txt.gz"

batch_size = 8
num_epochs = 1
lr = 0.001
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {device}")

In [None]:
import torch, gc

torch.cuda.empty_cache()
gc.collect()
torch.cuda.reset_peak_memory_stats()


### Load model

In [None]:
print("Loading SAM3 model...")
sam3_model = build_sam3_image_model(bpe_path=bpe_path)
feature_extractor = SAM3FeatureExtractor(sam3_model, device=device)
print("SAM3 encoder frozen!")

In [None]:
sum(p.numel() for p in feature_extractor.model.parameters() if p.requires_grad)

### Load and prepare dataset

In [None]:
print("Loading datasets...")
full_dataset = BraTSDataset(data_root, modality='t1ce')

In [None]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train: {len(train_dataset)} slices, Val: {len(val_dataset)} slices")

In [None]:
feature_dim = 256
probe = LinearProbe(feature_dim=feature_dim, num_classes=2)

### Main training loop

In [None]:
history = train_linear_probe(
    train_loader=train_loader,
    val_loader=val_loader,
    feature_extractor=feature_extractor,
    probe=probe,
    num_epochs=num_epochs,
    lr=lr,
    device=device
)

### Results

In [3]:
plot_results(history)

NameError: name 'plot_results' is not defined

In [None]:
torch.save(probe.state_dict(), 'linear_probe_brats.pth')
print("Linear probe saved!")

In [None]:
!ls -l /content/drive/MyDrive/MSc_Thesis_Neuroimaging/SAM\ 3/data/tumor/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001

total 78484
-rw------- 1 root root 17858880 Jan 26 15:36 BraTS20_Training_001_flair.nii
-rw------- 1 root root  8930976 Jan 26 15:36 BraTS20_Training_001_seg.nii
-rw------- 1 root root 17858880 Jan 26 15:36 BraTS20_Training_001_t1ce.nii
-rw------- 1 root root 17858880 Jan 26 15:36 BraTS20_Training_001_t1.nii
-rw------- 1 root root 17858880 Jan 26 15:36 BraTS20_Training_001_t2.nii


In [None]:
print("\nDetermining SAM3 feature dimensions...")
dummy_image = torch.randn(1, 3, 1008, 1008).to(device)
dummy_captions = ["tumor"]

with torch.no_grad():
    sam3_model.eval()
    features = sam3_model.backbone(dummy_image, dummy_captions)

    print(f"Feature output type: {type(features)}")

    if isinstance(features, dict):
        print(f"Feature keys: {list(features.keys())}")
        for key, val in features.items():
            if isinstance(val, torch.Tensor):
                print(f"  {key}: shape {val.shape}")

        possible_keys = ['vision_features', 'image_features', 'visual_features', 'features', 'vision_embedding']
        feature_tensor = None
        feature_key = None
        for key in possible_keys:
            if key in features and isinstance(features[key], torch.Tensor):
                feature_tensor = features[key]
                feature_key = key
                print(f"\nUsing '{key}' as main features: {feature_tensor.shape}")
                break

        if feature_tensor is None:
            for key, val in features.items():
                if isinstance(val, torch.Tensor):
                    feature_tensor = val
                    feature_key = key
                    print(f"\nUsing '{key}' as main features: {feature_tensor.shape}")
                    break

        if feature_tensor is not None:
            if len(feature_tensor.shape) == 4:  # [B, C, H, W]
                feature_dim = feature_tensor.shape[1]
                print(f"\n✓ Feature dimension: {feature_dim}")
                print(f"✓ Feature spatial size: {feature_tensor.shape[2]}x{feature_tensor.shape[3]}")
            elif len(feature_tensor.shape) == 3:  # [B, N, C] sequence format
                feature_dim = feature_tensor.shape[2]
                print(f"\n✓ Feature dimension: {feature_dim}")
                print(f"✓ Sequence length: {feature_tensor.shape[1]}")

    elif isinstance(features, torch.Tensor):
        print(f"Feature shape: {features.shape}")
        if len(features.shape) == 4:
            feature_dim = features.shape[1]
            print(f"\n✓ Feature dimension: {feature_dim}")
            print(f"✓ Feature spatial size: {features.shape[2]}x{features.shape[3]}")

print(f"\nSet feature_dim = {feature_dim}")


Determining SAM3 feature dimensions...


OutOfMemoryError: CUDA out of memory. Tried to allocate 12.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 12.12 MiB is free. Process 239277 has 14.73 GiB memory in use. Of the allocated memory 14.29 GiB is allocated by PyTorch, and 318.92 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)