# Clustering for group info

## TODO
- First load the trained ERM model (should try highly regularized or not) 
- Then extract the features of the training examples (save into a file somewhere?) 
- Then perform k-means on the feature space of each label x correctlyclassified to obtain pseudo-group-labels 
- Then perform some algorithm that can take into account the group information for lowering robust loss

In [1]:
from data.celebA_dataset import CelebADataset
from models import model_attributes
from data.dro_dataset import DRODataset
import torch
import numpy as np

# device='cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

root_dir = '/home/thiennguyen/research/datasets/celebA/'  # dir that contains data
target_name= 'Blond_Hair'  # we are classifying whether the input image is blond or not
confounder_names= ['Male']  # we aim to avoid learning spurious features... here it's the gender
model_type= 'resnet10vw'  # what model we are using to process --> this is to determine the input size to rescale the image
augment_data= False
fraction=1.0
splits = ['train', 'val', 'test']
full_dataset = CelebADataset(root_dir=root_dir,
        target_name=target_name,
        confounder_names=confounder_names,
        model_type=model_type,  # this string is to get the model's input size (for resizing) and input type (image or precomputed)
        augment_data=augment_data)  # augment data adds random resized crop and random flip.

subsets = full_dataset.get_splits(       # basically return the Subsets object with the appropriate indices for train/val/test
        splits,                          # also implements subsampling --> just remove random indices of the appropriate groups in train
        train_frac=fraction,   # fraction means how much of the train data to use --> randomly remove if less than 1
        subsample_to_minority=False)

dro_subsets = [  
    DRODataset(
        subsets[split],  # process each subset separately --> applying the transform parameter.
        process_item_fn=None,
        n_groups=full_dataset.n_groups,
        n_classes=full_dataset.n_classes,
        group_str_fn=full_dataset.group_str) \
    for split in splits]

train_data, val_data, test_data = dro_subsets
train_loader = train_data.get_loader(train=True, reweight_groups=False, batch_size=128)
val_loader = val_data.get_loader(train=False, reweight_groups=None, batch_size=5)
test_loader = test_data.get_loader(train=False, reweight_groups=None, batch_size=128)
example_batch = next(iter(train_loader))
example_batch = tuple(t.to(device) for t in example_batch)
x, y, g, idxs = example_batch

######################################################
###########        Load models and optimizers.  ######
######################################################
from variable_width_resnet import resnet10vw

n_classes = train_data.n_classes
n_groups = train_data.n_groups
margin_shape = n_classes

resnet_width = 16
lr = 0.001
weight_decay = 0.001
# load saved model
log_path = 'logs/cluster_w16s0/'
log_sr_path = 'logs/cluster_w16s0wd0.1/'
model_sr_path = log_sr_path + 'joint/model_100.pth'
model_path = log_path + 'joint/model_100.pth'

modeln = resnet10vw(resnet_width, num_classes=margin_shape)
modeln.load_state_dict(torch.load(model_path))
modeln.to(device)

model_sr = resnet10vw(resnet_width, num_classes=margin_shape)
model_sr.load_state_dict(torch.load(model_sr_path))
model_sr.to(device)


def freeze_all_but_last_layer(model):
    # freeze everything except the last layer
    for name, param in model.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = False
    # make sure that really worked
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    assert len(parameters) == 2  # fc.weight, fc.bias


# initialize the criterion and optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=0.9,
            weight_decay=weight_decay)

Using cuda device


NameError: name 'model' is not defined

## Train ERM Model 

## Evaluate ERM Model and Obtain Error Set
- Load the model 
- Run it on the training data and partition it into 4 different sets for clustering

In [4]:
import itertools  # for cartesian product for loops
from tqdm.auto import tqdm
import os

# while looping through the batch, save a dictionary (where the position of the list match up):
# {
#     activations : [],  # the features
#     labels : [],       # the corresponding labels 
#     correctly_classified : [], # whether the given model correctly classified
#     groups : [], # is this cheating? 
# }


def get_save_features_hook_fn(list_for_features: [torch.Tensor]):
    """
        Returns:
            function for forward hook that appends the input (i.e. features) of the last layer to list_for_features
    """
    def hook_fn(self, inp, outp):
        list_for_features.append(inp[0].detach().cpu())  # this saves the feature to the given feature list
    return hook_fn
    
def get_output_sets(loader, model, classifying_groups = False):
    output_sets = {'activations' : [],  # the features
                   'predicted' : [], # the model's prediction
#                    'groups' : [], # is this cheating? 
                   'labels' : [],       # the corresponding labels 
                   'idx' : []
                 }
    model.eval()
    # the line below makes it so that whenever model.forward is called, we save the features to our activations list.
    # the feature handler ensures we remove the hook after we are done extracting the features
    feature_handler = model.fc.register_forward_hook(get_save_features_hook_fn(output_sets['activations']))

    with torch.set_grad_enabled(False):
        for batch_idx, batch in enumerate(tqdm(loader)):
            batch = tuple(t.to(device) for t in batch)
            x, y, g, idx = batch
            outputs = model(x)  
            to_predict = g if classifying_groups else y
            _, predicted = torch.max(outputs.data, 1)
            
            output_sets['labels'].append(y.detach().cpu())
            output_sets['predicted'].append(predicted.detach().cpu())
            output_sets['idx'].append(idx.detach().cpu())
    
    feature_handler.remove()  # we unregister the forward hook. 
    return output_sets

output_sets = get_output_sets(train_loader, model_sr)

# save output_sets
output_dict_save_path = os.path.join(log_sr_path,  'output_sets.pth')
torch.save(output_sets, output_dict_save_path)

HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

correct = tensor(145254, device='cuda:0')
correct0 =  tensor(138270, device='cuda:0')


## Cluster the sets
- First we load the output sets again

In [5]:
import torch
import os
from sklearn.cluster import KMeans


group_features_dict = {}  # this contains the features of the different labels/correctly-classified groups
group_idx_dict = {}
output_sets = torch.load(output_dict_save_path)

print(len(output_sets['activations']), output_sets['activations'][0].shape)
x_feature = torch.cat(output_sets['activations'])
y_array = torch.cat(output_sets['labels'])
predicted = torch.cat(output_sets['predicted'])
idxs = torch.cat(output_sets['idx'])

for label in range(n_classes):
    correct_select = (y_array == label) & (y_array == predicted)
    group_features_dict[f'class{label}_correct'] = x_feature[correct_select]
    group_idx_dict[f'class{label}_correct'] = idxs[correct_select]
    
    wrong_select = (y_array == label) & (y_array != predicted)
    group_features_dict[f'class{label}_wrong'] = x_feature[wrong_select]
    group_idx_dict[f'class{label}_wrong'] = idxs[wrong_select]

1272 torch.Size([128, 128])


### Now we can cluster each subgroup

In [18]:
from copy import copy
n_clusters = 4  # 4, 300: 781, 
max_iter = 300  # 5, 300: 535.927, 400: 

cluster_assignments = {}

for label in range(n_classes):
    for stat in ['correct', 'wrong']:
        cluster_model = KMeans(n_clusters, max_iter=max_iter)
        cluster_model.fit(group_features_dict[f'class{label}_{stat}'])
        cluster_assignments[f'class{label}_{stat}'] = copy(cluster_model.labels_)

    
print(cluster_model.inertia_)

781.2484003058303


## Finally retrain with gDRO

In [19]:
a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])
torch.stack([a,b])

tensor([[1, 2, 3],
        [4, 5, 6]])

# PLAYGROUND

In [14]:
x_feature[(y_array == 0) & (y_array == predicted)].shape

torch.Size([138270, 128])

In [7]:
# print(y_array[0:30])
print(output_sets['labels'][0])
print(output_sets['idx'][0])


# tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,
#         0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
#         0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
#         1, 0, 0, 0, 0, 0, 0, 0])
# tensor([116512,  86576, 137033, 137783,  41092, 124173,  55239,  72699,  91503,
#          39249, 140607,  33405,  35304,  98070,  89792,  30422,  42185, 127753,
#         133399,  49669,  10597,   1907,  23412, 122663, 132795,  11350,  49146,
#         136059,  92930,  17800, 143768,  43261, 100543,  16203,  18786,  53670,
#          73070,  79761, 117203, 119135,  90050,  85270, 117438, 161047, 146605,
#         158457,  53547, 154838,  26025,  21408,  24719,  12405,  11870,  68837,
#           3229, 136679,  47042,  74359,  86986, 102683, 105254,  91918, 124831,
#         143581,  32862,  67845,  90193,  23095,  80341,   4068,  44223, 113745,
#          17024,  99588,  53284,  18192,  82366,  78202,  74944,  48619, 105401,
#          45699,  86736, 162564, 127440, 129152,  90281, 121655, 137138, 105555,
#         144738,   1960, 105210, 102164,  29702, 160608, 116402,  48100, 112067,
#           1149, 132157,  81982, 130300,  18743,  73463,  61076,  15688, 124098,
#          36976,  70666,  53980,  46209,  15880, 155393, 152950,  68808,  98545,
#          83046, 162318, 121945,  63439, 151894, 139379, 149241, 133724, 107680,
#         122864, 126772])


tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        1, 0, 0, 0, 0, 0, 0, 0])
tensor([116512,  86576, 137033, 137783,  41092, 124173,  55239,  72699,  91503,
         39249, 140607,  33405,  35304,  98070,  89792,  30422,  42185, 127753,
        133399,  49669,  10597,   1907,  23412, 122663, 132795,  11350,  49146,
        136059,  92930,  17800, 143768,  43261, 100543,  16203,  18786,  53670,
         73070,  79761, 117203, 119135,  90050,  85270, 117438, 161047, 146605,
        158457,  53547, 154838,  26025,  21408,  24719,  12405,  11870,  68837,
          3229, 136679,  47042,  74359,  86986, 102683, 105254,  91918, 124831,
       

In [5]:
# Checking whether the number between saving and loading lists above is the same as brute force

import itertools

model.eval()

test_dict = {f'class{label}_{stat}':[] for label,stat in itertools.product(range(n_classes), ['correct', 'wrong'])}

with torch.set_grad_enabled(False):
    for batch_idx, batch in enumerate(tqdm(loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, idx = batch
        outputs = model(x)  
        to_predict = y
        _, predicted = torch.max(outputs.data, 1)
        
        for label in range(n_classes):
            test_dict[f'class{label}_correct'] += [x[(to_predict == predicted) & (y == label)].cpu().numpy()]
            test_dict[f'class{label}_wrong'] += [x[(to_predict != predicted) & (y == label)].cpu().numpy()]



HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))




In [9]:
for key in test_dict:
    print(key, np.concatenate(test_dict[key], axis=0).shape)

class0_correct (138503, 3, 224, 224)
class0_wrong (0, 3, 224, 224)
class1_correct (23646, 3, 224, 224)
class1_wrong (621, 3, 224, 224)


In [23]:
a = np.random.rand(5,3)
b = np.random.rand(5,3)
c = np.array([0,1,0,1,0])
d = np.array([0,0,1,1,0])
print('a\n',a)
# print(b)
print(c)
print(d)


a
 [[0.19225536 0.72425263 0.78405661]
 [0.69527028 0.74528752 0.11246908]
 [0.44786416 0.15620919 0.70220325]
 [0.47935582 0.34016257 0.35106018]
 [0.55154411 0.66761181 0.01698646]]
[[0.21949124 0.86183075 0.90310517]
 [0.94968505 0.3059399  0.84207108]
 [0.01026236 0.90690631 0.04055816]
 [0.31990618 0.9845833  0.43567266]
 [0.14328209 0.80166115 0.72445191]]
[0 1 0 1 0]
[0 0 1 1 0]


In [27]:
d==True
a[(c==0) & (d==False)]

array([[0.19225536, 0.72425263, 0.78405661],
       [0.55154411, 0.66761181, 0.01698646]])

In [None]:
model.eval()

output_sets = {'activations' : [],  # the features
               'predicted' : [], # the model's prediction
               'labels' : [],       # the corresponding labels 
               'idx' : []
             }
# the line below makes it so that whenever model.forward is called, we save the features to our activations list.
# the feature handler ensures we remove the hook after we are done extracting the features
feature_handler = model.fc.register_forward_hook(get_save_features_hook_fn(output_sets['activations']))
stop = 0
with torch.set_grad_enabled(False):
    for batch_idx, batch in enumerate(tqdm(train_loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, idx = batch

        outputs = model(x)  
        to_predict =  y
        _, predicted = torch.max(outputs.data, 1)
        output_sets['labels'].append(y.cpu())
        output_sets['predicted'].append(predicted.cpu())
        output_sets['idx'].append(idx.cpu())
        print('x', x.norm())
        print('pred', predicted)
        print('y',y)
        print('idx', idx)
        print('----')
        break
        
    

# print()
# for k in output_sets:
#     if k == 'activations':
#         print(k, [output_sets[k][j].norm() for j in range(len(output_sets[k]))])
#     else:
#         print(k, output_sets[k])
#         print(torch.cat(output_sets[k]))
        
