In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split
#For reproducibility
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from data import create_synth_data as csd
from scripts import bbse, methods


In [2]:
X,y,g = csd.create_synthetic_data(10000)
X_train, X_test = X[g!=3], X[g==3]
y_train, y_test = y[g!=3], y[g==3]

In [3]:
y_pred_test_ours_prob, y_pred_val_ours_prob, y_val_ours, y1 = methods.ours(X_train, y_train, g[g!=3], X_test,1.0)
y_pred_test_ours = np.argmax(y_pred_test_ours_prob, axis=1)
y_pred_val_ours = np.argmax(y_pred_val_ours_prob, axis=1)
y_pred_test_erm_prob , y_pred_val_erm_prob, y_val_erm= methods.erm(X_train, y_train, X_test)
y_pred_test_erm = np.argmax(y_pred_test_erm_prob, axis=1)
y_pred_val_erm = np.argmax(y_pred_val_erm_prob, axis=1)

y_pred_test_irm_prob = methods.irm(X_train, y_train, g[g!=3], X_test,0.1)
y_pred_test_irm = np.argmax(y_pred_test_irm_prob, axis=1)
y_pred_test_dro_prob = methods.group_dro(X_train, y_train, g[g!=3], X_test)
y_pred_test_dro = np.argmax(y_pred_test_dro_prob, axis=1)



Early stopping!
Early stopping!
Early stopping!
Early stopping!


In [4]:
print(y_pred_val_erm.shape)
print(y_val_erm.shape)

(748,)
(748,)


In [5]:
#Print test accuracy for ERM, IRM and DRO and ours
print("Test accuracy Ours: ", np.mean(y_pred_test_ours == y_test))
print("Test accuracy ERM: ", np.mean(y_pred_test_erm == y_test))
print("Test accuracy IRM: ", np.mean(y_pred_test_irm == y_test))
print("Test accuracy DRO: ", np.mean(y_pred_test_dro == y_test))


#Print normalised confusion matrix for ERM and ours for validation and test data separately i.e. the entries should represent p(y_hat|y) and each row should sum to 1
print("Confusion matrix for ERM on validation data: ")
print(np.around(np.mean(y_pred_val_erm[y_val_erm==0], axis=0), 3))
print(np.around(np.mean(y_pred_val_erm[y_val_erm==1], axis=0), 3))
print("Confusion matrix for Ours on validation data: ")
print(np.around(np.mean(y_pred_val_ours[y_val_ours==0], axis=0), 3))
print(np.around(np.mean(y_pred_val_ours[y_val_ours==1], axis=0), 3))
print("Confusion matrix for ERM on test data: ")
print(np.around(np.mean(y_pred_test_erm[y_test==0], axis=0), 3))
print(np.around(np.mean(y_pred_test_erm[y_test==1], axis=0), 3))
print("Confusion matrix for Ours on test data: ")
print(np.around(np.mean(y_pred_test_ours[y_test==0], axis=0), 3)) 
print(np.around(np.mean(y_pred_test_ours[y_test==1], axis=0), 3))








Test accuracy Ours:  0.9231074118113357
Test accuracy ERM:  0.8751486325802615
Test accuracy IRM:  0.8620689655172413
Test accuracy DRO:  0.896551724137931
Confusion matrix for ERM on validation data: 
0.049
0.912
Confusion matrix for Ours on validation data: 
0.065
0.919
Confusion matrix for ERM on test data: 
0.052
0.857
Confusion matrix for Ours on test data: 
0.074
0.922


In [6]:
qy_hat = np.zeros(2)
qy_hat[0] = np.mean(y_pred_test_erm== 0)
qy_hat[1] = np.mean(y_pred_test_erm == 1)

qy, py = bbse.estimate_test_py(y_val_erm,y_pred_val_erm,qy_hat)
print(qy, py)
y_pred_test_erm_prob_bbse = bbse.recalibrate(y_pred_test_erm_prob, qy, py)
#Calculate accuracy of recalibrated ERM
y_pred_test_erm_bbse = np.argmax(y_pred_test_erm_prob_bbse, axis=1)
print("Test accuracy ERM BBSE: ", np.mean(y_pred_test_erm_bbse == y_test))

[0.24737344 0.75262656] [0.62165975 0.37834425]
Test accuracy ERM BBSE:  0.926674593737614


In [7]:
#Do BBSE for ours instead of ERM
qy_hat = np.zeros(2)
qy_hat[0] = np.mean(y_pred_test_ours== 0)
qy_hat[1] = np.mean(y_pred_test_ours == 1)

qy, py = bbse.estimate_test_py(y_val_ours,y_pred_val_ours,qy_hat)
print(qy, py)
y_pred_test_ours_prob_bbse = bbse.recalibrate(y_pred_test_ours_prob, qy, py)
#Calculate accuracy of recalibrated ours
y_pred_test_ours_bbse = np.argmax(y_pred_test_ours_prob_bbse, axis=1)
print("Test accuracy Ours BBSE: ", np.mean(y_pred_test_ours_bbse == y_test))

[0.19160753 0.80839247] [0.62165975 0.37834425]
Test accuracy Ours BBSE:  0.9397542608006342
