<a href="https://colab.research.google.com/github/saladpalad/bigML-assessment/blob/main/bigML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Setup

In [None]:
!pip install spuco



In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import tqdm as tqdm
import spuco.datasets

### Initialize dataset

In [None]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

100%|██████████| 48004/48004 [00:10<00:00, 4495.11it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6779.01it/s]


### Train ERM Model (ResNet-18)

In [None]:
from spuco.models import model_factory
from spuco.utils import Trainer
from torch.optim import SGD

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model_factory("resnet18", trainset[0][0].shape, trainset.num_classes).to(device)
erm = Trainer(
    trainset=trainset,
    model=model,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)

erm.train(1)

cuda


Epoch 0: 100%|██████████| 751/751 [00:23<00:00, 32.22batch/s, accuracy=75.0%, loss=0.371]


In [None]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:13,  1.75it/s]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:01<00:11,  1.96it/s]

Group (0, 1) Accuracy: 31.914893617021278


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:10,  2.08it/s]

Group (0, 2) Accuracy: 37.35224586288416


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:09,  2.10it/s]

Group (0, 3) Accuracy: 30.73286052009456


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:09,  2.14it/s]

Group (0, 4) Accuracy: 27.659574468085108


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:08,  2.12it/s]

Group (1, 0) Accuracy: 20.78239608801956


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:03<00:08,  2.15it/s]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:03<00:07,  2.15it/s]

Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:04<00:07,  2.12it/s]

Group (1, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:04<00:07,  2.10it/s]

Group (1, 4) Accuracy: 0.9803921568627451


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:05<00:06,  2.12it/s]

Group (2, 0) Accuracy: 24.266666666666666


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:05<00:06,  2.14it/s]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:06<00:05,  2.13it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:06<00:05,  2.13it/s]

Group (2, 3) Accuracy: 0.26666666666666666


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:07<00:04,  2.14it/s]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:07<00:04,  1.94it/s]

Group (3, 0) Accuracy: 8.793969849246231


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:08<00:04,  1.73it/s]

Group (3, 1) Accuracy: 1.0075566750629723


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:09<00:04,  1.61it/s]

Group (3, 2) Accuracy: 9.06801007556675


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:09<00:03,  1.70it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:10<00:02,  1.81it/s]

Group (3, 4) Accuracy: 11.335012594458439


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:10<00:02,  1.90it/s]

Group (4, 0) Accuracy: 58.69017632241814


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:11<00:01,  1.98it/s]

Group (4, 1) Accuracy: 0.2518891687657431


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:11<00:00,  2.03it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:11<00:00,  2.08it/s]

Group (4, 3) Accuracy: 6.0606060606060606


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:12<00:00,  2.01it/s]

Group (4, 4) Accuracy: 100.0





{(0, 0): 100.0,
 (0, 1): 31.914893617021278,
 (0, 2): 37.35224586288416,
 (0, 3): 30.73286052009456,
 (0, 4): 27.659574468085108,
 (1, 0): 20.78239608801956,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.9803921568627451,
 (2, 0): 24.266666666666666,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.26666666666666666,
 (2, 4): 0.0,
 (3, 0): 8.793969849246231,
 (3, 1): 1.0075566750629723,
 (3, 2): 9.06801007556675,
 (3, 3): 100.0,
 (3, 4): 11.335012594458439,
 (4, 0): 58.69017632241814,
 (4, 1): 0.2518891687657431,
 (4, 2): 0.0,
 (4, 3): 6.0606060606060606,
 (4, 4): 100.0}

### Cluster in the feature space of the ERM trained model

In [None]:
from spuco.group_inference import Cluster, ClusterAlg

logits = erm.get_trainset_outputs()
cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    cluster_alg=ClusterAlg.KMEANS,
    num_clusters=2,
    device=device,
    verbose=True
)
group_partition = cluster.infer_groups()

Getting Trainset Outputs: 100%|██████████| 751/751 [00:05<00:00, 133.08batch/s]
Clustering class-wise: 100%|██████████| 5/5 [00:00<00:00, 10.49it/s]


### Retrain using "Group-Balancing"

In [None]:
from torch.optim import SGD
from spuco.robust_train import GroupBalanceBatchERM, ClassBalanceBatchERM
from spuco.models import model_factory

model = model_factory("resnet18", trainset[0][0].shape, trainset.num_classes).to(device)
group_balance_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
group_balance_erm.train()

Epoch 0: 100%|██████████| 751/751 [00:23<00:00, 32.63batch/s, accuracy=75.0%, loss=1.05]
Epoch 1: 100%|██████████| 751/751 [00:23<00:00, 32.01batch/s, accuracy=100.0%, loss=0.168]
Epoch 2: 100%|██████████| 751/751 [00:23<00:00, 32.61batch/s, accuracy=100.0%, loss=0.155]
Epoch 3: 100%|██████████| 751/751 [00:23<00:00, 32.55batch/s, accuracy=100.0%, loss=0.0594]
Epoch 4: 100%|██████████| 751/751 [00:24<00:00, 30.72batch/s, accuracy=100.0%, loss=0.126]


In [None]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

print("Worst group accuracy:", evaluator.worst_group_accuracy)
print("Avg group accuracy:", evaluator.average_accuracy)
print("Spurious attribute accuracy", evaluator.evaluate_spurious_attribute_prediction())

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:10,  2.21it/s]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:10,  2.20it/s]

Group (0, 1) Accuracy: 55.319148936170215


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:09,  2.22it/s]

Group (0, 2) Accuracy: 65.2482269503546


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:09,  2.22it/s]

Group (0, 3) Accuracy: 64.53900709219859


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:09,  2.20it/s]

Group (0, 4) Accuracy: 65.72104018912529


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:08,  2.17it/s]

Group (1, 0) Accuracy: 36.919315403422985


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:03<00:08,  2.13it/s]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:03<00:07,  2.16it/s]

Group (1, 2) Accuracy: 29.41176470588235


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:04<00:07,  2.16it/s]

Group (1, 3) Accuracy: 30.147058823529413


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:04<00:06,  2.18it/s]

Group (1, 4) Accuracy: 50.490196078431374


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:05<00:06,  2.18it/s]

Group (2, 0) Accuracy: 74.66666666666667


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:05<00:05,  2.18it/s]

Group (2, 1) Accuracy: 31.733333333333334


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:06<00:06,  1.89it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:06<00:06,  1.74it/s]

Group (2, 3) Accuracy: 33.86666666666667


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:07<00:05,  1.68it/s]

Group (2, 4) Accuracy: 6.149732620320855


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:07<00:04,  1.81it/s]

Group (3, 0) Accuracy: 20.85427135678392


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:08<00:04,  1.91it/s]

Group (3, 1) Accuracy: 28.71536523929471


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:08<00:03,  1.98it/s]

Group (3, 2) Accuracy: 24.43324937027708


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:09<00:02,  2.00it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:09<00:02,  2.06it/s]

Group (3, 4) Accuracy: 44.584382871536526


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:10<00:01,  2.09it/s]

Group (4, 0) Accuracy: 50.377833753148614


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:10<00:01,  2.13it/s]

Group (4, 1) Accuracy: 6.801007556675063


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:11<00:00,  2.16it/s]

Group (4, 2) Accuracy: 9.571788413098236


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:11<00:00,  2.16it/s]

Group (4, 3) Accuracy: 49.24242424242424


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:12<00:00,  2.03it/s]

Group (4, 4) Accuracy: 100.0
Worst group accuracy: ((2, 4), 6.149732620320855)
Avg group accuracy: 99.70374144981233





Spurious attribute accuracy 63.83
