In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split
import pandas as pd
#For reproducibility
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from data import create_synth_data as csd
from scripts import bbse, methods
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn.model_selection import train_test_split



In [2]:
data_dir = '/home/parjanya/UCSD_courses/ECE228/melanoma'
df = pd.read_csv(data_dir + '/train.csv')
df['anatom_site_general_challenge'] = df['anatom_site_general_challenge'].astype('category')
print(df['anatom_site_general_challenge'].value_counts())
g = df['anatom_site_general_challenge'].cat.codes


y = df['target']


mapping = dict(enumerate(df['anatom_site_general_challenge'].cat.categories))
print("Mapping of 'anatom_site_general_challenge' to integers:", mapping)

torso              16845
lower extremity     8417
upper extremity     4983
head/neck           1855
palms/soles          375
oral/genital         124
Name: anatom_site_general_challenge, dtype: int64
Mapping of 'anatom_site_general_challenge' to integers: {0: 'head/neck', 1: 'lower extremity', 2: 'oral/genital', 3: 'palms/soles', 4: 'torso', 5: 'upper extremity'}


In [3]:
#Load X_resnet
X_resnet = np.load(data_dir + '/resnet_features.npy')

In [4]:

X_resnet = torch.tensor(X_resnet, dtype=torch.float32)
y = torch.tensor(y.values, dtype=torch.float32)
g = torch.tensor(g.values, dtype=torch.float32)
def stratified_split(X, y, g, test_size=0.5, train_pos_ratio=0.1, test_pos_ratio=0.5):
    pos_indices = (y == 1)
    neg_indices = (y == 0)
    size = min(pos_indices.sum(), neg_indices.sum())-1
    print(size)
    X_pos, y_pos, g_pos = X[pos_indices][:size], y[pos_indices][:size], g[pos_indices][:size]
    X_neg, y_neg, g_neg = X[neg_indices][:size], y[neg_indices][:size], g[neg_indices][:size]
    pos_train_size = int(train_pos_ratio * len(X_pos) / (train_pos_ratio + test_pos_ratio))
    neg_train_size = int(len(X_neg) * (1 - test_size))
    X_train_pos, X_test_pos, y_train_pos, y_test_pos, g_train_pos, g_test_pos = train_test_split(X_pos, y_pos, g_pos, train_size=pos_train_size, random_state=42)
    X_train_neg, X_test_neg, y_train_neg, y_test_neg, g_train_neg, g_test_neg = train_test_split(X_neg, y_neg, g_neg, train_size=neg_train_size, random_state=42)
    X_train = torch.cat([X_train_pos, X_train_neg], dim=0)
    y_train = torch.cat([y_train_pos, y_train_neg], dim=0)
    g_train = torch.cat([g_train_pos, g_train_neg], dim=0)
    X_test = torch.cat([X_test_pos, X_test_neg], dim=0)
    y_test = torch.cat([y_test_pos, y_test_neg], dim=0)
    g_test = torch.cat([g_test_pos, g_test_neg], dim=0)

    return X_train, X_test, y_train, y_test, g_train, g_test

X_train, X_test, y_train, y_test, g_train, g_test = stratified_split(X_resnet, y, g)

X_train = X_train.numpy()
X_test = X_test.numpy()
y_train = y_train.numpy()
y_test = y_test.numpy()
g_train = g_train.numpy()
g_test = g_test.numpy()


tensor(583)


In [6]:
group = 1
group_indices = (g_train == group)
group_indices_test = (g_test == group)
X_train_group = X_train[~group_indices]
y_train_group = y_train[~group_indices]
g_train_group = g_train[~group_indices]
X_test_group = X_test[group_indices_test]
y_test_group = y_test[group_indices_test]
g_test_group = g_test[group_indices_test]
y_pred_test_ours_prob_group, y_pred_val_ours_prob_group, y_val_ours_group, y1_group = methods.ours(X_train_group, y_train_group, g_train_group, X_test_group,1.0)
y_pred_test_ours_group = np.argmax(y_pred_test_ours_prob_group, axis=1)
y_pred_val_ours_group = np.argmax(y_pred_val_ours_prob_group, axis=1)
y_pred_test_erm_prob_group , y_pred_val_erm_prob_group, y_val_erm_group= methods.erm(X_train_group, y_train_group, X_test_group)
y_pred_test_erm_group = np.argmax(y_pred_test_erm_prob_group, axis=1)
y_pred_val_erm_group = np.argmax(y_pred_val_erm_prob_group, axis=1)

y_pred_test_irm_prob_group = methods.irm(X_train_group, y_train_group, g_train_group, X_test_group,0.1)
y_pred_test_irm_group = np.argmax(y_pred_test_irm_prob_group, axis=1)
y_pred_test_dro_prob_group = methods.group_dro(X_train_group, y_train_group, g_train_group, X_test_group)
y_pred_test_dro_group = np.argmax(y_pred_test_dro_prob_group, axis=1)


#Print test accuracy for ERM, IRM and DRO and ours
print("Test accuracy Ours: ", np.mean(y_pred_test_ours_group == y_test_group))
print("Test accuracy ERM: ", np.mean(y_pred_test_erm_group == y_test_group))
print("Test accuracy IRM: ", np.mean(y_pred_test_irm_group == y_test_group))
print("Test accuracy DRO: ", np.mean(y_pred_test_dro_group == y_test_group))



#Print recall based on this new y_pred
print("Recall Ours: ", np.mean(y_pred_test_ours_group[y_test_group==1] == y_test_group[y_test_group==1]))
print("Recall ERM: ", np.mean(y_pred_test_erm_group[y_test_group==1] == y_test_group[y_test_group==1]))
print("Recall IRM: ", np.mean(y_pred_test_irm_group[y_test_group==1] == y_test_group[y_test_group==1]))
print("Recall DRO: ", np.mean(y_pred_test_dro_group[y_test_group==1] == y_test_group[y_test_group==1]))

qy_hat = np.zeros(2)
qy_hat[0] = np.mean(y_pred_test_erm_group== 0)
qy_hat[1] = np.mean(y_pred_test_erm_group == 1)

qy, py = bbse.estimate_test_py(y_val_erm_group,y_pred_val_erm_group,qy_hat)
y_pred_test_erm_prob_bbse_group = bbse.recalibrate(y_pred_test_erm_prob_group, qy, py)
y_pred_test_erm_bbse_group = np.argmax(y_pred_test_erm_prob_bbse_group, axis=1)
print("Test accuracy ERM BBSE: ", np.mean(y_pred_test_erm_bbse_group == y_test_group))
print("Recall ERM BBSE: ", np.mean(y_pred_test_erm_bbse_group[y_test_group==1] == y_test_group[y_test_group==1]))

#Do BBSE for ours instead of ERM
qy_hat = np.zeros(2)
qy_hat[0] = np.mean(y_pred_test_ours_group== 0)
qy_hat[1] = np.mean(y_pred_test_ours_group == 1)

qy, py = bbse.estimate_test_py(y_val_ours_group,y_pred_val_ours_group,qy_hat)
y_pred_test_ours_prob_bbse_group = bbse.recalibrate(y_pred_test_erm_prob_group, qy, py)
y_pred_test_ours_bbse_group = np.argmax(y_pred_test_ours_prob_bbse_group, axis=1)
print("Test accuracy Ours BBSE: ", np.mean(y_pred_test_ours_bbse_group == y_test_group))
print("Recall Ours BBSE: ", np.mean(y_pred_test_ours_bbse_group[y_test_group==1] == y_test_group[y_test_group==1]))


Early stopping!
Early stopping!
Early stopping!
Test accuracy Ours:  0.4913294797687861
Test accuracy ERM:  0.4624277456647399
Test accuracy IRM:  0.4046242774566474
Test accuracy DRO:  0.45664739884393063
Recall Ours:  0.18867924528301888
Recall ERM:  0.16037735849056603
Recall IRM:  0.02830188679245283
Recall DRO:  0.32075471698113206
Test accuracy ERM BBSE:  0.4624277456647399
Recall ERM BBSE:  0.1320754716981132
Test accuracy Ours BBSE:  0.6589595375722543
Recall Ours BBSE:  0.8113207547169812
