#Installations

In [5]:
%pip install spuco



In [6]:
from spuco.utils import set_seed
set_seed(0)

In [7]:
import torch
device = torch.device("cpu")

#Creating Data


In [8]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]

difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

testset = SpuCoMNIST(
    root="/data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 3774350.38it/s]


Extracting /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 694370.40it/s]


Extracting /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 5485968.28it/s]


Extracting /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1327933.14it/s]

Extracting /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw




100%|██████████| 48004/48004 [00:08<00:00, 5435.83it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5785.63it/s]


# 1) Train a model using ERM


In [9]:
from spuco.models import model_factory
from spuco.robust_train import ERM
from torch.optim import SGD

model = model_factory('lenet',trainset[0][0].shape,trainset.num_classes)
erm = ERM(
    model=model,
    num_epochs=1,
    trainset=trainset,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
erm.train()

Epoch 0: 100%|██████████| 751/751 [00:30<00:00, 24.79batch/s, accuracy=100.0%, loss=0.0111]


In [10]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:11,  2.14it/s]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:09,  2.47it/s]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:07,  2.76it/s]

Group (0, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:06,  3.00it/s]

Group (0, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:06,  3.16it/s]

Group (0, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:05,  3.21it/s]

Group (1, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:05,  3.25it/s]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:05,  3.32it/s]

Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:02<00:04,  3.39it/s]

Group (1, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:03<00:04,  3.35it/s]

Group (1, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:04,  3.27it/s]

Group (2, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.37it/s]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:04<00:03,  3.34it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:04<00:03,  3.39it/s]

Group (2, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:02,  3.42it/s]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:04<00:02,  3.43it/s]

Group (3, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:05<00:02,  3.32it/s]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:05<00:02,  3.36it/s]

Group (3, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:05<00:01,  3.39it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:06<00:01,  3.38it/s]

Group (3, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:06<00:01,  3.37it/s]

Group (4, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:06<00:00,  3.40it/s]

Group (4, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:07<00:00,  3.42it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:07<00:00,  3.36it/s]

Group (4, 3) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.27it/s]

Group (4, 4) Accuracy: 100.0





{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 100.0}

# 2) Cluster inputs based on the output they produce for ERM

In [11]:
from spuco.group_inference import Cluster, ClusterAlg

logits = erm.trainer.get_trainset_outputs()
cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    cluster_alg=ClusterAlg.KMEANS,
    num_clusters=2,
    device=device,
    verbose=True
)
group_partition = cluster.infer_groups()

Getting Trainset Outputs: 100%|██████████| 751/751 [00:08<00:00, 84.36batch/s] 
Clustering class-wise: 100%|██████████| 5/5 [00:00<00:00, 16.96it/s]


In [12]:
for key in sorted(group_partition.keys()):
    print(key, len(group_partition[key]))

(0, 0) 10082
(0, 1) 51
(1, 0) 9623
(1, 1) 49
(2, 0) 8965
(2, 1) 46
(3, 0) 9698
(3, 1) 49
(4, 0) 9393
(4, 1) 48


In [13]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=trainset,
    group_partition=group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  10%|█         | 1/10 [00:01<00:16,  1.80s/it]

Group (0, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  20%|██        | 2/10 [00:02<00:07,  1.14it/s]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  30%|███       | 3/10 [00:03<00:08,  1.25s/it]

Group (1, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  40%|████      | 4/10 [00:03<00:05,  1.17it/s]

Group (1, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  50%|█████     | 5/10 [00:05<00:05,  1.11s/it]

Group (2, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  60%|██████    | 6/10 [00:05<00:03,  1.23it/s]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  70%|███████   | 7/10 [00:07<00:03,  1.10s/it]

Group (3, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 8/10 [00:07<00:01,  1.22it/s]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  90%|█████████ | 9/10 [00:10<00:01,  1.33s/it]

Group (4, 0) Accuracy: 100.0


Evaluating group-wise accuracy: 100%|██████████| 10/10 [00:10<00:00,  1.05s/it]

Group (4, 1) Accuracy: 0.0





{(0, 0): 100.0,
 (0, 1): 0.0,
 (1, 0): 100.0,
 (1, 1): 0.0,
 (2, 0): 100.0,
 (2, 1): 0.0,
 (3, 0): 100.0,
 (3, 1): 0.0,
 (4, 0): 100.0,
 (4, 1): 0.0}

# 3) Retrain using "Group-Balancing" to ensure in each batch each group appears equally.


In [14]:
from spuco.robust_train import GroupBalanceBatchERM

group_balance = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=trainset,
    group_partition=trainset.group_partition,
    batch_size=64,
    optimizer=torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay = 5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
group_balance.train()

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Epoch 0: 100%|██████████| 751/751 [00:17<00:00, 43.64batch/s, accuracy=100.0%, loss=0.0252]
Epoch 1: 100%|██████████| 751/751 [00:17<00:00, 43.96batch/s, accuracy=100.0%, loss=0.00357]
Epoch 2: 100%|██████████| 751/751 [00:18<00:00, 39.58batch/s, accuracy=100.0%, loss=0.00198]
Epoch 3: 100%|██████████| 751/751 [00:17<00:00, 43.03batch/s, accuracy=100.0%, loss=0.00036]
Epoch 4: 100%|██████████| 751/751 [00:18<00:00, 40.83batch/s, accuracy=100.0%, loss=0.000361]
Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:11,  2.05it/s]

Group (0, 0) Accuracy: 99.52718676122932


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:08,  2.61it/s]

Group (0, 1) Accuracy: 81.7966903073286


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:07,  2.82it/s]

Group (0, 2) Accuracy: 79.43262411347517


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:07,  2.99it/s]

Group (0, 3) Accuracy: 64.30260047281324


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:06,  3.08it/s]

Group (0, 4) Accuracy: 87.94326241134752


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:06,  3.15it/s]

Group (1, 0) Accuracy: 86.79706601466992


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:05,  3.18it/s]

Group (1, 1) Accuracy: 99.02200488997555


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:05,  3.22it/s]

Group (1, 2) Accuracy: 83.82352941176471


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:02<00:04,  3.24it/s]

Group (1, 3) Accuracy: 82.59803921568627


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:03<00:04,  3.20it/s]

Group (1, 4) Accuracy: 76.22549019607843


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:04,  3.25it/s]

Group (2, 0) Accuracy: 74.66666666666667


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:03<00:03,  3.28it/s]

Group (2, 1) Accuracy: 67.73333333333333


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:04<00:03,  3.24it/s]

Group (2, 2) Accuracy: 99.46666666666667


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:04<00:03,  3.28it/s]

Group (2, 3) Accuracy: 40.8


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:04<00:03,  3.28it/s]

Group (2, 4) Accuracy: 43.04812834224599


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:05<00:02,  3.29it/s]

Group (3, 0) Accuracy: 84.42211055276383


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:05<00:02,  3.23it/s]

Group (3, 1) Accuracy: 82.87153652392946


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:05<00:02,  3.25it/s]

Group (3, 2) Accuracy: 88.66498740554157


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:06<00:01,  3.26it/s]

Group (3, 3) Accuracy: 98.48866498740554


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:06<00:01,  3.24it/s]

Group (3, 4) Accuracy: 85.64231738035265


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:06<00:01,  3.24it/s]

Group (4, 0) Accuracy: 80.10075566750629


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:06<00:00,  3.24it/s]

Group (4, 1) Accuracy: 54.659949622166245


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:07<00:00,  3.10it/s]

Group (4, 2) Accuracy: 53.65239294710327


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:07<00:00,  3.18it/s]

Group (4, 3) Accuracy: 41.91919191919192


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.16it/s]

Group (4, 4) Accuracy: 99.24242424242425





{(0, 0): 99.52718676122932,
 (0, 1): 81.7966903073286,
 (0, 2): 79.43262411347517,
 (0, 3): 64.30260047281324,
 (0, 4): 87.94326241134752,
 (1, 0): 86.79706601466992,
 (1, 1): 99.02200488997555,
 (1, 2): 83.82352941176471,
 (1, 3): 82.59803921568627,
 (1, 4): 76.22549019607843,
 (2, 0): 74.66666666666667,
 (2, 1): 67.73333333333333,
 (2, 2): 99.46666666666667,
 (2, 3): 40.8,
 (2, 4): 43.04812834224599,
 (3, 0): 84.42211055276383,
 (3, 1): 82.87153652392946,
 (3, 2): 88.66498740554157,
 (3, 3): 98.48866498740554,
 (3, 4): 85.64231738035265,
 (4, 0): 80.10075566750629,
 (4, 1): 54.659949622166245,
 (4, 2): 53.65239294710327,
 (4, 3): 41.91919191919192,
 (4, 4): 99.24242424242425}

In [15]:
evaluator.worst_group_accuracy
# ((2, 4), 35.026737967914436)


((2, 3), 40.8)

In [16]:
evaluator.average_accuracy
# 99.66139282632827\\

99.0135106567585

In [17]:

evaluator.evaluate_spurious_attribute_prediction()

36.93