# ROC curves post processing method to improve fairness (particularly separation) in LS-SVM

**Objectives**

We want to emphasize unfairness in some AI/ML algorithm and then try to improve fairness of the machine by using post-correction process. Particularly with ROC curves. 

Therefore, we generate 2 Classes $C_1$ et $C_2$. Each class has a 2 groups : one majority and one minority. Thus, the protected attribute is being in a minority.

The Goal is to show some misclassification of minority observations, and the correct it with ROC curves : the best operating point (the optimal pair of thresholds is hypothetically at the intersection of the 2 ROC curves. The next objective is to compute the threshold with the list of points we obtain by displaying both curves.

For this notebook, we work on **real data (fashion-MNIST)**.

## ROC CURVES

* abscissas : $P(\hat{Y} = 1 \mid Y = 0, A = a)$

* ordinates : $P(\hat{Y} = 1 \mid Y = 1, A =a)$

* Here, $A \in ZOO$. $A=0$ means "in majority" and $A=1$ means "in minority"

A point is plotted for a certain threshold $\ \xi$. The curve is the union of all points generated.

In [3]:
# import the modules we need
import numpy as np
%matplotlib qt
import matplotlib.pyplot as plt
import scipy.linalg
import scipy.special
import scipy.stats
from tensorflow.keras.datasets import mnist,fashion_mnist
from math import *
pi = np.pi

Fashion-MNIST

**Labels**

0 : T-Shirt / Top

1 : Trouser

2 : Pullover

3 : Dress

4 : Coat

5 : Sandal

6 : Shirt

7 : Sneaker

8 : Bag

9 : Ankle Boot




**Possible Classifications**

N°1

$C_1$ = { 0,  2} ; $C_2$ = { 1, 3}

N°2

$C_1$ = {0,6} ; $C_2$ = {7, 5}



## ROC curves methods to improve Fairness in ML : Focus on Separation

In [5]:
#Probabilities belonging to each class. The 2 first ones are for C1. The 2 lasts are for C2. The sum must be equal to 1.
cs = [6/16, 1/16, 7/16, 2/16] 
k = len(cs)  # nb of classes
n = 4096 # nb of training samples
n_test = 4096 # nb of testing samples

gamma = 1  # regularization
f = lambda t : np.exp(-t/2) # RBF kernel

#xi_loop = list(np.linspace(0., 0.1, 200)) + list(np.linspace(0.1, 1, 25))
xi_loop = list(np.linspace(-1, 0, 50)) + list(np.linspace(0, 0.2, 150)) + list(np.linspace(0.2, 1, 50))
store_error = np.zeros( (len(xi_loop),4) , dtype = float )


"""
data_choice = 'MNIST'  # MNIST or Fashion-MNIST

if data_choice == 'MNIST':
    selected_labels=[7,9]
    (init_data, init_labels), _ = mnist.load_data()

if data_choice == 'fashion':
    selected_labels=[8,9]
    (init_data, init_labels), _ = fashion_mnist.load_data()


"""
# get the data
(init_data, init_labels), _ = fashion_mnist.load_data()

# put data in good shape & get data details
idx_init_labels = np.argsort(np.array(init_labels))    
labels = init_labels[idx_init_labels]
init_data = init_data[idx_init_labels,:,:]
data = np.transpose(init_data.reshape(np.shape(init_data)[0],np.shape(init_data)[1]*np.shape(init_data)[2]))
init_n,p=np.shape(data)[1],np.shape(data)[0]

# N1 = [0, 2, 1, 3] or n2 = [0, 6, 7, 5]
first_choice = [0, 2, 1, 3]
second_choice = [0, 6, 7, 5]
selected_labels = first_choice

# put data between 0 and 1
data = data/data.max()
mean_data=np.mean(data,axis=1).reshape(len(data),1)

# normalize data (CLT)
norm2_data=0
for i in range(init_n):
    norm2_data+=1/init_n*np.linalg.norm(data[:,i]-mean_data)**2

data=(data-mean_data)/np.sqrt(norm2_data)*np.sqrt(p)

#select data we want considering their labels
selected_data = []

for i in range(len(selected_labels)):
    selected_data.append(data[:,[ x for x in range(init_n) if labels[x]==selected_labels[i] ] ])
    if i==0:    
        cascade_selected_data = selected_data[-1]
    else:
        np.concatenate([cascade_selected_data,selected_data[-1]],axis=1)

# recentering of the k classes
mean_selected_data  = np.mean(cascade_selected_data,axis=1).reshape(len(cascade_selected_data),1)
norm2_selected_data = np.mean(np.sum(np.power(np.abs(cascade_selected_data-mean_selected_data),2),0))

for j in range(len(selected_labels)):
    selected_data[j] = (selected_data[j] - mean_selected_data) / np.sqrt(norm2_selected_data) * np.sqrt(p)

np.random.seed(928)
for iter,xi in enumerate(xi_loop):

    nb_average_loop = 30
    store_output = np.zeros((nb_average_loop, 4))
    for  average_index in range(nb_average_loop):
        X=np.zeros( (p,n) )
        X_test=np.zeros( (p,n_test) )
        for i in range(k):
            data = selected_data[i][ :, np.random.permutation(np.shape(selected_data[i])[1])]
            #print("data = ", data)
            X[ : , int(np.sum(cs[ : i])*n) : int(np.sum(cs[: i + 1]) * n)] = data[ : , : max(int(n*cs[i]), ceil(n * cs[i]))  ]
            X_test[:,int(np.sum(cs[:i])*n_test):int(np.sum(cs[:i+1])*n_test)]=data[:,n:n+ max(int(n_test*cs[i]), ceil(n_test * cs[i]))  ] 
            #X_test[ : , int(np.sum(cs[ : i]) * n_test) : int(np.sum(cs[ : i + 1]) * n_test)] = data[ : , n + int(n_test * sum(cs[:i])) : n + int(n_test * sum(cs[: i + 1]))  ]

        # kernel matrix
        XX = X.T@X / p
        K = f(-2 * XX + np.diag(XX).reshape(1, n)+np.diag(XX).reshape(n, 1))
        #print("X = ", X)
        #print("Xtest = ", X_test)

        # target Y
        y = np.concatenate([-np.ones(int((cs[0] + cs[1]) * n)),np.ones(int((cs[2] + cs[3]) * n))])

        # target test Y_test
        y_test = np.concatenate([-np.ones(int((cs[0] + cs[1]) * n_test)),np.ones(int((cs[2] + cs[3]) * n_test))])

        # Q et Q^{-1}
        inv_Q = K + n / gamma * np.eye(n)
        Q_y = np.linalg.solve(inv_Q, y)
        Q_1 = np.linalg.solve(inv_Q, np.ones(n))
        
        # alpha & b
        b = np.sum(Q_y) / np.sum(Q_1)
        alpha = Q_y - Q_1 * b

        # classifcation/soft scores
        g = lambda Y : alpha@f(np.diag(XX).reshape( (n, 1) ) + np.diag(Y.T@Y / p).reshape( (1,np.size(Y, 1)) ) - 2 * (X.T@Y / p)) + b
        #g = lambda Y : alpha@f(np.diag(XX).reshape( (n,1) )+np.diag(Y.T@Y/p).reshape( (1,np.size(Y,1)) )-2*(X.T@Y/p))+b
        g_test = g(X_test)
        #for i in range(len(g_test)):
            #if g_test[i] < 0:
                #g_test[i] = -g_test[i]
        #print(g_test)

        # compute of False Alarm Rate (FAR) & Correct Detection Rate (CDR) for each class considering the sensitive/protected attribute
        # "Y / Ŷ  = 0" <=> "in C1"
        # "Y / Ŷ = 1" <=> "in C2"
        # m means minority
        # M means majority
        FAR_m = float(np.sum(g_test[  int(cs[0] * n_test) :  int(sum(cs[:2]) * n_test)] >= xi)) / float(ceil(cs[1] * n_test))
        CDR_m = float( np.sum( g_test[ int((sum(cs[: - 1]) * n_test)) : ] >= xi)) / float(ceil(cs[3] * n_test))
        
        FAR_M = float( np.sum(  g_test[  : int(cs[0] * n_test)  ] >= xi)) /  float(int(cs[0] * n_test))
        CDR_M = float( np.sum( g_test[ int(sum(cs[:2]) * n_test) : int(sum(cs[:3]) * n_test)] >= xi)) / float(int(cs[2] * n_test))

        store_output[average_index, 0] = FAR_m
        store_output[average_index, 1] = CDR_m
        store_output[average_index, 2] = FAR_M
        store_output[average_index, 3] = CDR_M


    #print(store_output)
    store_error[iter, :] = [ np.mean(store_output[ : , 0]), np.mean( store_output[ : , 1]), np.mean(store_output[:, 2]), np.mean(store_output[:, 3]) ]

    
def find_xi(labs, lord):
  imin = 0
  d2min = labs[0]**2 + (lord[0] - 1)**2
  for i in range(1, len(labs)):
    dist2 = labs[i]**2 + (lord[1] - 1)**2
    if dist2 < d2min:
      imin = i
      d2min = dist2
  return imin

# get indices of empirical xi
iminority = find_xi(store_error[ : , 0], store_error[ : , 1])
imajority = find_xi(store_error[ : , 2], store_error[ : , 3])
xi_theoric = (cs[2]+cs[3]) - (cs[1] + cs[0]) 
print("xi theorique = ", xi_theoric )
print("xi minorite = ", xi_loop[iminority])
print("xi majorite = ", xi_loop[imajority])
        
        
        

#print(store_error)
# DISPLAY ROC CURVE 
#plt.errorbar(xi_loop, store_error[:,0], store_error[:,1])
#plt.axvline(cs[1]-cs[0],ls='--',color='m')
#plt.xlabel(r'Decision threshold $\xi$')
#plt.ylabel('Misclassification rate')
%matplotlib qt
plt.plot(store_error[ : , 0 ], store_error[ : , 1 ], 'b', linestyle = '-')
plt.plot(store_error[ : , 2], store_error[ : , 3], 'm', linestyle = '-')
plt.legend(["minorité", "majorité"])
plt.xlabel('False Alarm Rate')
plt.ylabel('Correct Decision Rate')
plt.show()

xi theorique =  0.125
xi minorite =  0.15033557046979867
xi majorite =  0.1906040268456376
