In [17]:
import torch

device = torch.device("cuda")

In [18]:
from spuco.utils import set_seed

set_seed(0)

In [19]:
from spuco.robust_train import ERM
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE
root = "data/mnist/"

trainset = SpuCoMNIST(
    root=root,
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.95,
    classes=classes,
    split="train",
    label_noise=0.001
)
trainset.initialize()

testset = SpuCoMNIST(
    root=root,
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()

from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty

classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]

valset = SpuCoMNIST(
    root=root,
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="val"
)
valset.initialize()



100%|██████████| 48004/48004 [00:05<00:00, 9511.94it/s]
100%|██████████| 10000/10000 [00:00<00:00, 10191.57it/s]
100%|██████████| 11996/11996 [00:01<00:00, 10320.57it/s]


In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from spuco.robust_train import ERM
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
import torchvision.transforms as T
import numpy as np
from spuco.models import model_factory


train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
not_shuffle_train_loader = DataLoader(trainset, batch_size=64, shuffle=False)
test_loader = DataLoader(testset, batch_size=64, shuffle=False)

# model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
model = model_factory("mlp", trainset[0][0].shape, trainset.num_classes).to(device)

from spuco.evaluate import Evaluator

# Initialize the model and optimizer
# optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=False)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
num_epochs = 5
# train_predictions_history = [[] for _ in range(len(trainset))]
# test_predictions_history = [[] for _ in range(len(testset))]
train_acc_history_lst = []
test_acc_history_lst = []

# Training function
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Evaluation function
def evaluate(model, data_loader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    predictions = []
    acc_history = []
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            acc_history.extend((predicted == target).cpu().numpy())
            predictions.extend(predicted.cpu().numpy())
    return correct / total, acc_history

# Calculate forgetability scores
def calculate_forgetability(predictions_history):
    forgetability_scores = []
    for sample_predictions in predictions_history:
        correct_to_incorrect = 0
        for i in range(1, len(sample_predictions)):
            if sample_predictions[i-1] == 1 and sample_predictions[i] == 0:
                correct_to_incorrect += 1

        if sum(sample_predictions) == 0:
            correct_to_incorrect = (num_epochs // 2) + 1
        forgetability_scores.append(correct_to_incorrect)
    return forgetability_scores



for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion, device)
    train_acc, train_acc_history = evaluate(model, not_shuffle_train_loader, criterion, device)
    test_acc, test_acc_history = evaluate(model, test_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")

    train_acc_history_lst.append(train_acc_history)
    test_acc_history_lst.append(test_acc_history)


Epoch 1/5, Train Acc: 0.9508, Test Acc: 0.2309
Epoch 2/5, Train Acc: 0.9524, Test Acc: 0.2683
Epoch 3/5, Train Acc: 0.9533, Test Acc: 0.2742
Epoch 4/5, Train Acc: 0.9566, Test Acc: 0.3309
Epoch 5/5, Train Acc: 0.9574, Test Acc: 0.3469


In [21]:
train_acc_history_lst = np.array(train_acc_history_lst).T
train_forgetability_scores = calculate_forgetability(train_acc_history_lst)


In [23]:
# group accuracy for the last epoch
print("Trainset accuracy for each group")
for group, members in trainset.group_partition.items():
    print(group, end=' ')
    for epoch in range(num_epochs):
        group_score = train_acc_history_lst[members, epoch]
        print(f"{group_score.mean():.4f}", end=' ')
    print()

Trainset accuracy for each group
(2, 2) 0.9999 0.9998 0.9989 0.9986 0.9981 
(0, 0) 0.9990 0.9943 0.9993 0.9985 0.9960 
(4, 4) 1.0000 1.0000 1.0000 0.9999 0.9997 
(1, 1) 1.0000 1.0000 0.9998 0.9971 1.0000 
(3, 3) 1.0000 0.9997 0.9990 0.9988 0.9991 
(0, 3) 0.0000 0.1127 0.0141 0.2042 0.3028 
(4, 0) 0.5520 0.6960 0.4240 0.7840 0.3120 
(3, 1) 0.0000 0.0000 0.0000 0.0000 0.0000 
(1, 2) 0.1261 0.0721 0.3243 0.2342 0.4865 
(0, 2) 0.0000 0.0000 0.0620 0.2481 0.1860 
(2, 0) 0.0000 0.0000 0.1639 0.0246 0.0246 
(0, 4) 0.0000 0.0853 0.0310 0.1783 0.3876 
(1, 3) 0.0000 0.0000 0.0076 0.0227 0.1439 
(3, 2) 0.0000 0.0000 0.1176 0.1324 0.2721 
(0, 1) 0.0000 0.0000 0.0342 0.3333 0.0000 
(3, 0) 0.0000 0.0160 0.0480 0.1920 0.3440 
(2, 4) 0.0000 0.0000 0.0000 0.0000 0.0000 
(4, 1) 0.0000 0.0000 0.0000 0.0000 0.0000 
(4, 3) 0.0233 0.0930 0.2326 0.2171 0.1550 
(3, 4) 0.0000 0.0000 0.0000 0.0328 0.1885 
(1, 0) 0.1181 0.6614 0.3071 0.4724 0.8110 
(2, 3) 0.0000 0.0160 0.0000 0.0640 0.0240 
(2, 1) 0.0000 0.0091 

In [24]:
print("Trainset forgetability score each group")
train_forgetability_scores = np.array(train_forgetability_scores)
count = 0
for group, members in trainset.group_partition.items():
    group_score = train_forgetability_scores[members]
    print(group, group_score.mean())

Trainset forgetability score each group
(2, 2) 0.0024558531165945504
(0, 0) 0.01019028803161069
(4, 4) 0.00033500837520938025
(1, 1) 0.002940856115891515
(3, 3) 0.0016216216216216215
(0, 3) 2.1901408450704225
(4, 0) 1.376
(3, 1) 3.0
(1, 2) 1.6846846846846846
(0, 2) 2.302325581395349
(2, 0) 2.5655737704918034
(0, 4) 1.8914728682170543
(1, 3) 2.553030303030303
(3, 2) 2.176470588235294
(0, 1) 2.3333333333333335
(3, 0) 1.968
(2, 4) 3.0
(4, 1) 3.0
(4, 3) 2.37984496124031
(3, 4) 2.4344262295081966
(1, 0) 0.9291338582677166
(2, 3) 2.864
(2, 1) 2.3545454545454545
(1, 4) 2.8333333333333335
(4, 2) 3.0


In [25]:
# group correct into (0,0), and incorrect into (1,1)

jtt_group_partition = {(0,0): [], (0,1): []}
for i in range(train_acc_history_lst.shape[0]):
    if train_acc_history_lst[i, -1]: # correct
        jtt_group_partition[(0,0)].append(i)
    else: # incorrect
        jtt_group_partition[(0,1)].append(i)

In [26]:
for group, members in jtt_group_partition.items():
    print(group, len(members), members[:10])

(0, 0) 45957 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
(0, 1) 2047 [21, 56, 89, 111, 113, 156, 158, 172, 177, 190]


In [33]:
fgt_group_partition = {(0,0): [], (0,1): []}
for i in range(train_acc_history_lst.shape[0]):
    if train_forgetability_scores[i] == 0: # correct
        fgt_group_partition[(0,0)].append(i)
    else: # incorrect
        fgt_group_partition[(0,1)].append(i)

In [34]:
for group, members in fgt_group_partition.items():
    print(group, len(members), members[:10])

(0, 0) 45805 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
(0, 1) 2199 [21, 56, 89, 111, 113, 156, 158, 172, 177, 190]


In [27]:
# from spuco.group_inference import JTTInference
# from spuco.utils import Trainer

# trainer = Trainer(
#     trainset=trainset,
#     model=model,
#     batch_size=64,
#     optimizer=optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
#     device=device,
#     verbose=True
# )

# # trainer.train(1)

# predictions = torch.argmax(trainer.get_trainset_outputs(), dim=-1).detach().cpu().tolist()
# jtt = JTTInference(
#     predictions=predictions,
#     class_labels=trainset.labels
# )
# group_partition = jtt.infer_groups()

Getting Trainset Outputs: 100%|██████████| 751/751 [00:01<00:00, 411.86batch/s]


In [28]:
# for group, members in group_partition.items():
#     print(group, len(members), members[:10])

(0, 0) 45957 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
(0, 1) 2047 [21, 56, 89, 111, 113, 156, 158, 172, 177, 190]


In [14]:
jtt_model = model_factory("mlp", trainset[0][0].shape, trainset.num_classes).to(device)
jtt_model.load_state_dict(model.state_dict())


<All keys matched successfully>

In [32]:
fgt_model = model_factory("mlp", trainset[0][0].shape, trainset.num_classes).to(device)
fgt_model.load_state_dict(model.state_dict())

<All keys matched successfully>

In [29]:
from torch.optim import SGD
from spuco.robust_train import UpSampleERM, DownSampleERM, CustomSampleERM

jtt_train = UpSampleERM(
    model=jtt_model,
    num_epochs=5,
    trainset=trainset,
    batch_size=64,
    group_partition=jtt_group_partition,
    optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
    device=device,
    verbose=True
)
jtt_train.train()

Epoch 0: 100%|██████████| 1437/1437 [00:06<00:00, 212.49batch/s, accuracy=40.0%, loss=2.29]    
Epoch 1: 100%|██████████| 1437/1437 [00:06<00:00, 216.88batch/s, accuracy=80.0%, loss=0.741]   
Epoch 2: 100%|██████████| 1437/1437 [00:10<00:00, 137.46batch/s, accuracy=70.0%, loss=1.07]    
Epoch 3: 100%|██████████| 1437/1437 [00:08<00:00, 176.66batch/s, accuracy=80.0%, loss=0.587]   
Epoch 4: 100%|██████████| 1437/1437 [00:06<00:00, 214.66batch/s, accuracy=60.0%, loss=1.62]    


In [16]:
# from torch.optim import SGD
# from spuco.robust_train import UpSampleERM, DownSampleERM, CustomSampleERM

# jtt_train = UpSampleERM(
#     model=model,
#     num_epochs=5,
#     trainset=trainset,
#     batch_size=64,
#     group_partition=group_partition,
#     optimizer=SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True),
#     device=device,
#     verbose=True
# )
# jtt_train.train()

Epoch 0: 100%|██████████| 1444/1444 [00:06<00:00, 211.43batch/s, accuracy=62.5%, loss=0.996]   
Epoch 1: 100%|██████████| 1444/1444 [00:06<00:00, 212.88batch/s, accuracy=79.16666666666667%, loss=0.704]
Epoch 2: 100%|██████████| 1444/1444 [00:06<00:00, 214.77batch/s, accuracy=33.333333333333336%, loss=1.34]
Epoch 3: 100%|██████████| 1444/1444 [00:06<00:00, 215.06batch/s, accuracy=62.5%, loss=1.03]    
Epoch 4: 100%|██████████| 1444/1444 [00:06<00:00, 215.99batch/s, accuracy=79.16666666666667%, loss=0.675]


In [30]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=jtt_model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:10,  2.28it/s]

Group (0, 0) Accuracy: 99.76359338061465


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:10,  2.19it/s]

Group (0, 1) Accuracy: 15.130023640661939


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:09,  2.22it/s]

Group (0, 2) Accuracy: 59.810874704491724


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:09,  2.18it/s]

Group (0, 3) Accuracy: 16.548463356973997


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:09,  2.18it/s]

Group (0, 4) Accuracy: 56.973995271867615


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:08,  2.16it/s]

Group (1, 0) Accuracy: 37.89731051344743


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:03<00:08,  2.15it/s]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:03<00:07,  2.19it/s]

Group (1, 2) Accuracy: 46.3235294117647


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:04<00:07,  2.16it/s]

Group (1, 3) Accuracy: 9.558823529411764


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:04<00:06,  2.17it/s]

Group (1, 4) Accuracy: 1.9607843137254901


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:05<00:06,  2.17it/s]

Group (2, 0) Accuracy: 28.533333333333335


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:05<00:05,  2.18it/s]

Group (2, 1) Accuracy: 20.8


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:05<00:05,  2.17it/s]

Group (2, 2) Accuracy: 99.46666666666667


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:06<00:05,  2.14it/s]

Group (2, 3) Accuracy: 7.466666666666667


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:06<00:04,  2.14it/s]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:07<00:04,  2.11it/s]

Group (3, 0) Accuracy: 59.79899497487437


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:07<00:03,  2.12it/s]

Group (3, 1) Accuracy: 17.884130982367758


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:08<00:03,  2.12it/s]

Group (3, 2) Accuracy: 56.92695214105793


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:08<00:02,  2.11it/s]

Group (3, 3) Accuracy: 100.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:09<00:02,  2.13it/s]

Group (3, 4) Accuracy: 19.143576826196472


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:09<00:01,  2.13it/s]

Group (4, 0) Accuracy: 65.74307304785894


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:10<00:01,  2.13it/s]

Group (4, 1) Accuracy: 3.27455919395466


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:10<00:00,  2.13it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:11<00:00,  2.14it/s]

Group (4, 3) Accuracy: 16.161616161616163


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:11<00:00,  2.15it/s]

Group (4, 4) Accuracy: 100.0





{(0, 0): 99.76359338061465,
 (0, 1): 15.130023640661939,
 (0, 2): 59.810874704491724,
 (0, 3): 16.548463356973997,
 (0, 4): 56.973995271867615,
 (1, 0): 37.89731051344743,
 (1, 1): 100.0,
 (1, 2): 46.3235294117647,
 (1, 3): 9.558823529411764,
 (1, 4): 1.9607843137254901,
 (2, 0): 28.533333333333335,
 (2, 1): 20.8,
 (2, 2): 99.46666666666667,
 (2, 3): 7.466666666666667,
 (2, 4): 0.0,
 (3, 0): 59.79899497487437,
 (3, 1): 17.884130982367758,
 (3, 2): 56.92695214105793,
 (3, 3): 100.0,
 (3, 4): 19.143576826196472,
 (4, 0): 65.74307304785894,
 (4, 1): 3.27455919395466,
 (4, 2): 0.0,
 (4, 3): 16.161616161616163,
 (4, 4): 100.0}

In [31]:
from spuco.evaluate import Evaluator

evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=64,
    model=model,
    device=device,
    verbose=True
)
evaluator.evaluate()

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:10,  2.19it/s]

Group (0, 0) Accuracy: 99.52718676122932


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:10,  2.15it/s]

Group (0, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:10,  2.15it/s]

Group (0, 2) Accuracy: 17.96690307328605


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:09,  2.14it/s]

Group (0, 3) Accuracy: 22.22222222222222


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:02<00:09,  2.11it/s]

Group (0, 4) Accuracy: 37.35224586288416


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:08,  2.11it/s]

Group (1, 0) Accuracy: 76.28361858190709


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:03<00:08,  2.12it/s]

Group (1, 1) Accuracy: 100.0


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:03<00:07,  2.14it/s]

Group (1, 2) Accuracy: 50.490196078431374


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:04<00:07,  2.16it/s]

Group (1, 3) Accuracy: 23.284313725490197


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:04<00:06,  2.18it/s]

Group (1, 4) Accuracy: 1.7156862745098038


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:05<00:06,  2.17it/s]

Group (2, 0) Accuracy: 0.8


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:05<00:05,  2.19it/s]

Group (2, 1) Accuracy: 0.5333333333333333


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:06<00:05,  2.17it/s]

Group (2, 2) Accuracy: 100.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:06<00:05,  2.15it/s]

Group (2, 3) Accuracy: 1.6


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:06<00:04,  2.14it/s]

Group (2, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:07<00:04,  2.10it/s]

Group (3, 0) Accuracy: 32.663316582914575


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:07<00:03,  2.12it/s]

Group (3, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:08<00:03,  2.14it/s]

Group (3, 2) Accuracy: 28.96725440806045


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:08<00:02,  2.14it/s]

Group (3, 3) Accuracy: 99.49622166246851


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:09<00:02,  2.14it/s]

Group (3, 4) Accuracy: 21.15869017632242


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:09<00:01,  2.14it/s]

Group (4, 0) Accuracy: 33.249370277078086


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:10<00:01,  2.16it/s]

Group (4, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:10<00:00,  2.14it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:11<00:00,  2.15it/s]

Group (4, 3) Accuracy: 13.383838383838384


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:11<00:00,  2.14it/s]

Group (4, 4) Accuracy: 100.0





{(0, 0): 99.52718676122932,
 (0, 1): 0.0,
 (0, 2): 17.96690307328605,
 (0, 3): 22.22222222222222,
 (0, 4): 37.35224586288416,
 (1, 0): 76.28361858190709,
 (1, 1): 100.0,
 (1, 2): 50.490196078431374,
 (1, 3): 23.284313725490197,
 (1, 4): 1.7156862745098038,
 (2, 0): 0.8,
 (2, 1): 0.5333333333333333,
 (2, 2): 100.0,
 (2, 3): 1.6,
 (2, 4): 0.0,
 (3, 0): 32.663316582914575,
 (3, 1): 0.0,
 (3, 2): 28.96725440806045,
 (3, 3): 99.49622166246851,
 (3, 4): 21.15869017632242,
 (4, 0): 33.249370277078086,
 (4, 1): 0.0,
 (4, 2): 0.0,
 (4, 3): 13.383838383838384,
 (4, 4): 100.0}