In [None]:
%%capture
!pip install wilds
!apt-get install libmagickwand-dev
!pip install wand

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
!mkdir /content/data/
!mkdir /content/data/camelyon17_v1.0/
!tar -xvzf /content/drive/MyDrive/clear-vae/data/camelyon17_v1.0/archive.tar.gz -C /content/data/camelyon17_v1.0/

In [None]:
import sys
import json

sys.path.append('/content/drive/MyDrive/clear-vae')

In [None]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
dataset = get_dataset(dataset="camelyon17", download=False)

In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_data = dataset.get_subset(
    "train",
    transform=transforms.Compose(
        [transforms.Resize((64, 64)), transforms.ToTensor()]
    ),
)
train_data, valid_data = random_split(train_data, [0.8, 0.2])

test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((64, 64)), transforms.ToTensor()]
    ),
)

def collate_fn(batch):
    x, y, meta = zip(*batch)
    return torch.stack(x), torch.stack(y), torch.stack(meta)[:,0]

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [None]:
import torch.nn as nn
import numpy as np
from src.trainer import DownstreamMLPTrainer, SimpleCNNTrainer
from src.utils.trainer_utils import *
import os

In [None]:
def init_weights(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        # He uniform initialization
        torch.nn.init.kaiming_uniform_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)

In [None]:
def experiment_helper(
    dataloaders,
    vae_trainer,
    epochs,
):
    train_loader = dataloaders["train_loader"]
    valid_loader = dataloaders["valid_loader"]
    test_loader = dataloaders["test_loader"]
    # Train VAE
    vae_trainer.fit(epochs, train_loader, valid_loader)
    vae_trainer.model.eval()

    z_dim = vae_trainer.model.z_dim  # which is 1/2 * total_z_dim
    device = vae_trainer.device
    # Freeze VAE parameters
    for p in vae_trainer.model.parameters():
        p.requires_grad = False

    # Create and train MLP classifier
    mlp = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Linear(256, 10),
    ).to(device)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()

    trainer = DownstreamMLPTrainer(
        vae_trainer.model, mlp, optimizer, criterion, 1, device
    )
    trainer.fit(1, train_loader, valid_loader)

    # Evaluate on test set
    (aupr_scores, auroc_scores), acc = trainer.evaluate(test_loader, False, 0)

    return aupr_scores, auroc_scores, acc


def experiment(dataloaders, trainer_kwargs, epochs, run_index,
               save_root="/content/drive/MyDrive/clear-vae/expr_output/camelyon17/result"):
    train_loader = dataloaders["train_loader"]
    valid_loader = dataloaders["valid_loader"]
    test_loader = dataloaders["test_loader"]

    tc_trainer_kwargs = trainer_kwargs.copy()
    tc_trainer_kwargs["temperature"] = 0.1

    models = {
        "baseline": (
            get_cnn_trainer,
            {"n_class": 2, "device": trainer_kwargs["device"], "cnn_arch": "SimpleCNN64Classifier",
             "in_channel": trainer_kwargs["in_channel"], "verbose_period": trainer_kwargs["verbose_period"]},
        ),
        "clear": (
            get_clearvae_trainer,
            {"ps": True, **trainer_kwargs},
        ),
        "clear-mim (L1OutUB)": (
            get_clearmimvae_trainer,
            {
                "mi_estimator": "L1OutUB",
                "la": 3,
                "mi_estimator_lr": 2e-3,
                **trainer_kwargs,
            },
        ),
        "clear-mim (CLUB-S)": (
            get_clearmimvae_trainer,
            {
                "mi_estimator": "CLUBSample",
                "la": 3,
                "mi_estimator_lr": 2e-3,
                **trainer_kwargs,
            },
        ),
        "clear-tc": (
            get_cleartcvae_trainer,
            {"la": 1, "factor_cls_lr": 5e-4, **tc_trainer_kwargs},
        ),
        "lamcnn": (
            get_lamcnn_trainer,
            {"n_class": 2, "lam_coef": 0.001,
             "device": trainer_kwargs["device"], "cnn_arch": "LAMCNN64Classifier",
             "in_channel": trainer_kwargs["in_channel"], "verbose_period": trainer_kwargs["verbose_period"]},
        ),
        "gvae": (
            get_hierarchical_vae_trainer,
            {
                "beta": trainer_kwargs["beta"],
                "vae_lr": 1e-4,
                "vae_arch": trainer_kwargs["vae_arch"],
                "z_dim": trainer_kwargs["z_dim"],
                "group_mode": "GVAE",
                "device": trainer_kwargs["device"],
                "in_channel": trainer_kwargs["in_channel"],
                "verbose_period": trainer_kwargs["verbose_period"],
            },
        ),
        "mlvae": (
            get_hierarchical_vae_trainer,
            {
                "beta": trainer_kwargs["beta"],
                "vae_lr": 1e-4,
                "vae_arch": trainer_kwargs["vae_arch"],
                "z_dim": trainer_kwargs["z_dim"],
                "group_mode": "MLVAE",
                "device": trainer_kwargs["device"],
                "in_channel": trainer_kwargs["in_channel"],
                "verbose_period": trainer_kwargs["verbose_period"],
            },
        ),
    }

    filename = f"{save_root}_{run_index}.json"
    if os.path.exists(filename):
        with open(filename, "r") as infile:
            results = json.load(infile) # Load existing data
    else:
        results = {}

    for model_name, (trainer_func, params) in models.items():
        print(f"\nTraining {model_name}:")
        trainer = trainer_func(**params)
        trainer.model.apply(init_weights)

        if isinstance(trainer, SimpleCNNTrainer):
            trainer.fit(
                epochs=6, train_loader=train_loader, valid_loader=valid_loader
            )
            (aupr_scores, auroc_scores), acc = trainer.evaluate(test_loader, False, 0)
        else:
            aupr_scores, auroc_scores, acc = experiment_helper(
                dataloaders, trainer, epochs
            )
        temp = {
            "acc": round(float(acc), 3),
            "pr": {
                "overall": round(np.mean(list(aupr_scores.values())), 3),
                "stratified": aupr_scores,
            },
            "roc": {
                "overall": round(np.mean(list(auroc_scores.values())), 3),
                "stratified": auroc_scores,
            },
        }

        results[model_name] = temp
        with open(filename, "w") as outfile:
            json.dump(results, outfile, indent=4)

    return results

In [None]:
dataloaders = {
    "train_loader": train_loader,
    "valid_loader": valid_loader,
    "test_loader": test_loader,
}
trainer_kwargs = {
    "in_channel": 3,
    "vae_arch": "VAE64",
    "beta": 1 / 32,
    "vae_lr": 1e-4,
    "z_dim": 64,
    "alpha": 100,
    "temperature": 0.3,
    "device": device,
    "verbose_period": 2
}

In [15]:
for i in range(1):
    s = np.random.randint(1000)
    np.random.seed(s)
    torch.manual_seed(s)
    rlt = experiment(dataloaders, trainer_kwargs, 7, run_index=s)



Training baseline:


epoch 0: 100%|██████████| 3781/3781 [04:32<00:00, 13.90batch/s, loss=0.0921]
val-epoch 0: 100%|██████████| 946/946 [01:03<00:00, 14.82it/s]


val_aupr: {0: np.float64(0.99), 1: np.float64(0.992)}
0.991
val_auroc: {0: np.float64(0.991), 1: np.float64(0.991)}
0.991
val_acc: 0.958


epoch 2: 100%|██████████| 3781/3781 [04:31<00:00, 13.91batch/s, loss=0.271]
val-epoch 2: 100%|██████████| 946/946 [01:03<00:00, 14.88it/s]


val_aupr: {0: np.float64(0.994), 1: np.float64(0.996)}
0.995
val_auroc: {0: np.float64(0.995), 1: np.float64(0.995)}
0.995
val_acc: 0.968


epoch 4: 100%|██████████| 3781/3781 [04:30<00:00, 13.98batch/s, loss=0.0305]
val-epoch 4: 100%|██████████| 946/946 [01:03<00:00, 14.89it/s]


val_aupr: {0: np.float64(0.992), 1: np.float64(0.995)}
0.994
val_auroc: {0: np.float64(0.994), 1: np.float64(0.994)}
0.994
val_acc: 0.968

Training clear:


Epoch 0: 100%|██████████| 3781/3781 [05:13<00:00, 12.08batch/s, c_loss=0.0917, kl_c=64.1, kl_s=78.1, recontr_loss=287, s_loss=0.69]
val-epoch 0: 100%|██████████| 946/946 [01:07<00:00, 14.09it/s]


val_recontr_loss=283.632, val_kl_c=65.766, val_kl_s=80.368, val_c_loss=0.193, val_s_loss=0.705
gMIG: 0.316; mse: 283.632


Epoch 2: 100%|██████████| 3781/3781 [05:17<00:00, 11.90batch/s, c_loss=0.257, kl_c=73.8, kl_s=75.8, recontr_loss=200, s_loss=0.804]
val-epoch 2: 100%|██████████| 946/946 [01:07<00:00, 14.02it/s]


val_recontr_loss=139.627, val_kl_c=78.536, val_kl_s=75.521, val_c_loss=0.145, val_s_loss=0.718
gMIG: 0.468; mse: 139.627


Epoch 4: 100%|██████████| 3781/3781 [05:17<00:00, 11.91batch/s, c_loss=0.111, kl_c=81.5, kl_s=75.4, recontr_loss=147, s_loss=0.67]
val-epoch 4: 100%|██████████| 946/946 [01:07<00:00, 13.98it/s]


val_recontr_loss=126.443, val_kl_c=82.323, val_kl_s=77.141, val_c_loss=0.101, val_s_loss=0.714
gMIG: 0.478; mse: 126.443


Epoch 6: 100%|██████████| 3781/3781 [05:17<00:00, 11.91batch/s, c_loss=0.281, kl_c=92.1, kl_s=76.5, recontr_loss=130, s_loss=0.736]
val-epoch 6: 100%|██████████| 946/946 [01:07<00:00, 13.99it/s]


val_recontr_loss=122.867, val_kl_c=87.889, val_kl_s=76.984, val_c_loss=0.103, val_s_loss=0.715
gMIG: 0.509; mse: 122.867


epoch 0: 100%|██████████| 3781/3781 [04:26<00:00, 14.18batch/s, loss=0.0668]
val-epoch 0: 100%|██████████| 946/946 [01:04<00:00, 14.63it/s]


val_aupr: {0: np.float64(0.995), 1: np.float64(0.996)}
0.996
val_auroc: {0: np.float64(0.996), 1: np.float64(0.996)}
0.996
val_acc: 0.974

Training clear-mim (L1OutUB):


Epoch 0: 100%|██████████| 3781/3781 [06:51<00:00,  9.18batch/s, c_loss=0.331, kl_c=55.6, kl_s=63.3, mi_loss=0.746, recontr_loss=338]
val-epoch 0: 100%|██████████| 946/946 [01:07<00:00, 13.96it/s]


val_recontr_loss=343.189, val_kl_c=64.877, val_kl_s=64.825, val_c_loss=0.177, val_mi_loss=1.037
gMIG: 0.38; mse: 343.189


Epoch 2: 100%|██████████| 3781/3781 [06:51<00:00,  9.19batch/s, c_loss=0.0437, kl_c=73.6, kl_s=75.5, mi_loss=0.0477, recontr_loss=161]
val-epoch 2: 100%|██████████| 946/946 [01:07<00:00, 14.01it/s]


val_recontr_loss=165.260, val_kl_c=79.039, val_kl_s=75.270, val_c_loss=0.114, val_mi_loss=0.601
gMIG: 0.495; mse: 165.26


Epoch 4: 100%|██████████| 3781/3781 [06:51<00:00,  9.18batch/s, c_loss=0.199, kl_c=78.6, kl_s=78.4, mi_loss=0.51, recontr_loss=123]
val-epoch 4: 100%|██████████| 946/946 [01:07<00:00, 13.97it/s]


val_recontr_loss=125.490, val_kl_c=80.862, val_kl_s=79.044, val_c_loss=0.094, val_mi_loss=0.645
gMIG: 0.428; mse: 125.49


Epoch 6: 100%|██████████| 3781/3781 [06:52<00:00,  9.18batch/s, c_loss=0.0227, kl_c=79.2, kl_s=76.1, mi_loss=0.0903, recontr_loss=110]
val-epoch 6: 100%|██████████| 946/946 [01:07<00:00, 13.98it/s]


val_recontr_loss=123.564, val_kl_c=82.031, val_kl_s=75.552, val_c_loss=0.097, val_mi_loss=0.051
gMIG: 0.402; mse: 123.564


epoch 0: 100%|██████████| 3781/3781 [04:26<00:00, 14.19batch/s, loss=0.0135]
val-epoch 0: 100%|██████████| 946/946 [01:03<00:00, 14.83it/s]


val_aupr: {0: np.float64(0.994), 1: np.float64(0.996)}
0.995
val_auroc: {0: np.float64(0.995), 1: np.float64(0.995)}
0.995
val_acc: 0.972

Training clear-mim (CLUB-S):


Epoch 0: 100%|██████████| 3781/3781 [06:48<00:00,  9.26batch/s, c_loss=0.324, kl_c=38.9, kl_s=45.9, mi_loss=0.81, recontr_loss=395]
val-epoch 0: 100%|██████████| 946/946 [01:07<00:00, 14.05it/s]


val_recontr_loss=375.601, val_kl_c=44.715, val_kl_s=50.622, val_c_loss=0.305, val_mi_loss=1.529
gMIG: 0.176; mse: 375.601


Epoch 2: 100%|██████████| 3781/3781 [06:48<00:00,  9.26batch/s, c_loss=0.215, kl_c=42.8, kl_s=59.4, mi_loss=0.279, recontr_loss=159]
val-epoch 2: 100%|██████████| 946/946 [01:07<00:00, 14.11it/s]


val_recontr_loss=175.860, val_kl_c=46.510, val_kl_s=61.529, val_c_loss=0.218, val_mi_loss=0.742
gMIG: 0.289; mse: 175.86


Epoch 4: 100%|██████████| 3781/3781 [06:48<00:00,  9.26batch/s, c_loss=0.268, kl_c=54.6, kl_s=68.8, mi_loss=0.732, recontr_loss=143]
val-epoch 4: 100%|██████████| 946/946 [01:07<00:00, 14.03it/s]


val_recontr_loss=148.325, val_kl_c=55.558, val_kl_s=68.469, val_c_loss=0.168, val_mi_loss=0.601
gMIG: 0.375; mse: 148.325


Epoch 6: 100%|██████████| 3781/3781 [06:43<00:00,  9.37batch/s, c_loss=0.0755, kl_c=77.6, kl_s=74.9, mi_loss=-0.0814, recontr_loss=136]
val-epoch 6: 100%|██████████| 946/946 [01:06<00:00, 14.16it/s]


val_recontr_loss=128.278, val_kl_c=81.469, val_kl_s=72.009, val_c_loss=0.114, val_mi_loss=0.064
gMIG: 0.516; mse: 128.278


epoch 0: 100%|██████████| 3781/3781 [04:24<00:00, 14.31batch/s, loss=0.111]
val-epoch 0: 100%|██████████| 946/946 [01:03<00:00, 14.87it/s]


val_aupr: {0: np.float64(0.993), 1: np.float64(0.994)}
0.994
val_auroc: {0: np.float64(0.994), 1: np.float64(0.994)}
0.994
val_acc: 0.966

Training lamcnn:


epoch 0: 100%|██████████| 3781/3781 [04:54<00:00, 12.82batch/s, ce_loss=0.157, lam_loss=0.709]
val-epoch 0: 100%|██████████| 946/946 [01:03<00:00, 15.00it/s]


val_aupr: {0: np.float64(0.992), 1: np.float64(0.993)}
0.992
val_auroc: {0: np.float64(0.993), 1: np.float64(0.993)}
0.993
val_acc: 0.962


epoch 2: 100%|██████████| 3781/3781 [04:54<00:00, 12.85batch/s, ce_loss=0.058, lam_loss=0.636]
val-epoch 2: 100%|██████████| 946/946 [01:03<00:00, 15.01it/s]


val_aupr: {0: np.float64(0.995), 1: np.float64(0.996)}
0.996
val_auroc: {0: np.float64(0.996), 1: np.float64(0.996)}
0.996
val_acc: 0.969


epoch 4: 100%|██████████| 3781/3781 [04:54<00:00, 12.85batch/s, ce_loss=0.331, lam_loss=0.967]
val-epoch 4: 100%|██████████| 946/946 [01:02<00:00, 15.03it/s]


val_aupr: {0: np.float64(0.992), 1: np.float64(0.993)}
0.992
val_auroc: {0: np.float64(0.993), 1: np.float64(0.993)}
0.993
val_acc: 0.96

Training gvae:


epoch 0: 100%|██████████| 3781/3781 [05:03<00:00, 12.44batch/s, kl_c=60.6, kl_s=1.15e+3, reconstr_loss=4.19e+3]
val-epoch 0: 100%|██████████| 946/946 [01:04<00:00, 14.65it/s]


val_recontr_loss=278.514, val_kl_c=84.839, val_kl_s=80.459
gMIG: -0.001; mse: 278.514


epoch 2: 100%|██████████| 3781/3781 [05:04<00:00, 12.43batch/s, kl_c=67.1, kl_s=1.1e+3, reconstr_loss=2.04e+3]
val-epoch 2: 100%|██████████| 946/946 [01:05<00:00, 14.52it/s]


val_recontr_loss=149.989, val_kl_c=76.369, val_kl_s=81.860
gMIG: 0.051; mse: 149.989


epoch 4: 100%|██████████| 3781/3781 [05:04<00:00, 12.42batch/s, kl_c=70.6, kl_s=1.14e+3, reconstr_loss=2.03e+3]
val-epoch 4: 100%|██████████| 946/946 [01:05<00:00, 14.54it/s]


val_recontr_loss=136.634, val_kl_c=77.745, val_kl_s=81.057
gMIG: 0.125; mse: 136.634


epoch 6: 100%|██████████| 3781/3781 [05:03<00:00, 12.45batch/s, kl_c=69.5, kl_s=1.14e+3, reconstr_loss=2.07e+3]
val-epoch 6: 100%|██████████| 946/946 [01:04<00:00, 14.55it/s]


val_recontr_loss=136.613, val_kl_c=81.123, val_kl_s=85.059
gMIG: 0.128; mse: 136.613


epoch 0: 100%|██████████| 3781/3781 [04:24<00:00, 14.29batch/s, loss=0.38]
val-epoch 0: 100%|██████████| 946/946 [01:03<00:00, 14.83it/s]


val_aupr: {0: np.float64(0.928), 1: np.float64(0.931)}
0.93
val_auroc: {0: np.float64(0.931), 1: np.float64(0.931)}
0.931
val_acc: 0.857

Training mlvae:


epoch 0: 100%|██████████| 3781/3781 [05:09<00:00, 12.23batch/s, kl_c=50.2, kl_s=1.17e+3, reconstr_loss=4.06e+3]
val-epoch 0: 100%|██████████| 946/946 [01:05<00:00, 14.55it/s]


val_recontr_loss=331.781, val_kl_c=6349.901, val_kl_s=82.411
gMIG: -0.01; mse: 331.781


epoch 2: 100%|██████████| 3781/3781 [05:08<00:00, 12.24batch/s, kl_c=58, kl_s=1.25e+3, reconstr_loss=2.57e+3]
val-epoch 2: 100%|██████████| 946/946 [01:04<00:00, 14.56it/s]


val_recontr_loss=176.505, val_kl_c=4789594.500, val_kl_s=84.460
gMIG: -0.006; mse: 176.505


epoch 4: 100%|██████████| 3781/3781 [05:04<00:00, 12.40batch/s, kl_c=47.7, kl_s=1.13e+3, reconstr_loss=2.1e+3]
val-epoch 4: 100%|██████████| 946/946 [01:04<00:00, 14.68it/s]


val_recontr_loss=182.324, val_kl_c=333151200.000, val_kl_s=81.579
gMIG: 0.018; mse: 182.324


epoch 6: 100%|██████████| 3781/3781 [05:03<00:00, 12.46batch/s, kl_c=50.9, kl_s=1.09e+3, reconstr_loss=2.02e+3]
val-epoch 6: 100%|██████████| 946/946 [01:04<00:00, 14.72it/s]


val_recontr_loss=172.421, val_kl_c=50106556.000, val_kl_s=78.458
gMIG: 0.025; mse: 172.421


epoch 0: 100%|██████████| 3781/3781 [04:20<00:00, 14.49batch/s, loss=0.31]
val-epoch 0: 100%|██████████| 946/946 [01:02<00:00, 15.02it/s]


val_aupr: {0: np.float64(0.933), 1: np.float64(0.932)}
0.933
val_auroc: {0: np.float64(0.933), 1: np.float64(0.933)}
0.933
val_acc: 0.859


In [None]:
from tqdm import tqdm
import time

for i in tqdm(range(9)):
    time.sleep(1)

!kill $(ps aux | awk '{print $2}')

100%|██████████| 9/9 [00:09<00:00,  1.00s/it]