In [None]:
import os
import hydra
import torch
import torch.nn as nn
import torchmetrics

from hydra.utils import instantiate
from dotenv import load_dotenv

load_dotenv()
os.chdir('../')

from src.utils.custom_metrics import MinorityMajorityAccuracy

In [None]:
def train_and_pred_linear_classifier(x_train, y_train, x_test, n_epochs=1000, lr=0.01):
    # Define the linear model
    model = nn.Linear(x_train.shape[1], 1)

    # Loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # Training loop
    for _ in range(n_epochs):  # Number of epochs
        model.train()
        optimizer.zero_grad()
        outputs = model(x_train)
        loss = criterion(outputs, y_train.unsqueeze(1).float())
        loss.backward()
        optimizer.step()

    # Testing
    model.eval()
    with torch.no_grad():
        test_pred = model(x_test.unsqueeze(0))
        test_pred_label = torch.sigmoid(test_pred).round().item()  # Convert to binary label

    # Check if prediction matches the test label
    return test_pred_label

In [None]:
hydra.initialize(config_path="../configs", version_base="1.2")

# Load the configuration as a generic dictionary
config = hydra.compose(config_name="train", overrides=["spurious_setting=separate_token"])

In [None]:
datamodule = instantiate(config.datamodule)

In [None]:
datamodule.setup()

In [None]:
train_on = {
    "x": lambda input_seq: input_seq[:, 0::3],
    "c": lambda input_seq: input_seq[:, 1::3],
    "xc": lambda input_seq: torch.cat([input_seq[:, 0::3], input_seq[:, 1::3]], dim=2),
    "x+c": lambda input_seq: input_seq[:, 0::3] + input_seq[:, 1::3],
}

In [None]:
for mode in train_on.keys():
    accuracy = torchmetrics.Accuracy(task="binary")
    accuracy_minority = MinorityMajorityAccuracy(group_type="minority")
    accuracy_majority = MinorityMajorityAccuracy(group_type="majority")
    
    for batch, batch_idx, dataloader_idx in datamodule.val_dataloader():
        if dataloader_idx != 0: # check only on inner val set
            break
        input_seq, spurious_labels, class_labels, _ = batch # input_seq, spurious_labels, class_labels, image_indices
        
        batch_data = train_on[mode](input_seq) # torch.cat([input_seq[:, ::3], input_seq[:, 1::3]], dim=2)
    
        batch_train_x = batch_data[:, :-1]
        batch_test_x = batch_data[:, -1]
        
        batch_train_y = class_labels[:, :-1]


        for x_train, y_train, x_test, y_test_all, c_test_all in zip(batch_train_x, batch_train_y, batch_test_x, class_labels, spurious_labels):
            y_test_pred = train_and_pred_linear_classifier(x_train, y_train, x_test)
            accuracy.update(torch.tensor([y_test_pred]), torch.tensor([y_test_all[-1]]))
            accuracy_minority.update(torch.tensor([[y_test_pred]]), y_test_all.unsqueeze(0), c_test_all.unsqueeze(0))
            accuracy_majority.update(torch.tensor([[y_test_pred]]), y_test_all.unsqueeze(0), c_test_all.unsqueeze(0))

    print(f"'{mode}' acc={accuracy.compute()} min_acc={accuracy_minority.compute()} maj_acc={accuracy_majority.compute()}")
