In [1]:
import wilds
print(wilds.__version__)
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from torchvision.transforms import Compose, ToTensor, Resize
from tqdm.auto import tqdm

2.0.0


In [9]:
!nvidia-smi

Mon Apr 14 15:25:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:31:00.0 Off |                  Off |
| 30%   35C    P0             71W /  300W |       1MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               Off |   00

In [3]:
dataset = Camelyon17Dataset(root_dir='/voyager/datasets', download=True)

In [4]:
""" For Bayesian D-PDDM
- train         (used to train base model)
- valid         (used ot validate base model)
- dpddm_train   (used to train dpddm's Phi)
- dpddm_id      (used to validate FPR)
- dpddm_ood     (used to validate TPR)
"""

splits = {
    'train': 'train',
    'valid': 'val',
    'dpddm_train': 'id_val',
    'dpddm_id': 'val',
    'dpddm_ood': 'test'
}

In [7]:
for split in splits.keys():
    try: 
        ds = dataset.get_subset(splits[split], transform=Compose([Resize((224, 224)), ToTensor()]))
        print(split, len(ds))
    except:
        continue

train 302436
valid 34904
dpddm_train 33560
dpddm_id 34904
dpddm_ood 85054


In [8]:
ds = dataset.get_subset(splits['train'], transform=Compose([Resize((224, 224)), ToTensor()]))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import wandb 

run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="opent03-team",
    # Set the wandb project where this run will be logged.
    project="wilds_dpddm",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": 0.001,
        "architecture": "resnet18",
        "dataset": "Camelyon17",
        "epochs": 50,
    },
)


# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize WILDS dataset
dataset = Camelyon17Dataset(root_dir='/h/300/viet/bayesian_dpddm/data', download=True)

# Get train, validation, and test sets
train_data = dataset.get_subset('train', frac=0.1)
val_data = dataset.get_subset('val',  frac=0.1)
test_data = dataset.get_subset('test',  frac=0.1)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# For validation/test (no augmentations)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Apply transformations
train_data.transform = train_transform
val_data.transform = val_transform
test_data.transform = val_transform

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=4, pin_memory=True)

# Initialize ResNet18
model = resnet50(pretrained=False)  # Using pretrained weights

# Modify the final layer for binary classification (tumor vs normal)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # Camelyon17 has 2 classes

model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Training phase
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        x, y, metadata = batch
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total
    
    wandb.log({
        'train_loss': train_loss,
        'train_acc': train_acc
    })
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            x, y, metadata = batch
            x, y = x.to(device), y.to(device)
            
            outputs = model(x)
            loss = criterion(outputs, y)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    
    val_loss /= len(val_loader)
    val_acc = 100 * correct / total
    
    print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    wandb.log({
        'val_loss': val_loss,
        'val_acc': val_acc
    })

    # Test evaluation
    model.eval()
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for batch in test_loader:
            x, y, metadata = batch
            x, y = x.to(device), y.to(device)
            
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            test_total += y.size(0)
            test_correct += (predicted == y).sum().item()
            
    test_acc = 100 * test_correct / test_total
    wandb.log({
        'test_acc': test_acc
    })
    print(f'Test Accuracy: {test_acc:.2f}%')

0,1
train_acc,▁▆█
train_loss,█▃▁
val_acc,▃█▁
val_loss,▁▅█

0,1
train_acc,94.67332
train_loss,0.14796
val_acc,77.47851
val_loss,0.93549


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████████████████████| 946/946 [00:41<00:00, 22.84it/s]


Epoch 1: Train Loss: 0.3420, Train Acc: 86.40%, Val Loss: 0.4509, Val Acc: 82.87%
Test Accuracy: 85.83%


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████████████████████| 946/946 [00:40<00:00, 23.52it/s]


Epoch 2: Train Loss: 0.2180, Train Acc: 91.75%, Val Loss: 0.4549, Val Acc: 87.22%
Test Accuracy: 88.09%


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████████████████████| 946/946 [00:39<00:00, 23.67it/s]


Epoch 3: Train Loss: 0.1801, Train Acc: 93.30%, Val Loss: 0.5520, Val Acc: 83.98%
Test Accuracy: 75.77%


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████████████████████| 946/946 [00:40<00:00, 23.57it/s]


Epoch 4: Train Loss: 0.1697, Train Acc: 93.62%, Val Loss: 0.3734, Val Acc: 88.02%
Test Accuracy: 80.34%


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████████████████████| 946/946 [00:40<00:00, 23.60it/s]


Epoch 5: Train Loss: 0.1589, Train Acc: 94.12%, Val Loss: 0.3795, Val Acc: 87.85%
Test Accuracy: 83.22%


Epoch 6/10:   5%|███▉                                                                                  | 43/946 [00:02<00:37, 23.82it/s]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import resnet50
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import wandb
import math

class VBLLayer(nn.Module):
    def __init__(self, in_features, out_features, prior_precision=1.0):
        super(VBLLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Mean parameters
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.Tensor(out_features))
        
        # Log variance parameters (for numerical stability)
        self.weight_logvar = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias_logvar = nn.Parameter(torch.Tensor(out_features))
        
        # Prior precision
        self.prior_precision = prior_precision
        
        # Initialize parameters
        self.reset_parameters()
    
    def reset_parameters(self):
        # Initialize means
        nn.init.kaiming_uniform_(self.weight_mu, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_mu)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias_mu, -bound, bound)
        
        # Initialize logvars
        nn.init.constant_(self.weight_logvar, -6)
        nn.init.constant_(self.bias_logvar, -6)
    
    def forward(self, x):
        # Sample weights using reparameterization trick
        weight_eps = torch.randn_like(self.weight_logvar)
        bias_eps = torch.randn_like(self.bias_logvar)
        
        weight = self.weight_mu + torch.exp(0.5 * self.weight_logvar) * weight_eps
        bias = self.bias_mu + torch.exp(0.5 * self.bias_logvar) * bias_eps
        
        return F.linear(x, weight, bias)
    
    def kl_divergence(self):
        # Compute KL divergence between posterior and prior
        weight_kl = 0.5 * (self.prior_precision * torch.exp(self.weight_logvar).sum() +
                          self.prior_precision * (self.weight_mu**2).sum() -
                          self.weight_logvar.sum() - self.in_features * self.out_features)
        
        bias_kl = 0.5 * (self.prior_precision * torch.exp(self.bias_logvar).sum() +
                         self.prior_precision * (self.bias_mu**2).sum() -
                         self.bias_logvar.sum() - self.out_features)
        
        return weight_kl + bias_kl

# Initialize wandb
run = wandb.init(
    entity="opent03-team",
    project="wilds_dpddm",
    config={
        "learning_rate": 0.001,
        "architecture": "resnet50_vbll",
        "dataset": "Camelyon17",
        "epochs": 50,
        "prior_precision": 1.0,
        "kl_weight": 0.1,
    },
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize WILDS dataset
dataset = Camelyon17Dataset(root_dir='/h/300/viet/bayesian_dpddm/data', download=True)

# Get train, validation, and test sets
train_data = dataset.get_subset('train', frac=0.1)
val_data = dataset.get_subset('val', frac=0.1)
test_data = dataset.get_subset('test', frac=0.1)

# Define transforms (same as before)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Apply transformations
train_data.transform = train_transform
val_data.transform = val_transform
test_data.transform = val_transform

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=4, pin_memory=True)

# Initialize ResNet50 with VBLL last layer
model = resnet50(pretrained=False)
num_features = model.fc.in_features

# Replace the final layer with VBLL
model.fc = VBLLayer(num_features, 2, prior_precision=wandb.config.prior_precision)
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

# Training loop
num_epochs = wandb.config.epochs
kl_weight = wandb.config.kl_weight  # Weight for KL term in ELBO

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_nll = 0.0
    running_kl = 0.0
    correct = 0
    total = 0
    
    # Training phase
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        x, y, metadata = batch
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        outputs = model(x)
        
        # Negative log likelihood
        nll = criterion(outputs, y)
        
        # KL divergence (only from the VBLL layer)
        kl = model.fc.kl_divergence()
        
        # ELBO loss
        loss = nll + kl_weight * kl
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_nll += nll.item()
        running_kl += kl.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
    
    train_loss = running_loss / len(train_loader)
    train_nll = running_nll / len(train_loader)
    train_kl = running_kl / len(train_loader)
    train_acc = 100 * correct / total
    
    wandb.log({
        'train_loss': train_loss,
        'train_nll': train_nll,
        'train_kl': train_kl,
        'train_acc': train_acc
    })
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_nll = 0.0
    val_kl = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            x, y, metadata = batch
            x, y = x.to(device), y.to(device)
            
            outputs = model(x)
            nll = criterion(outputs, y)
            kl = model.fc.kl_divergence()
            loss = nll + kl_weight * kl
            
            val_loss += loss.item()
            val_nll += nll.item()
            val_kl += kl.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    
    val_loss /= len(val_loader)
    val_nll /= len(val_loader)
    val_kl /= len(val_loader)
    val_acc = 100 * correct / total
    
    print(f'Epoch {epoch+1}: '
          f'Train Loss: {train_loss:.4f} (NLL: {train_nll:.4f}, KL: {train_kl:.4f}), '
          f'Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f} (NLL: {val_nll:.4f}, KL: {val_kl:.4f}), '
          f'Val Acc: {val_acc:.2f}%')
    
    wandb.log({
        'val_loss': val_loss,
        'val_nll': val_nll,
        'val_kl': val_kl,
        'val_acc': val_acc
    })

    # Test evaluation
    model.eval()
    test_loss = 0.0
    test_nll = 0.0
    test_kl = 0.0
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for batch in test_loader:
            x, y, metadata = batch
            x, y = x.to(device), y.to(device)
            
            outputs = model(x)
            nll = criterion(outputs, y)
            kl = model.fc.kl_divergence()
            loss = nll + kl_weight * kl
            
            test_loss += loss.item()
            test_nll += nll.item()
            test_kl += kl.item()
            
            _, predicted = torch.max(outputs.data, 1)
            test_total += y.size(0)
            test_correct += (predicted == y).sum().item()
            
    test_loss /= len(test_loader)
    test_nll /= len(test_loader)
    test_kl /= len(test_loader)
    test_acc = 100 * test_correct / test_total
    
    wandb.log({
        'test_loss': test_loss,
        'test_nll': test_nll,
        'test_kl': test_kl,
        'test_acc': test_acc
    })
    
    print(f'Test Accuracy: {test_acc:.2f}%, '
          f'Test Loss: {test_loss:.4f} (NLL: {test_nll:.4f}, KL: {test_kl:.4f})')

# Close wandb run
wandb.finish()