In [None]:
import numpy as np
import random
import itertools
import scipy.stats
import matplotlib.pyplot as plt


We assume a logistic decision rule:
$$
P(Y = y | X = x) = \frac{1}{1 + \exp(yx^Tw)}
$$
where $$y \in \{+1,-1\}, x,w \in \mathbb{R}^d$$ 

In [None]:
#Computes P(Y = y| X = x)
def Pr_Y_given_X(y,X,w):
    return 1/(1 + np.exp(-y*(np.matmul(X,w))))

We assume that $$X \sim Mul(n,p_1, \dots, p_k)$$

In [None]:
#Returns all possible multinomial outcomes
def partitions(n, b):
    masks = np.identity(b, dtype=int)
    for c in itertools.combinations_with_replacement(masks, n): 
        yield sum(c)

Let $\mathcal{X}$ be the support of $X$. Then we have:
\begin{align}
P(Y = y) = \sum_{x \in \mathcal{X}}P(Y = y | X = x)P(X = x)
\end{align}

In [None]:
#Computes P(Y = y)
def Pr_Y(y,w,probs,n_trials):
    AllX = np.array(list(partitions(n_trials,4)))
    AllXOffset = np.hstack((AllX, np.ones((AllX.shape[0],1))))
    PrX = scipy.stats.multinomial.pmf(AllX, n = np.full(AllX.shape[0],n_trials), p = probs)
    return np.dot(Pr_Y_given_X(y,AllXOffset,w), PrX)

In [None]:
w = np.array([0.5, 0, 0.2, 0.1, -0.6])
probs1 = [0.3,0.2,0.4,0.1]
PrY = Pr_Y(1,w,probs1,5)
print("P(Y = 1) for group 1:", PrY)
probs2 = [0.6,0.2,0.1,0.1]
PrY = Pr_Y(1,w,probs2,5)
print("P(Y = 1) for group 2:", PrY)

Equality of odds: we want the following
\begin{align}
P(\hat{Y} = \hat{y} | Y = y) = &\sum_{x \in \mathcal{X}}P(\hat{Y} = \hat{y} | Y = y, X = x)P(X = x | Y = y)\\
&= \frac{\sum_{x \in \mathcal{X}}P(\hat{Y} = \hat{y} | X = x)P(Y = y| X = x)P(X = x)}{P(Y = y)}
\end{align}
to be the same for both groups for $y, \hat{y} \in \{+1,-1\}.$

In [None]:
def Pr_Yhat_given_Y(y_outcome,y_given,w_hat,w_nat,probs,n_trials):
    AllX = np.array(list(partitions(n_trials,4)))
    AllXOffset = np.hstack((AllX, np.ones((AllX.shape[0],1))))    
    PrX = scipy.stats.multinomial.pmf(AllX, n = np.full(AllX.shape[0],n_trials), p = probs)
    return np.dot(np.multiply(Pr_Y_given_X(y_outcome,AllXOffset,w_hat), 
                              Pr_Y_given_X(y_given,AllXOffset,w_nat)),PrX) / Pr_Y(y_given,w_nat,probs,n_trials)

Predictive Value Parity: we want the following
\begin{align}
P(Y = y | \hat{Y} = \hat{y}) = &\sum_{x \in \mathcal{X}}P(Y = y | \hat{Y} = \hat{y}, X = x)P(X = x | \hat{Y} = \hat{y})\\
&= \frac{\sum_{x \in \mathcal{X}}P(Y = y | X = x)P(\hat{Y} = \hat{y}| X = x)P(X = x)}{P(\hat{Y} = \hat{y})}
\end{align}
to be the same for both groups for $y, \hat{y} \in \{+1,-1\}.$

In [None]:
def Pr_Y_given_Yhat(y_outcome,y_given,w_hat,w_nat,probs,n_trials):
    AllX = np.array(list(partitions(n_trials,4)))
    AllXOffset = np.hstack((AllX, np.ones((AllX.shape[0],1))))
    PrX = scipy.stats.multinomial.pmf(AllX, n = np.full(AllX.shape[0],n_trials), p = probs)
    return np.dot(np.multiply(Pr_Y_given_X(y_outcome,AllXOffset,w_nat), 
                              Pr_Y_given_X(y_given,AllXOffset,w_hat)),PrX) / Pr_Y(y_given,w_hat,probs,n_trials)


We define accuracy to be: $$P(Y = 1; \hat{w}) - P(Y = 1; w^\natural)$$

In [None]:
def accuracy(y,w_hat,w_nat,probs,n_trials):
    AllX = np.array(list(partitions(n_trials,4)))
    AllX = np.hstack((AllX, np.ones((AllX.shape[0],1))))
    return Pr_Y(y,w_hat,probs,n_trials) - Pr_Y(y,w_nat,probs,n_trials)

In [None]:
probs1 = [0.3,0.2,0.4,0.1]
probs2 = [0.6,0.2,0.1,0.1]
n_trials1 = 4
n_trials2 = 4
w_hat = np.array([0.9, 0, 0.2, 0.1, -0.7])
w_nat = np.array([0.7, 0, 0.2, 0.1, -0.7])
PrYhatGivenY1 = Pr_Yhat_given_Y(1,1,w_hat, w_nat, probs1, n_trials1)
PrYhatGivenY2 = Pr_Yhat_given_Y(1,1,w_hat, w_nat, probs2, n_trials2)
print("P(Yhat = 1 | Y = 1) for group 1: ", PrYhatGivenY1)
print("P(Yhat = 1 | Y = 1) for group 2: ", PrYhatGivenY2)
PrYGivenYhat1 = Pr_Y_given_Yhat(1,-1,w_hat, w_nat, probs1,n_trials1)
PrYGivenYhat2 = Pr_Y_given_Yhat(1,-1,w_hat, w_nat, probs2,n_trials2)
print("P(Y = 1 | Yhat = 1) for group 1: ", PrYGivenYhat1)
print("P(Y = 1 | Yhat = 1) for group 2: ", PrYGivenYhat2)

In [None]:
def plot_equal_odd(f, y_outcome, y_given, probs1, probs2, w_nat, ws, x, n_trials,marker, color):
    group1 = np.array([Pr_Yhat_given_Y(y_outcome, y_given, w_hat, w_nat, probs1, n_trials) for w_hat in ws])
    group2 = np.array([Pr_Yhat_given_Y(y_outcome, y_given, w_hat, w_nat, probs2, n_trials) for w_hat in ws])
    f(x, group1 - group2, marker = marker, color = color)
def plot_pred_value_parity(f, y_outcome, y_given, probs1, probs2, w_nat, ws, x, n_trials, marker, color):
    group1 = np.array([Pr_Y_given_Yhat(y_outcome, y_given, w_hat, w_nat, probs1, n_trials) for w_hat in ws])
    group2 = np.array([Pr_Y_given_Yhat(y_outcome, y_given, w_hat, w_nat, probs2, n_trials) for w_hat in ws])
    f(x, group1 - group2, marker = marker, color = color)
def plot_accuracy(f, y, probs, w_nat, ws, x, n_trials, marker, color):
    f(x, [accuracy(y,w_hat,w_nat,probs,n_trials) for w_hat in ws], marker = marker, color = color)

In [None]:
#We vary the first coordinate of w from -2 to 2 and plot the changes in the metrics
x = np.array(range(-20,21))/10
x = np.reshape(x,(x.shape[0],1))
w = [0.1,0.2,0.1,-0.4]
w = np.tile(w,(x.shape[0],1))
ws = np.hstack((x,w))

plot_accuracy(plt.plot, 1, probs1, w_nat, ws, x, 4, '.', "blue")
plot_accuracy(plt.plot, 1, probs2, w_nat, ws, x, 4, '.', "red")

plot_equal_odd(plt.plot, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'o', "green")
plot_pred_value_parity(plt.plot, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'x', "green")
plot_equal_odd(plt.plot, 1, -1, probs1, probs2, w_nat, ws, x, 4, 'o', "black")
plot_pred_value_parity(plt.plot, 1, -1, probs1, probs2, w_nat, ws, x, 4, 'x', "black")



In [None]:
plot_accuracy(plt.plot, 1, probs1, w_nat, ws, x, 4, '.', "blue")
plot_accuracy(plt.plot, 1, probs2, w_nat, ws, x, 4, '.', "red")

plot_pred_value_parity(plt.plot, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'x', "green")
plot_pred_value_parity(plt.plot, 1, -1, probs1, probs2, w_nat, ws, x, 4, 'x', "black")


In [None]:
#We generate random ws and plot the spread of the metrics

ws = 2*np.random.random_sample((10000,5))-0.5
ws = np.apply_along_axis(lambda x : x/np.sum(x), 1, ws)
x = range(10000)
plot_accuracy(plt.scatter, 1, probs1, w_nat, ws, x, 4, '.', "blue")
plot_accuracy(plt.scatter, 1, probs2, w_nat, ws, x, 4, '.', "red")

plot_equal_odd(plt.scatter, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'x', "green")
plot_pred_value_parity(plt.scatter, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'x', "black")

In [None]:
ws = 2*np.random.random_sample((100,5))-0.5
ws = np.apply_along_axis(lambda x : x/np.sum(x), 1, ws)
x = range(100)
plot_equal_odd(plt.scatter, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'x', "green")
plot_pred_value_parity(plt.scatter, 1, 1, probs1, probs2, w_nat, ws, x, 4, 'x', "black")