In [53]:
from train import run_epoch
from torch.utils.data import DataLoader
import torch
import wandb
import os
from utils import set_seed, Logger, CSVBatchLogger, log_args, get_model, hinge_loss, split_data, check_args, get_subsampled_indices
import numpy as np

# get path
p = '0.7'
dataset = 'CelebA'
p2wd = '2e-05' if dataset == 'CUB' else '0.0001'
seed = 0
model_name = 'best_wg_acc_model'
main_dir = f"/home/thien/research/pseudogroups/{dataset}/splitpgl_sweep_logs/" \
                  f"p{p}_wd0.0001_lr0.0001"
best_model_path = f"{main_dir}/part2_oll-1rw_rgl_group_dro_p0.7_wd{p2wd}_lr0.0001_s0/{model_name}.pth"
data_path = f"{main_dir}/part1_s{seed}/part1and2_data_p{p}"

best_model_path1 = f"{main_dir}/part1_s{seed}/best_model.pth"
device = 'cuda:0'

# load data splits
data = torch.load(data_path)
part1_data, part2_data = data['part1'], data['part2']
batch_size = 32

part1_loader = DataLoader(part1_data, shuffle=False, batch_size=batch_size, pin_memory=True)
part2_loader = DataLoader(part2_data, shuffle=False, batch_size=batch_size, pin_memory=True)


In [54]:
# load model
model = torch.load(best_model_path)
model.to(device)
model.eval()

model1 = torch.load(best_model_path1)
model1.to(device)
model1.eval()


is_training = False

In [55]:
# now run the model on the desired dataset
from tqdm.notebook import tqdm

loader = part1_loader

with torch.set_grad_enabled(is_training):  # to make sure we don't save grad when val
    for batch_idx, batch in tqdm(enumerate(loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, data_idx = batch
        outputs = model(x)

        # now log the desired stats
        # Calculate stats -- get the prediction and compare with groundtruth -- save to output df
        if batch_idx == 0:
            acc_y_pred = np.argmax(outputs.detach().cpu().numpy(), axis=1)
            acc_y_true = y.detach().cpu().numpy()
            acc_g_true = g.detach().cpu().numpy()
            indices = data_idx.detach().cpu().numpy()

            probs = outputs.detach().cpu().numpy()
        else:  # concatenate
            acc_y_pred = np.concatenate([
                acc_y_pred,
                np.argmax(outputs.detach().cpu().numpy(), axis=1)
            ])
            acc_y_true = np.concatenate([acc_y_true, y.detach().cpu().numpy()])
            acc_g_true = np.concatenate([acc_g_true, g.detach().cpu().numpy()])
            indices = np.concatenate([indices, data_idx.detach().cpu().numpy()])
            probs = np.concatenate([probs, outputs.detach().cpu().numpy()], axis=0)

        assert probs.shape[0] == indices.shape[0]

# part1
with torch.set_grad_enabled(is_training):  # to make sure we don't save grad when val
    for batch_idx, batch in tqdm(enumerate(loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, data_idx = batch
        outputs = model1(x)

        # now log the desired stats
        # Calculate stats -- get the prediction and compare with groundtruth -- save to output df
        if batch_idx == 0:
            acc_y_pred1 = np.argmax(outputs.detach().cpu().numpy(), axis=1)
            acc_y_true1 = y.detach().cpu().numpy()
            acc_g_true1 = g.detach().cpu().numpy()
            indices1 = data_idx.detach().cpu().numpy()

            probs1 = outputs.detach().cpu().numpy()
        else:  # concatenate
            acc_y_pred1 = np.concatenate([
                acc_y_pred1,
                np.argmax(outputs.detach().cpu().numpy(), axis=1)
            ])
            acc_y_true1 = np.concatenate([acc_y_true1, y.detach().cpu().numpy()])
            acc_g_true1 = np.concatenate([acc_g_true1, g.detach().cpu().numpy()])
            indices1 = np.concatenate([indices1, data_idx.detach().cpu().numpy()])
            probs1 = np.concatenate([probs1, outputs.detach().cpu().numpy()], axis=0)

        assert probs1.shape[0] == indices1.shape[0]

3561it [05:01, 11.81it/s]
3561it [05:02, 11.77it/s]


In [56]:
print("ERM model")
pred_acc1 = (acc_y_pred1 == acc_y_true1)
avg_acc1 = np.sum(pred_acc1)/len(pred_acc1)
print(f"Average acc [n={len(pred_acc1)}]: {avg_acc1:.4f}")
for g in range(4):  # now calculate per-group acc
    g_count1 = np.sum(acc_g_true1 == g)
    group_acc1 = np.sum(pred_acc1 * (acc_g_true1 == g))/g_count1
    print(f"Group {g} [n={g_count1}]: group_acc = {group_acc1:.4f}")

print("\nRetrained model")
# now calculate the final stats
pred_acc = (acc_y_pred == acc_y_true)
avg_acc = np.sum(pred_acc)/len(pred_acc)
print(f"Average acc [n={len(pred_acc)}]: {avg_acc:.4f}")
for g in range(4):  # now calculate per-group acc
    g_count = np.sum(acc_g_true == g)
    group_acc = np.sum(pred_acc * (acc_g_true == g))/g_count
    print(f"Group {g} [n={g_count}]: group_acc = {group_acc:.4f}")

ERM model
Average acc [n=113939]: 0.9653
Group 0 [n=50311]: group_acc = 0.9671
Group 1 [n=46652]: group_acc = 0.9964
Group 2 [n=16012]: group_acc = 0.8990
Group 3 [n=964]: group_acc = 0.4627

Retrained model
Average acc [n=113939]: 0.9239
Group 0 [n=50311]: group_acc = 0.9172
Group 1 [n=46652]: group_acc = 0.9217
Group 2 [n=16012]: group_acc = 0.9518
Group 3 [n=964]: group_acc = 0.9180


In [57]:
print(f"diff: {sum(acc_y_pred != acc_y_pred1)} ({sum(acc_y_pred != acc_y_pred1)*100/len(pred_acc)}%)")
for g in range(4):  # now calculate per-group acc
    g_count = np.sum(acc_g_true == g)
    pred = acc_y_pred * (acc_g_true == g)
    pred1 = acc_y_pred1 * (acc_g_true1 == g)
    print(f"Group {g} [n={g_count}]: diff = {sum(pred != pred1)} %={sum(pred != pred1)/g_count}")

diff: 7537 (6.61494308357981%)
Group 0 [n=50311]: diff = 2609 %=0.0518574466816402
Group 1 [n=46652]: diff = 3489 %=0.07478779044842665
Group 2 [n=16012]: diff = 1000 %=0.062453160129902575
Group 3 [n=964]: diff = 439 %=0.4553941908713693


# Examine part 2

In [58]:
# now run the model on the desired dataset
from tqdm.notebook import tqdm

loader = part2_loader

with torch.set_grad_enabled(is_training):  # to make sure we don't save grad when val
    for batch_idx, batch in tqdm(enumerate(loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, data_idx = batch
        outputs = model(x)

        # now log the desired stats
        # Calculate stats -- get the prediction and compare with groundtruth -- save to output df
        if batch_idx == 0:
            acc_y_pred = np.argmax(outputs.detach().cpu().numpy(), axis=1)
            acc_y_true = y.detach().cpu().numpy()
            acc_g_true = g.detach().cpu().numpy()
            indices = data_idx.detach().cpu().numpy()

            probs = outputs.detach().cpu().numpy()
        else:  # concatenate
            acc_y_pred = np.concatenate([
                acc_y_pred,
                np.argmax(outputs.detach().cpu().numpy(), axis=1)
            ])
            acc_y_true = np.concatenate([acc_y_true, y.detach().cpu().numpy()])
            acc_g_true = np.concatenate([acc_g_true, g.detach().cpu().numpy()])
            indices = np.concatenate([indices, data_idx.detach().cpu().numpy()])
            probs = np.concatenate([probs, outputs.detach().cpu().numpy()], axis=0)

        assert probs.shape[0] == indices.shape[0]

# part1
with torch.set_grad_enabled(is_training):  # to make sure we don't save grad when val
    for batch_idx, batch in tqdm(enumerate(loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, data_idx = batch
        outputs = model1(x)

        # now log the desired stats
        # Calculate stats -- get the prediction and compare with groundtruth -- save to output df
        if batch_idx == 0:
            acc_y_pred1 = np.argmax(outputs.detach().cpu().numpy(), axis=1)
            acc_y_true1 = y.detach().cpu().numpy()
            acc_g_true1 = g.detach().cpu().numpy()
            indices1 = data_idx.detach().cpu().numpy()

            probs1 = outputs.detach().cpu().numpy()
        else:  # concatenate
            acc_y_pred1 = np.concatenate([
                acc_y_pred1,
                np.argmax(outputs.detach().cpu().numpy(), axis=1)
            ])
            acc_y_true1 = np.concatenate([acc_y_true1, y.detach().cpu().numpy()])
            acc_g_true1 = np.concatenate([acc_g_true1, g.detach().cpu().numpy()])
            indices1 = np.concatenate([indices1, data_idx.detach().cpu().numpy()])
            probs1 = np.concatenate([probs1, outputs.detach().cpu().numpy()], axis=0)

        assert probs1.shape[0] == indices1.shape[0]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [59]:
print("ERM model")
pred_acc1 = (acc_y_pred1 == acc_y_true1)
avg_acc1 = np.sum(pred_acc1)/len(pred_acc1)
print(f"Average acc [n={len(pred_acc1)}]: {avg_acc1:.4f}")
for g in range(4):  # now calculate per-group acc
    g_count1 = np.sum(acc_g_true1 == g)
    group_acc1 = np.sum(pred_acc1 * (acc_g_true1 == g))/g_count1
    print(f"Group {g} [n={g_count1}]: group_acc = {group_acc1:.4f}")

print("\nRetrained model")
# now calculate the final stats
pred_acc = (acc_y_pred == acc_y_true)
avg_acc = np.sum(pred_acc)/len(pred_acc)
print(f"Average acc [n={len(pred_acc)}]: {avg_acc:.4f}")
for g in range(4):  # now calculate per-group acc
    g_count = np.sum(acc_g_true == g)
    group_acc = np.sum(pred_acc * (acc_g_true == g))/g_count
    print(f"Group {g} [n={g_count}]: group_acc = {group_acc:.4f}")

ERM model
Average acc [n=48831]: 0.9556
Group 0 [n=21318]: group_acc = 0.9579
Group 1 [n=20222]: group_acc = 0.9955
Group 2 [n=6868]: group_acc = 0.8682
Group 3 [n=423]: group_acc = 0.3546

Retrained model
Average acc [n=48831]: 0.9204
Group 0 [n=21318]: group_acc = 0.9134
Group 1 [n=20222]: group_acc = 0.9217
Group 2 [n=6868]: group_acc = 0.9365
Group 3 [n=423]: group_acc = 0.9527


In [61]:
print(f"diff: {sum(acc_y_pred != acc_y_pred1)} ({sum(acc_y_pred != acc_y_pred1)*100/len(pred_acc)}%)")
for g in range(4):  # now calculate per-group acc
    g_count = np.sum(acc_g_true == g)
    pred = acc_y_pred * (acc_g_true == g)
    pred1 = acc_y_pred1 * (acc_g_true1 == g)
    print(f"Group {g} {g_count/len(pred_acc):.4f} [n={g_count}]: diff = {sum(pred != pred1)} %={sum(pred != pred1)/g_count}")

diff: 3274 (6.704757223894657%)
Group 0 0.4366 [n=21318]: diff = 1011 %=0.04742471151139882
Group 1 0.4141 [n=20222]: diff = 1493 %=0.07383048165364454
Group 2 0.1406 [n=6868]: diff = 517 %=0.07527664531158998
Group 3 0.0087 [n=423]: diff = 253 %=0.5981087470449172
