# Clustering for group info

## TODO
- First load the trained ERM model (should try highly regularized or not) 
- Then extract the features of the training examples (save into a file somewhere?) 
- Then perform k-means on the feature space of each label x correctlyclassified to obtain pseudo-group-labels 
- Then perform some algorithm that can take into account the group information for lowering robust loss

In [14]:
from data.celebA_dataset import CelebADataset
from models import model_attributes
from data.dro_dataset import DRODataset
import torch
import numpy as np

# device='cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

root_dir = '/home/thiennguyen/research/datasets/celebA/'  # dir that contains data
target_name= 'Blond_Hair'  # we are classifying whether the input image is blond or not
confounder_names= ['Male']  # we aim to avoid learning spurious features... here it's the gender
model_type= 'resnet10vw'  # what model we are using to process --> this is to determine the input size to rescale the image
augment_data= False
fraction=1.0
splits = ['train', 'val', 'test']
full_dataset = CelebADataset(root_dir=root_dir,
        target_name=target_name,
        confounder_names=confounder_names,
        model_type=model_type,  # this string is to get the model's input size (for resizing) and input type (image or precomputed)
        augment_data=augment_data)  # augment data adds random resized crop and random flip.

subsets = full_dataset.get_splits(       # basically return the Subsets object with the appropriate indices for train/val/test
        splits,                          # also implements subsampling --> just remove random indices of the appropriate groups in train
        train_frac=fraction,   # fraction means how much of the train data to use --> randomly remove if less than 1
        subsample_to_minority=False)

dro_subsets = [  
    DRODataset(
        subsets[split],  # process each subset separately --> applying the transform parameter.
        process_item_fn=None,
        n_groups=full_dataset.n_groups,
        n_classes=full_dataset.n_classes,
        group_str_fn=full_dataset.group_str) \
    for split in splits]

train_data, val_data, test_data = dro_subsets
train_loader = train_data.get_loader(train=True, reweight_groups=False, batch_size=128)
val_loader = val_data.get_loader(train=False, reweight_groups=None, batch_size=128)
test_loader = test_data.get_loader(train=False, reweight_groups=None, batch_size=128)
example_batch = next(iter(train_loader))
example_batch = tuple(t.to(device) for t in example_batch)
x, y, g, idxs = example_batch

######################################################
###########        Load models and optimizers.  ######
######################################################
from variable_width_resnet import resnet10vw

n_classes = train_data.n_classes
n_groups = train_data.n_groups
margin_shape = n_classes

resnet_width = 16
lr = 0.001
weight_decay = 0.001
# load saved model
log_path = 'logs/cluster_w16s0/'
log_sr_path = 'logs/cluster_w16s0wd0.1/'
model_sr_path = log_sr_path + 'joint/model_100.pth'
model_path = log_path + 'joint/model_100.pth'

modeln = resnet10vw(resnet_width, num_classes=margin_shape)
modeln.load_state_dict(torch.load(model_path))
modeln.to(device)

model_sr = resnet10vw(resnet_width, num_classes=margin_shape)
model_sr.load_state_dict(torch.load(model_sr_path))
model_sr.to(device)


def freeze_all_but_last_layer(model):
    # freeze everything except the last layer
    for name, param in model.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = False
    # make sure that really worked
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    assert len(parameters) == 2  # fc.weight, fc.bias


model = modeln

# initialize the criterion and optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=0.9,
            weight_decay=weight_decay)

Using cuda device


## Get features partition and cluster
- Load the model 
- Run it on the training data and partition it into 4 different sets for clustering

In [15]:
import itertools  # for cartesian product for loops
from tqdm.auto import tqdm
import os

# while looping through the batch, save a dictionary (where the position of the list match up):
# {
#     activations : [],  # the features
#     labels : [],       # the corresponding labels 
#     correctly_classified : [], # whether the given model correctly classified
#     groups : [], # is this cheating? 
# }


def get_save_features_hook_fn(list_for_features: [torch.Tensor]):
    """
        Returns:
            function for forward hook that appends the input (i.e. features) of the last layer to list_for_features
    """
    def hook_fn(self, inp, outp):
        list_for_features.append(inp[0].detach().cpu())  # this saves the feature to the given feature list
    return hook_fn
    

def get_output_sets(loader, model, classifying_groups = False):
    output_sets = {'activations' : [],  # the features
                   'predicted' : [], # the model's prediction
#                    'groups' : [], # is this cheating? 
                   'labels' : [],       # the corresponding labels 
                   'idx' : []
                 }
    model.eval()
    # the line below makes it so that whenever model.forward is called, we save the features to our activations list.
    # the feature handler ensures we remove the hook after we are done extracting the features
    feature_handler = model.fc.register_forward_hook(get_save_features_hook_fn(output_sets['activations']))

    with torch.set_grad_enabled(False):
        for batch_idx, batch in enumerate(tqdm(loader)):
            batch = tuple(t.to(device) for t in batch)
            x, y, g, idx = batch
            outputs = model(x)  
            to_predict = g if classifying_groups else y
            _, predicted = torch.max(outputs.data, 1)
            
            output_sets['labels'].append(y.detach().cpu())
            output_sets['predicted'].append(predicted.detach().cpu())
            output_sets['idx'].append(idx.detach().cpu())
    
    feature_handler.remove()  # we unregister the forward hook. 
    return output_sets


output_dict_save_path = os.path.join(log_path,  'output_sets.pth')

# # save output_sets
output_sets = get_output_sets(train_loader, model)
torch.save(output_sets, output_dict_save_path)

HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))




## Cluster the sets
- First we load the output sets again

In [3]:
import torch
import os
from sklearn.cluster import KMeans


group_features_dict = {}  # this contains the features of the different labels/correctly-classified groups
group_idx_dict = {}
output_sets = torch.load(output_dict_save_path)

print(len(output_sets['activations']), output_sets['activations'][0].shape)
x_feature = torch.cat(output_sets['activations'])
y_array = torch.cat(output_sets['labels'])
predicted = torch.cat(output_sets['predicted'])
idxs = torch.cat(output_sets['idx'])

# then we partition it into the combinations: label x [correctly, incorrectly classified]
for label in range(n_classes):
    correct_select = (y_array == label) & (y_array == predicted)
    group_features_dict[f'class{label}_correct'] = x_feature[correct_select]
    group_idx_dict[f'class{label}_correct'] = idxs[correct_select].numpy()
    
    wrong_select = (y_array == label) & (y_array != predicted)
    group_features_dict[f'class{label}_wrong'] = x_feature[wrong_select]
    group_idx_dict[f'class{label}_wrong'] = idxs[wrong_select].numpy()

1272 torch.Size([128, 128])


- Cluster the subgroups

In [5]:
from copy import copy  # copy the labels from the cluster
n_clusters = 2  # 4, 300: 781, 
max_iter = 300  # 5, 300: 535.927, 400: 


def get_idxs_to_subgroup_labels(cluster_model, n_clusters, group_features_dict, n_classes, group_idx_dict, max_iter=300):
    cluster_assignments = {}
    for label in range(n_classes):
        for stat in ['correct', 'wrong']:
            cluster_model = KMeans(n_clusters, max_iter=max_iter)
            cluster_model.fit(group_features_dict[f'class{label}_{stat}'])
            cluster_assignments[f'class{label}_{stat}'] = copy(cluster_model.labels_)


    # Now we map each idx for each point to its appropriate group label
    idxs_to_subgroup_labels = {}

    group_label_counter = 0  # so that the group labels for each groups get a unique number
    for k in cluster_assignments.keys():
        idx_array = group_idx_dict[k]
        assignment_array = cluster_assignments[k] + group_label_counter
        idxs_to_subgroup_labels.update({idx: assignment for idx, assignment in zip(idx_array, assignment_array)})
        group_label_counter += len(np.unique(cluster_assignments[k]))

    print("n_groups = ", group_label_counter)
    return idxs_to_subgroup_labels, group_label_counter

idxs_to_subgroup_labels, group_label_counter = get_idxs_to_subgroup_labels(cluster_model, n_clusters, group_features_dict, n_classes, group_idx_dict, max_iter)
torch.save(idxs_to_subgroup_labels, os.path.join(log_sr_path, 'idx_to_subgroup_labels.pth'))

n_groups =  8


In [6]:
from data.pseudogrouplabels_dataset import PseudoGroupLabelsDataset

# PseudoGroupLabelsDataset implements so that the dataloader 
pgl_train_data = PseudoGroupLabelsDataset(train_data, idxs_to_subgroup_labels, group_label_counter)
pgl_train_loader = pgl_train_data.get_loader(train=True, reweight_groups=False, batch_size=128)

In [9]:
train_data.get_g(355)

0

## Cluster Analysis

In [16]:
group_to_pseudogroup = {i:{sg: 0 for sg in range(group_label_counter)} for i in range(4)}  # dict of {real_group: {pseudo_group_label: count}}
pseudogroup_to_group = {sg: {i: 0 for i in range(4)} for sg in range(group_label_counter)}

for idx, sgl in idxs_to_subgroup_labels.items():
    group_to_pseudogroup[train_data.get_g(idx)][sgl] += 1
    pseudogroup_to_group[sgl][train_data.get_g(idx)] += 1
    

In [20]:
from pprint import pprint
print("groups to pseudogroups")
pprint(group_to_pseudogroup)
print('pseudogroups to groups')
pprint(pseudogroup_to_group)
print("pseudogroups: \n\t0,1:\tLabel0_correct\n\t2,3:\tLabel0_wrong\n\t4,5:\tLabel1_correct\n\t6,7:\tLabel1_wrong")

groups to pseudogroups
{0: {0: 61904, 1: 9504, 2: 69, 3: 152, 4: 0, 5: 0, 6: 0, 7: 0},
 1: {0: 62342, 1: 4520, 2: 2, 3: 10, 4: 0, 5: 0, 6: 0, 7: 0},
 2: {0: 0, 1: 0, 2: 0, 3: 0, 4: 4401, 5: 2549, 6: 9546, 7: 6384},
 3: {0: 0, 1: 0, 2: 0, 3: 0, 4: 28, 5: 6, 6: 369, 7: 984}}
pseudogroups to groups
{0: {0: 61904, 1: 62342, 2: 0, 3: 0},
 1: {0: 9504, 1: 4520, 2: 0, 3: 0},
 2: {0: 69, 1: 2, 2: 0, 3: 0},
 3: {0: 152, 1: 10, 2: 0, 3: 0},
 4: {0: 0, 1: 0, 2: 4401, 3: 28},
 5: {0: 0, 1: 0, 2: 2549, 3: 6},
 6: {0: 0, 1: 0, 2: 9546, 3: 369},
 7: {0: 0, 1: 0, 2: 6384, 3: 984}}
pseudogroups: 
	0,1:	Label0_correct
	2,3:	Label0_wrong
	4,5:	Label1_correct
	6,7:	Label1_wrong


## Finally retrain with gDRO

In [21]:
from losses import LossComputer
from train import run_epoch 
import csv

class BasicLogger:
    def __init__(self):
        pass
    
    def write(self, msg):
        print(msg)
        
robust = True

logger = BasicLogger()

criterion = torch.nn.CrossEntropyLoss(reduction='none')
criterion.to(device)
loss_computer = LossComputer(criterion, is_robust=robust, dataset=pgl_train_data)
wd = 0.0001

######################################################
###########        Load models and optimizers.  ######
######################################################
model = resnet10vw(resnet_width, num_classes=n_classes)
model.to(device)
optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=0.9,
            weight_decay=wd)

######################################################
###########          Log files                  ######
######################################################
mode = 'w'
log_dir = os.path.join(log_sr_path, 'cluster_gDRO')
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
train_path = open(os.path.join(log_dir, 'train.csv'), mode)
val_path = open(os.path.join(log_dir, 'val.csv'), mode)

n_groups = pgl_train_data.n_groups
val_n_groups = 4
total_acc_per_group = [f'total_acc:g{i}' for i in range(n_groups)]
group_accs = [f'group{i}_acc' for i in range(n_groups)]
group_margins = [f'group{i}_margin' for i in range(n_groups)]
train_columns = ['epoch', 'total_acc',  'split_acc', 'loss',
                 'avg_margin'] + total_acc_per_group + group_accs + group_margins
total_acc_per_group = [f'total_acc:g{i}' for i in range(val_n_groups)]
group_accs = [f'group{i}_acc' for i in range(val_n_groups)]
group_margins = [f'group{i}_margin' for i in range(val_n_groups)]
valtest_columns = ['epoch', 'total_acc',  'split_acc',
                       'avg_margin'] + total_acc_per_group + group_accs + group_margins

train_writer = csv.DictWriter(train_path, fieldnames=train_columns)
val_writer = csv.DictWriter(val_path, fieldnames=valtest_columns)
train_writer.writeheader()
val_writer.writeheader()

In [None]:
for epoch in range(25):
    # train
    logger.write(f'Train epoch {epoch}')
    run_epoch(epoch + 1, model, device, optimizer, pgl_train_loader, loss_computer, train_writer, logger,
              is_training=True, is_robust=robust, classifying_groups=False)
    run_epoch(epoch, model, device, optimizer, val_loader, loss_computer, val_writer, logger,
                  is_training=False, is_robust=True, classifying_groups=False)

Train epoch 0


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[1,   200] loss: 0.524, avg_margin: -0.081. adv_probs: tensor([0.1689, 0.2455, 0.0489, 0.0550, 0.1050, 0.0911, 0.1256, 0.1599],
       device='cuda:0')
[1,   400] loss: 0.566, avg_margin: 0.077. adv_probs: tensor([0.1207, 0.3461, 0.0186, 0.0227, 0.0691, 0.0486, 0.1146, 0.2597],
       device='cuda:0')
[1,   600] loss: 0.608, avg_margin: 0.118. adv_probs: tensor([0.0858, 0.4021, 0.0060, 0.0081, 0.0469, 0.0251, 0.0980, 0.3280],
       device='cuda:0')
[1,   800] loss: 0.626, avg_margin: 0.118. adv_probs: tensor([0.0687, 0.4272, 0.0020, 0.0029, 0.0355, 0.0152, 0.0915, 0.3572],
       device='cuda:0')
[1,  1000] loss: 0.629, avg_margin: 0.112. adv_probs: tensor([0.0586, 0.4382, 0.0006, 0.0011, 0.0289, 0.0093, 0.0876, 0.3757],
       device='cuda:0')
[1,  1200] loss: 0.622, avg_margin: 0.107. adv_probs: tensor([5.1335e-02, 4.4440e-01, 2.0637e-04, 3.7226e-04, 2.4104e-02, 5.8176e-03,
        8.2343e-02, 3.9142e-01], device='cuda:0')

{'avg_margin': '0.1040',
 'epoch': 1,
 'group0_acc': '0.764

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '0.0726',
 'epoch': 0,
 'group0_acc': '0.7058',
 'group0_margin': '0.0549',
 'group1_acc': '0.6841',
 'group1_margin': '0.0442',
 'group2_acc': '0.7234',
 'group2_margin': '0.2152',
 'group3_acc': '0.5000',
 'group3_margin': '-0.0619',
 'split_acc': '0.6974',
 'total_acc': '0.6974',
 'total_acc:g0': '0.7058',
 'total_acc:g1': '0.6841',
 'total_acc:g2': '0.7234',
 'total_acc:g3': '0.5000'}
Train epoch 1


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[2,   200] loss: 0.634, avg_margin: 0.003. adv_probs: tensor([4.8634e-02, 4.4819e-01, 4.1912e-05, 8.4238e-05, 1.8769e-02, 3.3075e-03,
        8.2048e-02, 3.9893e-01], device='cuda:0')
[2,   400] loss: 0.624, avg_margin: 0.004. adv_probs: tensor([4.8999e-02, 4.4877e-01, 1.4149e-05, 2.9624e-05, 1.6416e-02, 2.4464e-03,
        8.2678e-02, 4.0065e-01], device='cuda:0')
[2,   600] loss: 0.626, avg_margin: 0.014. adv_probs: tensor([4.6417e-02, 4.5010e-01, 4.3421e-06, 9.9370e-06, 1.5692e-02, 1.7681e-03,
        8.3443e-02, 4.0256e-01], device='cuda:0')
[2,   800] loss: 0.610, avg_margin: 0.022. adv_probs: tensor([4.4866e-02, 4.5120e-01, 1.4053e-06, 3.4755e-06, 1.4234e-02, 1.2377e-03,
        8.6231e-02, 4.0223e-01], device='cuda:0')
[2,  1000] loss: 0.613, avg_margin: 0.025. adv_probs: tensor([4.3483e-02, 4.5325e-01, 4.3567e-07, 1.1681e-06, 1.3672e-02, 8.5380e-04,
        8.8704e-02, 4.0003e-01], device='cuda:0')
[2,  1200] loss: 0.615, avg_margin: 0.036. adv_probs: tensor([3.8636e-02, 4.5963

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '0.2489',
 'epoch': 1,
 'group0_acc': '0.8101',
 'group0_margin': '0.2693',
 'group1_acc': '0.8358',
 'group1_margin': '0.2890',
 'group2_acc': '0.6900',
 'group2_margin': '0.1114',
 'group3_acc': '0.3846',
 'group3_margin': '-0.3595',
 'split_acc': '0.7995',
 'total_acc': '0.7995',
 'total_acc:g0': '0.8101',
 'total_acc:g1': '0.8358',
 'total_acc:g2': '0.6900',
 'total_acc:g3': '0.3846'}
Train epoch 2


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[3,   200] loss: 0.605, avg_margin: 0.100. adv_probs: tensor([3.3321e-02, 4.6369e-01, 3.3071e-08, 1.0899e-07, 9.6964e-03, 3.1633e-04,
        8.6839e-02, 4.0614e-01], device='cuda:0')
[3,   400] loss: 0.603, avg_margin: 0.074. adv_probs: tensor([3.2709e-02, 4.6592e-01, 1.0824e-08, 3.9180e-08, 9.7871e-03, 2.3403e-04,
        9.0592e-02, 4.0076e-01], device='cuda:0')
[3,   600] loss: 0.591, avg_margin: 0.082. adv_probs: tensor([3.1163e-02, 4.6702e-01, 3.6785e-09, 1.4132e-08, 8.6278e-03, 1.5406e-04,
        9.4334e-02, 3.9870e-01], device='cuda:0')
[3,   800] loss: 0.590, avg_margin: 0.086. adv_probs: tensor([2.9399e-02, 4.6900e-01, 1.2055e-09, 5.0048e-09, 7.5135e-03, 1.0158e-04,
        9.6213e-02, 3.9777e-01], device='cuda:0')
[3,  1000] loss: 0.592, avg_margin: 0.087. adv_probs: tensor([2.7891e-02, 4.7029e-01, 4.1564e-10, 1.8717e-09, 6.3696e-03, 6.7708e-05,
        9.9070e-02, 3.9632e-01], device='cuda:0')
[3,  1200] loss: 0.610, avg_margin: 0.085. adv_probs: tensor([2.5946e-02, 4.7321

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '0.0090',
 'epoch': 2,
 'group0_acc': '0.6384',
 'group0_margin': '-0.0055',
 'group1_acc': '0.6385',
 'group1_margin': '0.0090',
 'group2_acc': '0.6708',
 'group2_margin': '0.0677',
 'group3_acc': '0.4670',
 'group3_margin': '-0.2401',
 'split_acc': '0.6416',
 'total_acc': '0.6416',
 'total_acc:g0': '0.6384',
 'total_acc:g1': '0.6385',
 'total_acc:g2': '0.6708',
 'total_acc:g3': '0.4670'}
Train epoch 3


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[4,   200] loss: 0.588, avg_margin: 0.074. adv_probs: tensor([2.4224e-02, 4.7343e-01, 2.9336e-11, 1.6063e-10, 4.5277e-03, 2.2510e-05,
        1.0444e-01, 3.9336e-01], device='cuda:0')
[4,   400] loss: 0.576, avg_margin: 0.085. adv_probs: tensor([2.3854e-02, 4.7317e-01, 9.8072e-12, 5.8204e-11, 4.2773e-03, 1.6649e-05,
        1.0607e-01, 3.9261e-01], device='cuda:0')
[4,   600] loss: 0.586, avg_margin: 0.096. adv_probs: tensor([2.2432e-02, 4.7764e-01, 3.4801e-12, 2.2347e-11, 3.3526e-03, 1.0624e-05,
        1.0453e-01, 3.9204e-01], device='cuda:0')
[4,   800] loss: 0.592, avg_margin: 0.100. adv_probs: tensor([2.1318e-02, 4.7454e-01, 1.2651e-12, 8.0543e-12, 2.6148e-03, 6.1267e-06,
        1.0087e-01, 4.0065e-01], device='cuda:0')
[4,  1000] loss: 0.589, avg_margin: 0.094. adv_probs: tensor([2.1387e-02, 4.7669e-01, 4.3282e-13, 3.3177e-12, 2.1611e-03, 4.0020e-06,
        9.8867e-02, 4.0089e-01], device='cuda:0')
[4,  1200] loss: 0.589, avg_margin: 0.091. adv_probs: tensor([2.1174e-02, 4.7736

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '0.1074',
 'epoch': 3,
 'group0_acc': '0.6676',
 'group0_margin': '0.0720',
 'group1_acc': '0.6938',
 'group1_margin': '0.1401',
 'group2_acc': '0.6910',
 'group2_margin': '0.1451',
 'group3_acc': '0.4560',
 'group3_margin': '-0.3134',
 'split_acc': '0.6800',
 'total_acc': '0.6800',
 'total_acc:g0': '0.6676',
 'total_acc:g1': '0.6938',
 'total_acc:g2': '0.6910',
 'total_acc:g3': '0.4560'}
Train epoch 4


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[5,   200] loss: 0.580, avg_margin: 0.132. adv_probs: tensor([1.9869e-02, 4.8016e-01, 3.7669e-14, 3.1826e-13, 1.3725e-03, 1.3710e-06,
        9.6347e-02, 4.0225e-01], device='cuda:0')
[5,   400] loss: 0.570, avg_margin: 0.110. adv_probs: tensor([2.0422e-02, 4.8246e-01, 1.3646e-14, 1.1774e-13, 1.2561e-03, 9.3779e-07,
        1.0460e-01, 3.9126e-01], device='cuda:0')
[5,   600] loss: 0.572, avg_margin: 0.129. adv_probs: tensor([1.8421e-02, 4.7883e-01, 4.8816e-15, 4.8164e-14, 9.8565e-04, 6.1071e-07,
        9.8808e-02, 4.0295e-01], device='cuda:0')
[5,   800] loss: 0.567, avg_margin: 0.132. adv_probs: tensor([1.7601e-02, 4.8241e-01, 1.6519e-15, 1.8253e-14, 8.1975e-04, 3.8425e-07,
        9.9456e-02, 3.9972e-01], device='cuda:0')
[5,  1000] loss: 0.550, avg_margin: 0.132. adv_probs: tensor([1.8061e-02, 4.8146e-01, 6.2096e-16, 7.2589e-15, 6.2605e-04, 2.1990e-07,
        9.6834e-02, 4.0302e-01], device='cuda:0')
[5,  1200] loss: 0.561, avg_margin: 0.137. adv_probs: tensor([1.7214e-02, 4.8336

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '0.0849',
 'epoch': 4,
 'group0_acc': '0.6586',
 'group0_margin': '0.0737',
 'group1_acc': '0.6624',
 'group1_margin': '0.1352',
 'group2_acc': '0.6058',
 'group2_margin': '0.0017',
 'group3_acc': '0.4341',
 'group3_margin': '-0.3692',
 'split_acc': '0.6505',
 'total_acc': '0.6505',
 'total_acc:g0': '0.6586',
 'total_acc:g1': '0.6624',
 'total_acc:g2': '0.6058',
 'total_acc:g3': '0.4341'}
Train epoch 5


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[6,   200] loss: 0.537, avg_margin: 0.221. adv_probs: tensor([1.5881e-02, 4.8686e-01, 5.9590e-17, 7.4782e-16, 3.9768e-04, 6.8503e-08,
        9.6398e-02, 4.0046e-01], device='cuda:0')
[6,   400] loss: 0.555, avg_margin: 0.195. adv_probs: tensor([1.5396e-02, 4.8234e-01, 2.2846e-17, 2.7650e-16, 3.2736e-04, 4.2513e-08,
        9.4787e-02, 4.0715e-01], device='cuda:0')
[6,   600] loss: 0.554, avg_margin: 0.174. adv_probs: tensor([1.5858e-02, 4.8123e-01, 8.1436e-18, 1.1158e-16, 2.8064e-04, 2.7520e-08,
        9.6101e-02, 4.0653e-01], device='cuda:0')
[6,   800] loss: 0.555, avg_margin: 0.174. adv_probs: tensor([1.5184e-02, 4.8497e-01, 2.9766e-18, 4.3809e-17, 2.1098e-04, 1.5659e-08,
        9.0830e-02, 4.0881e-01], device='cuda:0')
[6,  1000] loss: 0.551, avg_margin: 0.169. adv_probs: tensor([1.5711e-02, 4.8570e-01, 1.0572e-18, 1.8258e-17, 1.7349e-04, 1.0672e-08,
        9.2163e-02, 4.0625e-01], device='cuda:0')
[6,  1200] loss: 0.545, avg_margin: 0.168. adv_probs: tensor([1.5677e-02, 4.8314

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '0.1913',
 'epoch': 5,
 'group0_acc': '0.7219',
 'group0_margin': '0.1187',
 'group1_acc': '0.7823',
 'group1_margin': '0.2537',
 'group2_acc': '0.7171',
 'group2_margin': '0.2536',
 'group3_acc': '0.5055',
 'group3_margin': '-0.2241',
 'split_acc': '0.7443',
 'total_acc': '0.7443',
 'total_acc:g0': '0.7219',
 'total_acc:g1': '0.7823',
 'total_acc:g2': '0.7171',
 'total_acc:g3': '0.5055'}
Train epoch 6


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[7,   200] loss: 0.529, avg_margin: 0.194. adv_probs: tensor([1.5469e-02, 4.8466e-01, 1.0961e-19, 2.3789e-18, 1.1645e-04, 4.3554e-09,
        1.0269e-01, 3.9707e-01], device='cuda:0')
[7,   400] loss: 0.541, avg_margin: 0.204. adv_probs: tensor([1.4714e-02, 4.8835e-01, 4.1437e-20, 1.0857e-18, 9.3078e-05, 2.7317e-09,
        1.0101e-01, 3.9584e-01], device='cuda:0')
[7,   600] loss: 0.534, avg_margin: 0.223. adv_probs: tensor([1.3354e-02, 4.8595e-01, 1.5821e-20, 4.3013e-19, 7.6210e-05, 1.5429e-09,
        1.0432e-01, 3.9630e-01], device='cuda:0')


# PLAYGROUND

In [None]:
# # Checking whether the number between saving and loading lists above is the same as brute force

# import itertools

# model = model_sr
# model.eval()

# test_dict = {f'class{label}_{stat}':[] for label,stat in itertools.product(range(n_classes), ['correct', 'wrong'])}

# with torch.set_grad_enabled(False):
#     for batch_idx, batch in enumerate(tqdm(loader)):
#         batch = tuple(t.to(device) for t in batch)
#         x, y, g, idx = batch
#         outputs = model(x)  
#         to_predict = y
#         _, predicted = torch.max(outputs.data, 1)
        
#         for label in range(n_classes):
#             test_dict[f'class{label}_correct'] += [x[(to_predict == predicted) & (y == label)].cpu().numpy()]
#             test_dict[f'class{label}_wrong'] += [x[(to_predict != predicted) & (y == label)].cpu().numpy()]


# for key in test_dict:
#     print(key, np.concatenate(test_dict[key], axis=0).shape)

In [None]:
# model.eval()

# output_sets = {'activations' : [],  # the features
#                'predicted' : [], # the model's prediction
#                'labels' : [],       # the corresponding labels 
#                'idx' : []
#              }
# # the line below makes it so that whenever model.forward is called, we save the features to our activations list.
# # the feature handler ensures we remove the hook after we are done extracting the features
# feature_handler = model.fc.register_forward_hook(get_save_features_hook_fn(output_sets['activations']))
# stop = 0
# with torch.set_grad_enabled(False):
#     for batch_idx, batch in enumerate(tqdm(train_loader)):
#         batch = tuple(t.to(device) for t in batch)
#         x, y, g, idx = batch

#         outputs = model(x)  
#         to_predict =  y
#         _, predicted = torch.max(outputs.data, 1)
#         output_sets['labels'].append(y.cpu())
#         output_sets['predicted'].append(predicted.cpu())
#         output_sets['idx'].append(idx.cpu())
#         print('x', x.norm())
#         print('pred', predicted)
#         print('y',y)
#         print('idx', idx)
#         print('----')
#         break
        
    

# # print()
# # for k in output_sets:
# #     if k == 'activations':
# #         print(k, [output_sets[k][j].norm() for j in range(len(output_sets[k]))])
# #     else:
# #         print(k, output_sets[k])
# #         print(torch.cat(output_sets[k]))
        
