In [1]:
from torch.utils.data import DataLoader

from dataset import MNISTDataset
from model_trainer import deep_feature_reweighting, train
from models import SimpleModel
from spurious_features import Position, spurious_square

In [2]:
train_dataset = MNISTDataset(
    train=True,
    labels=[9, 7, 6],
    spurious_features={
        9: lambda img: spurious_square(img, pos=Position.LEFT_TOP, size=6),
        7: lambda img: spurious_square(img, pos=Position.LEFT_TOP, size=6),
    },
    probabilities={9: 0.99, 7: 0.01},
)

validation_dataset = MNISTDataset(train=False, labels=[9, 7, 6])

In [3]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)

In [4]:
model = SimpleModel(num_classes=3)

model_path, tensorboard_path = train(
    model=model,
    validation_loader=validation_loader,
    train_loader=train_loader,
    num_epochs=10,
)

Epochs: 100%|██████████| 10/10 [00:22<00:00,  2.23s/it, Train Loss=0.0000, Valid Loss=0.1290, Train Accuracy=99.95%, Valid Accuracy=97.30%, Worst Group Accuracy=91.97%]


In [5]:
dfr_train_set = MNISTDataset(train=True, labels=[9, 7, 6])
dfr_loader = DataLoader(dfr_train_set, batch_size=64, shuffle=True)

In [6]:
model_path, tensorboard_path = deep_feature_reweighting(
    path_to_model=model_path,
    path_to_tensorboard_run=tensorboard_path,
    model=model,
    num_epochs=2,
    validation_loader=validation_loader,
    train_loader=dfr_loader,
)

Reweighting Epochs: 100%|██████████| 2/2 [00:04<00:00,  2.37s/it, Train Loss=0.0000, Valid Loss=0.0246, Train Accuracy=99.35%, Valid Accuracy=99.17%, Worst Group Accuracy=98.32%]


In [7]:
tensorboard_path

'runs/spurious_trainer_15_11_2024_1602'