In [23]:
import torch 

device = torch.device("cpu")

In [24]:
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root="../data/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

testset = SpuCoMNIST(
    root="../data/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

100%|██████████| 48004/48004 [00:07<00:00, 6817.83it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6904.47it/s]


In [25]:
from spuco.models import model_factory 
from spuco.utils import Trainer
from torch.optim import SGD

model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
trainer = Trainer(
    trainset=trainset,
    model=model,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)

trainer.train(1)

Epoch 0:   0%|          | 0/751 [00:00<?, ?batch/s]

Epoch 0: 100%|██████████| 751/751 [00:11<00:00, 67.98batch/s, accuracy=100.0%, loss=0.00642] 


In [26]:
from spuco.evaluate import Evaluator 

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:03,  6.56it/s]

Group (0, 0) Accuracy: 100.0
Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:00<00:01, 10.16it/s]

Group (0, 2) Accuracy: 0.0
Group (0, 3) Accuracy: 0.0
Group (0, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:00<00:01,  9.28it/s]

Group (1, 0) Accuracy: 0.0
Group (1, 1) Accuracy: 100.0
Group (1, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:00<00:01, 11.98it/s]

Group (1, 3) Accuracy: 0.0
Group (1, 4) Accuracy: 0.0
Group (2, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:01<00:00, 14.36it/s]

Group (2, 1) Accuracy: 0.0
Group (2, 2) Accuracy: 100.0
Group (2, 3) Accuracy: 0.0
Group (2, 4) Accuracy: 0.0
Group (3, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:01<00:00, 13.27it/s]

Group (3, 1) Accuracy: 0.0
Group (3, 2) Accuracy: 0.0
Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:01<00:00, 14.28it/s]

Group (3, 4) Accuracy: 0.0
Group (4, 0) Accuracy: 0.0
Group (4, 1) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:02<00:00, 12.41it/s]

Group (4, 2) Accuracy: 0.0
Group (4, 3) Accuracy: 0.0
Group (4, 4) Accuracy: 100.0





{(0, 0): 100.0,
 (0, 1): 0.0,
 (0, 2): 0.0,
 (0, 3): 0.0,
 (0, 4): 0.0,
 (1, 0): 0.0,
 (1, 1): 100.0,
 (1, 2): 0.0,
 (1, 3): 0.0,
 (1, 4): 0.0,
 (2, 0): 0.0,
 (2, 1): 0.0,
 (2, 2): 100.0,
 (2, 3): 0.0,
 (2, 4): 0.0,
 (3, 0): 0.0,
 (3, 1): 0.0,
 (3, 2): 0.0,
 (3, 3): 100.0,
 (3, 4): 0.0,
 (4, 0): 0.0,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 0.0,
 (4, 4): 100.0}

In [27]:
evaluator.average_accuracy

99.49379218398467

In [28]:
evaluator.evaluate_spurious_attribute_prediction()

100.0

In [29]:
from spuco.group_inference import Cluster, ClusterAlg

logits = trainer.get_trainset_outputs()
cluster = Cluster(
    Z=logits,
    class_labels=trainset.labels,
    cluster_alg=ClusterAlg.KMEANS,
    num_clusters=2,
    device=device,
    verbose=True
)
group_partition = cluster.infer_groups()

Getting Trainset Outputs: 100%|██████████| 751/751 [00:04<00:00, 174.07batch/s]
Clustering class-wise: 100%|██████████| 5/5 [00:00<00:00, 19.58it/s]


In [30]:
evaluator = Evaluator(
    testset=trainset,
    group_partition=group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  10%|█         | 1/10 [00:00<00:05,  1.76it/s]

Group (0, 0) Accuracy: 99.9843137254902


Evaluating group-wise accuracy:  20%|██        | 2/10 [00:00<00:03,  2.33it/s]

Group (0, 1) Accuracy: 98.66950505588079


Evaluating group-wise accuracy:  30%|███       | 3/10 [00:01<00:04,  1.63it/s]

Group (1, 0) Accuracy: 99.8132973757909
Group (1, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  50%|█████     | 5/10 [00:02<00:01,  2.61it/s]

Group (2, 0) Accuracy: 99.61290322580645


Evaluating group-wise accuracy:  60%|██████    | 6/10 [00:02<00:01,  2.61it/s]

Group (2, 1) Accuracy: 99.35794542536115


Evaluating group-wise accuracy:  70%|███████   | 7/10 [00:03<00:01,  1.74it/s]

Group (3, 0) Accuracy: 99.7839283876942
Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy: 100%|██████████| 10/10 [00:04<00:00,  2.23it/s]

Group (4, 0) Accuracy: 99.89365096245879
Group (4, 1) Accuracy: 0.0





{(0, 0): 99.9843137254902,
 (0, 1): 98.66950505588079,
 (1, 0): 99.8132973757909,
 (1, 1): 0.0,
 (2, 0): 99.61290322580645,
 (2, 1): 99.35794542536115,
 (3, 0): 99.7839283876942,
 (3, 1): 0.0,
 (4, 0): 99.89365096245879,
 (4, 1): 0.0}

In [31]:
from spuco.robust_train import GroupBalanceBatchERM, ClassBalanceBatchERM

group_balance_erm = GroupBalanceBatchERM(
    model=model,
    num_epochs=5,
    trainset=trainset,
    group_partition=group_partition,
    batch_size=64,
    optimizer=SGD(model.parameters(), lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
group_balance_erm.train()

Epoch 0: 100%|██████████| 751/751 [00:12<00:00, 58.89batch/s, accuracy=75.0%, loss=0.279]   
Epoch 1: 100%|██████████| 751/751 [00:11<00:00, 67.05batch/s, accuracy=75.0%, loss=0.214]   
Epoch 2: 100%|██████████| 751/751 [00:11<00:00, 64.53batch/s, accuracy=100.0%, loss=0.103]   
Epoch 3: 100%|██████████| 751/751 [00:11<00:00, 66.66batch/s, accuracy=100.0%, loss=0.0124]  
Epoch 4: 100%|██████████| 751/751 [00:10<00:00, 71.14batch/s, accuracy=100.0%, loss=0.0534]  


In [32]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:00<00:00, 27.45it/s]

Group (0, 0) Accuracy: 99.76359338061465
Group (0, 1) Accuracy: 17.494089834515368
Group (0, 2) Accuracy: 0.0
Group (0, 3) Accuracy: 0.2364066193853428
Group (0, 4) Accuracy: 2.127659574468085
Group (1, 0) Accuracy: 19.559902200489


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:00<00:00, 27.64it/s]

Group (1, 1) Accuracy: 99.26650366748166
Group (1, 2) Accuracy: 82.3529411764706
Group (1, 3) Accuracy: 87.25490196078431
Group (1, 4) Accuracy: 76.7156862745098
Group (2, 0) Accuracy: 0.0
Group (2, 1) Accuracy: 0.5333333333333333


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:00<00:00, 29.07it/s]

Group (2, 2) Accuracy: 97.06666666666666
Group (2, 3) Accuracy: 1.8666666666666667
Group (2, 4) Accuracy: 0.0
Group (3, 0) Accuracy: 66.58291457286433
Group (3, 1) Accuracy: 27.455919395465994
Group (3, 2) Accuracy: 1.7632241813602014
Group (3, 3) Accuracy: 98.48866498740554


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:00<00:00, 28.55it/s]

Group (3, 4) Accuracy: 0.7556675062972292
Group (4, 0) Accuracy: 64.23173803526448
Group (4, 1) Accuracy: 66.49874055415617
Group (4, 2) Accuracy: 57.68261964735516
Group (4, 3) Accuracy: 1.0101010101010102
Group (4, 4) Accuracy: 94.6969696969697





{(0, 0): 99.76359338061465,
 (0, 1): 17.494089834515368,
 (0, 2): 0.0,
 (0, 3): 0.2364066193853428,
 (0, 4): 2.127659574468085,
 (1, 0): 19.559902200489,
 (1, 1): 99.26650366748166,
 (1, 2): 82.3529411764706,
 (1, 3): 87.25490196078431,
 (1, 4): 76.7156862745098,
 (2, 0): 0.0,
 (2, 1): 0.5333333333333333,
 (2, 2): 97.06666666666666,
 (2, 3): 1.8666666666666667,
 (2, 4): 0.0,
 (3, 0): 66.58291457286433,
 (3, 1): 27.455919395465994,
 (3, 2): 1.7632241813602014,
 (3, 3): 98.48866498740554,
 (3, 4): 0.7556675062972292,
 (4, 0): 64.23173803526448,
 (4, 1): 66.49874055415617,
 (4, 2): 57.68261964735516,
 (4, 3): 1.0101010101010102,
 (4, 4): 94.6969696969697}

In [33]:
evaluator.average_accuracy

97.55321861168294

In [34]:
evaluator.evaluate_spurious_attribute_prediction()

59.13