In [1]:
import wilds
print(wilds.__version__)
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim

torch.backends.cudnn.benchmark = True
from tqdm.notebook import tqdm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
import wandb


2.0.0


In [2]:
!nvidia-smi

Fri Apr 11 04:52:01 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   27C    P0             72W /  700W |       4MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00

In [3]:
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="opent03-team",
    # Set the wandb project where this run will be logged.
    project="wilds_dpddm",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": 0.001,
        "architecture": "resnet50",
        "dataset": "Camelyon17",
        "epochs": 50,
    },
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mopent03[0m ([33mopent03-team[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
dataset = Camelyon17Dataset(root_dir='/h/300/viet/bayesian_dpddm/data/', download=False)

In [5]:
splits = {
    'train': 'train',
    'valid': 'val',
    'dpddm_train': 'id_val',
    'dpddm_id': 'val',
    'dpddm_ood': 'test'
}
dataset_dict = {}

In [6]:
for split in splits:
    dataset_dict[split] = dataset.get_subset(split=splits[split], transform=Compose([Resize((224, 224)), ToTensor()]))

In [7]:
BATCH_SIZE=512
EPOCHS=50
LEARNING_RATE=1e-3

In [8]:
#trainloader = get_train_loader("standard", dataset_dict['train'], batch_size=BATCH_SIZE)
#valloader = get_eval_loader("standard", dataset_dict['valid'], batch_size=BATCH_SIZE)
trainloader = DataLoader(dataset_dict['train'], batch_size=BATCH_SIZE, pin_memory=True, num_workers=10)
valloader = DataLoader(dataset_dict['valid'], batch_size=BATCH_SIZE, pin_memory=True, num_workers=10)


In [9]:
model = resnet50()
model.fc = nn.Linear(in_features=2048, out_features=1)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

model = nn.DataParallel(model)
model = model.to(device)

In [10]:
def train(model, loader, optimizer, criterion):
    model.train()
    losses = 0
    for images, labels, _ in tqdm(loader, leave=False):
        optimizer.zero_grad()
        images, labels = images.to(device), labels.unsqueeze(1).float().to(device)
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        losses += loss.item() * len(labels)
        optimizer.step()
    return losses
        
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels, _ in tqdm(loader, leave=False):
            images, labels = images.to(device), labels.unsqueeze(1).to(device)
            output = model(images)
            preds = output > 0.5
            correct += torch.sum(preds == labels).item()
            total += len(labels)
    return correct / total

In [None]:
train_losses = []
train_accs = []
valid_accs = []

for e in tqdm(range(EPOCHS)):
    loss = train(model, trainloader, optimizer, criterion)
    train_acc = evaluate(model, trainloader)
    valid_acc = evaluate(model, valloader)
    
    train_losses.append(loss)
    train_accs.append(train_acc)
    valid_accs.append(valid_acc)
    wandb.log({
        'train_loss': loss,
        'train_acc': train_acc,
        'valid_acc': valid_acc
    })

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/591 [00:00<?, ?it/s]

In [8]:
!nvidia-smi

Thu Apr 10 23:25:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A5000               Off |   00000000:3B:00.0 Off |                  Off |
| 30%   19C    P8             17W /  230W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               Off |   00