In [None]:
import torch, json
import torch.nn as nn
import numpy as np
import models_vit
from Dataset_Fair import build_combined_dataset
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from sklearn.decomposition import TruncatedSVD
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, average_precision_score
from sklearn.preprocessing import label_binarize

  warn(


In [None]:
def calculate_metrics(y_true, y_pred_probs):

    y_prob = torch.softmax(y_pred_probs, dim=1)
    y_pred = torch.argmax(y_prob, dim=1).cpu().numpy()  
    y_true = y_true.cpu().numpy()                             
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    
    try:
        overall_unique_classes = np.unique(y_true)
        if len(overall_unique_classes) == 2:
            y_score_reduced = y_prob[:, overall_unique_classes[0]]
            overall_auroc = roc_auc_score(y_true, y_score_reduced.detach().numpy())
        if len(overall_unique_classes) >= 3:
            y_true_onehot = label_binarize(y_true, classes=range(3))
            overall_auroc = roc_auc_score(y_true_onehot, y_prob, multi_class='ovr', average='macro')
        else:
            print("Batch contains only one class. Setting AUROC to 0.333")
            overall_auroc = 0.33
    except ValueError:
        auroc = float('nan') 

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'auroc': overall_auroc,
    }


In [None]:
def evaluate_distinct_attributes(attributes, true_labels, predicted_labels, attribute_names, attribute_values):
    results = {attr: {value: {'accuracy': 0.0, 'precision': 0.0, 'auroc': 0.0, 'count': 0} 
                      for value in values} 
               for attr, values in attribute_values.items()}
    
    overall_correct = 0
    total_samples = true_labels.size(0)

    all_true_labels = true_labels.detach().numpy()
    predicted_labels = torch.softmax(predicted_labels, dim=1).detach()
    all_predicted_probs = torch.argmax(predicted_labels, dim=1).detach().numpy()  
    for attr_idx, attr_name in enumerate(attribute_names):
        for value in attribute_values[attr_name]:
              indices = (attributes[:, attr_idx] == value).nonzero(as_tuple=True)[0]
            if indices.numel() > 0:  
                
                subset_true_labels = all_true_labels[indices]
                subset_predicted_labels = all_predicted_probs[indices]
                
                if np.ndim(subset_predicted_labels) == 0: 
                    subset_predicted_labels = np.array([subset_predicted_labels])
                
                if np.ndim(subset_true_labels) == 0: 
                    subset_true_labels = np.array([subset_true_labels])

                # Compute metrics
                accuracy = accuracy_score(subset_true_labels, subset_predicted_labels)
                precision = precision_score(subset_true_labels, subset_predicted_labels, average='weighted',zero_division=0)
                unique_classes = np.unique(subset_true_labels)

                if len(unique_classes) == 2:
                    y_score_reduced = predicted_labels[indices][:, unique_classes[0]]# y_score_reduced.detach().numpy()[:, 1]
                    auroc = roc_auc_score(subset_true_labels, y_score_reduced.detach().numpy(), multi_class='ovr', average='macro')
                if len(unique_classes) >= 3:
                    y_true_onehot = label_binarize(subset_true_labels, classes=[0, 1, 2])

                    auroc = roc_auc_score(y_true_onehot, predicted_labels[indices].detach().numpy(), multi_class='ovr', average='macro')
                else:
                    # Assign a default AUROC value when only one class is present
                    print("Batch contains only one class. Setting AUROC to 0.333")
                    auroc = 0.33
                # Store results
                results[attr_name][value]['accuracy'] += accuracy
                results[attr_name][value]['precision'] += precision
                results[attr_name][value]['auroc'] += auroc
                results[attr_name][value]['count'] += indices.size(0)
            
    overall_results = {}
    for attr, values in results.items():
        overall_results[attr] = {}
        for value, metrics in values.items():
            if metrics['count'] > 0: 
                overall_results[attr][value] = {
                    'accuracy': metrics['accuracy'],
                    'precision': metrics['precision'],
                    'auroc': metrics['auroc'],
                    'count': metrics['count'],
                }

    return overall_results


In [None]:
def e_dist(A, B, cosine=False, eps=1e-10):
        A_n = (A ** 2).sum(axis=1).reshape(-1, 1)
        B_n = (B ** 2).sum(axis=1).reshape(1, -1)
        inner = np.matmul(A, B.T)
        if cosine:
            return 1 - inner / (np.sqrt(A_n * B_n) + eps)
        else:
            return A_n - 2 * inner + B_n

In [None]:
import torch
import numpy as np
from sklearn.decomposition import TruncatedSVD

def Fair_predicted_output(predicted_tensor, pooled_embeddings, attributes_tensor, tau_values, theta_values, lambda_GLIF_NRW=0.001):
    """
    Fairness post-processing function that considers all demographics at once and applies class-specific tau and theta.
    
    Args:
        predicted_tensor (Tensor): Raw predicted values (logits or class predictions).
        pooled_embeddings (Tensor): Embeddings for each instance.
        attributes_tensor (Tensor): Tensor containing sensitive attributes (e.g., gender, race, etc.).
        sensitive_attribute_indices (list): List of indices for the sensitive attributes to use for fairness.
        tau_values (dict): Dictionary with tau values for each class {class: tau}.
        theta_values (dict): Dictionary with theta values for each class {class: theta}.
        lambda_GLIF_NRW (float): Regularization parameter for the fairness adjustment.

    Returns:
        y_updated (Tensor): Updated predictions after fairness adjustment.
    """
    all_embeddings = pooled_embeddings.detach()
    sensitive_attributes = np.array(attributes_tensor)
    
    one_hot_list = []

    for attr_idx in range(attributes_tensor.shape[1]):
        num_classes = np.max(sensitive_attributes[:, attr_idx]) + 1
        one_hot = np.eye(num_classes)[sensitive_attributes[:, attr_idx]]
        one_hot_list.append(torch.tensor(one_hot, dtype=torch.float32))
    
    one_hot_sensitive = torch.cat(one_hot_list, dim=1)
    
    prohibited_subspace = one_hot_sensitive.T @ pooled_embeddings
    one_hot_column = one_hot_sensitive.shape[1]
    
    tSVD = TruncatedSVD(n_components=one_hot_column)
    tSVD.fit(prohibited_subspace)
    svd_sens_directions = tSVD.components_
    svd_sens_directions = torch.tensor(svd_sens_directions).float()

    basis = svd_sens_directions.cpu().numpy().T
    proj = np.linalg.inv(np.matmul(basis.T, basis))
    proj = np.matmul(basis, proj)
    proj = np.matmul(proj, basis.T)
    proj_compl = np.eye(proj.shape[0]) - proj
    proj_compl = torch.tensor(proj_compl).float().to('cpu')

    fair_space_data = all_embeddings @ proj_compl.T
    fair_space_data = fair_space_data.cpu().numpy()

    fair_space_data_squared_distances = e_dist(fair_space_data, fair_space_data, cosine=True)
    fair_space_data_squared_distances = torch.relu(torch.tensor(fair_space_data_squared_distances).to('cpu'))
    
    fair_similarity_W = torch.zeros_like(fair_space_data_squared_distances)

    for idx, label in enumerate(predicted_tensor):
        class_tau = tau_values[label.item()]  
        class_theta = theta_values[label.item()] 
        fair_similarity_W[idx] = torch.exp(-fair_space_data_squared_distances[idx] * class_theta) * \
                                 (fair_space_data_squared_distances[idx] <= class_tau).float()

    D_ii = torch.diag_embed(fair_similarity_W.sum(1))
    D_ii_to_minus_half = torch.diag_embed(fair_similarity_W.sum(1).pow(-.5))
    W_tilde = D_ii_to_minus_half @ fair_similarity_W @ D_ii_to_minus_half
    D_tilde_to_minus_one = torch.diag_embed(W_tilde.sum(1).pow(-1))
    W = D_tilde_to_minus_one @ W_tilde
    L = torch.eye(D_ii.shape[0]) - W
    L = (L.T + L) / 2

    matrix_to_invert = lambda_GLIF_NRW * L + torch.eye(L.shape[0])
    condition_number = torch.linalg.cond(matrix_to_invert)
    
    avg_degree = fair_similarity_W.sum(1).mean()
    y_updated = torch.inverse(lambda_GLIF_NRW * avg_degree * L + torch.eye(L.shape[0])) @ predicted_tensor
    y_updated = torch.clamp(y_updated, min=-10, max=10)

    return y_updated


In [None]:
def load_dict(file_path):
    try:
        with open(file_path, 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        return {}

def save_dict(dictionary , file_path):
    with open(file_path, 'w') as file:
        json.dump(dictionary, file)

In [23]:
def eav(Fair_record):
    overall_accuracy = 0.0
    overall_precision = 0.0
    overall_auroc = 0.0
    for batch_no in range(len(Fair_record)):
        overall_accuracy += Fair_record[str(batch_no)]['Overall']['accuracy']
        overall_precision += Fair_record[str(batch_no)]['Overall']['precision']
        overall_auroc += Fair_record[str(batch_no)]['Overall']['auroc']

    print("Overall Accuracy : " + str(overall_accuracy/len(Fair_record)))
    print("Overall Precision : " + str(overall_precision/len(Fair_record)))
    print("Overall AUROC : " + str(overall_auroc/len(Fair_record)))
    
    record = {'race' : {0 : [], 1 : [], 2: []},
               'male' : {0 : [], 1 : []}, 
               'hispanic' : {0 : [], 1 : []}, 
               'maritalstatus' : {-1 : [], 0 : [], 1 : [], 2: [],3 : [], 4 : []}, 
               'language' : {0 : [], 1 : [], 2: []}
              }

    for batch_no in range(len(Fair_record)):
        for attribute in attribute_names:
            temp_record = Fair_record[str(batch_no)][attribute]
            for value in temp_record.keys():
                record[attribute][int(value)].append(temp_record[value])
    sum_record =  {'race' : {0 : {}, 1 : {}, 2: {}},
               'male' : {0 : {}, 1 : {}}, 
               'hispanic' : {0 : {}, 1 : {}}, 
               'maritalstatus' : {-1 : {}, 0 : {}, 1 : {}, 2: {},3 : {}, 4 : {}}, 
               'language' : {0 : {}, 1 : {}, 2: {}}
              }

    for attribute in record.keys(): 
        for value in record[attribute].keys():
            temp_record = record[attribute][value]
            count = 0
            temp_acc = 0.0
            temp_prec = 0.0
            temp_auroc = 0.0
            for temp_matrix in temp_record:
                temp_acc += temp_matrix['accuracy'] * temp_matrix['count'] 
                temp_prec += temp_matrix['precision'] * temp_matrix['count'] 
                temp_auroc += temp_matrix['auroc'] * temp_matrix['count']
                count += temp_matrix['count']
    #             record[attribute][int(value)].append(temp_record[value])
            if count > 0 : 
                sum_record[attribute][value] = {"accuracy" : temp_acc/count, 
                                                "precision" : temp_prec/count , 
                                                "auroc" : temp_auroc/count}
            else:
                print(f"Count is 0 for attribute :  {attribute} , category : {value}")
    return sum_record

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Fair_file_path = './Fair_Performance_record.json'
Original_file_path = './Performance_record.json'

In [None]:
root_dirs1 = './Glaucoma/'
root_dirs2 = './DR/'
root_dirs = [root_dirs1, root_dirs2]
test_dataset = build_combined_dataset(root_dirs=root_dirs, phase='test', input_size=200)
test_loader = DataLoader(test_dataset,batch_size=500,pin_memory=True, shuffle=True)

In [None]:
model = models_vit.load_pretrained_vit_base(target_size=200,global_pool=True, num_classes=3)
custom_weights = torch.load('./best_student_model.pth', map_location='cpu')  # Adjust map_location if using GPU
model.load_state_dict(custom_weights, strict=False)
# model.to(device)
model.eval()

Global Pooling Current State : True
1000
LayerNorm((768,), eps=1e-06, elementwise_affine=True)
Position interpolate from 14x14 to 12x12


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine

In [11]:
attribute_names = ['race','male', 'hispanic', 'maritalstatus', 'language']
attribute_values = { 
   'race': [0, 1, 2],
    'male': [0, 1],
    'language': [0, 1, 2],
    'hispanic': [0, 1],
    'maritalstatus': [-1,0, 1,2,3,4]}

In [None]:
Fair_result = {}
original_result = {}
class_specific_tau = {0: 0.8, 1: 1.1, 2: 1.4}  # Adjust these values based on your requirements
class_specific_theta = {0: 0.001, 1: 0.002, 2: 0.015} # Adjust these values based on your requirements
for batch_no,(data, labels, attributes) in enumerate(test_loader):
    print("Prediction start from the model")
    with torch.no_grad():
        Embedding = model.get_spatial_feature_maps(data)
        pooled_embeddings = F.adaptive_avg_pool2d(Embedding, (1, 1)).squeeze(-1).squeeze(-1)
        predicted_tensor,_ = model(data)
    temp_dict_fair = dict()
    temp_dict_original = dict()
    print("Prediction recieved from the model")
    fair_output = Fair_predicted_output(predicted_tensor, pooled_embeddings, attributes, 
                                            class_specific_tau, class_specific_theta, lambda_GLIF_NRW=0.001)
    attribute_value = {Attibute : attribute_values[attribute_names[attribute_no]]}

    overall_results = evaluate_distinct_attributes(attributes, labels, predicted_tensor, attribute_names, attribute_values)
    F_overall_results = evaluate_distinct_attributes(attributes, labels, fair_output, attribute_names, attribute_values)
    temp_dict_fair.update(F_overall_results)
    temp_dict_original.update(overall_results)
    Fair_overall_result = calculate_metrics(labels.detach(), fair_output.detach())
    UnFair_overall_result = calculate_metrics(labels.detach(), predicted_tensor.detach())
    temp_dict_fair['Overall'] = Fair_overall_result
    temp_dict_original['Overall'] = UnFair_overall_result
    Fair_result[batch_no] = temp_dict_fair
    original_result[batch_no] = temp_dict_original
    save_dict(Fair_result, Fair_file_path)
    save_dict(original_result, Original_file_path)
    Fair_record = load_dict(Fair_file_path)
    Original_record = load_dict(Original_file_path)

    

Prediction start from the model
Prediction recieved from the model
tensor(1.8662)
tensor(0.)
Batch contains only one class. Setting AUROC to 0.333
Batch contains only one class. Setting AUROC to 0.333
Prediction start from the model
Prediction recieved from the model
tensor(1.8686)
tensor(0.)
Prediction start from the model
Prediction recieved from the model
tensor(1.8726)
tensor(0.)
Batch contains only one class. Setting AUROC to 0.333
Batch contains only one class. Setting AUROC to 0.333
Prediction start from the model
Prediction recieved from the model
tensor(1.8321)
tensor(0.)
Batch contains only one class. Setting AUROC to 0.333
Batch contains only one class. Setting AUROC to 0.333
Prediction start from the model
Prediction recieved from the model
tensor(1.9171)
tensor(0.)
Prediction start from the model
Prediction recieved from the model
tensor(1.8234)
tensor(0.)
Prediction start from the model
Prediction recieved from the model
tensor(1.8983)
tensor(0.)
Batch contains only one c

In [24]:
Fair_record = load_dict(Fair_file_path)
fair_res = eav(Fair_record)
Original_record = load_dict(Original_file_path)
res = eav(Original_record)


compile_res = {}
for attribute in fair_res.keys():
    for value in fair_res[attribute].keys():
        if bool(fair_res[attribute][value]) and bool(res[attribute][value]):
            compile_res[attribute + "_" + str(value)] = [fair_res[attribute][value]['auroc'],res[attribute][value]['auroc']]
            
compile_res

Overall Accuracy : 0.7273333333333332
Overall Precision : 0.7573536886596446
Overall AUROC : 0.8761363922374534
Overall Accuracy : 0.7373333333333334
Overall Precision : 0.7330731135574499
Overall AUROC : 0.8573694796703628


{'race_0': [0.8945186353556517, 0.8777567362332971],
 'race_1': [0.8675337262798222, 0.8470081060197792],
 'race_2': [0.8747767373519817, 0.8559381871782273],
 'male_0': [0.8744718537618532, 0.8555479442771696],
 'male_1': [0.8772443784881646, 0.8585984871608685],
 'hispanic_0': [0.8765793351579889, 0.8577618796304949],
 'hispanic_1': [0.8502193406340489, 0.8326452587762472],
 'maritalstatus_-1': [0.8669838227737584, 0.8467761740686974],
 'maritalstatus_0': [0.8826708300718905, 0.8650800904153172],
 'maritalstatus_1': [0.8655562495211385, 0.8448161540305714],
 'maritalstatus_2': [0.8943512996168117, 0.8751783985508284],
 'maritalstatus_3': [0.8700527501532159, 0.8481969500546621],
 'maritalstatus_4': [0.5046367521367521, 0.5046367521367521],
 'language_0': [0.8759793236347047, 0.85713702110791],
 'language_1': [0.7715398308856254, 0.756928991742076],
 'language_2': [0.8798126523651418, 0.8613928033459316]}