# Predict Protected Features

In this notebook, we try to derive protected characteristics directly from the embeddings.

In [None]:
import torch
%load_ext autoreload
%autoreload 2

import os

import numpy as np
from src.eval import EmbeddingEvaluator, Pathology
import pytorch_lightning as L

DATA_DIR = '../data'
EMBEDDING_FILE = 'mimic_cfm.npy'
#EMBEDDING_FILE = 'mimic_chess.npy'
#EMBEDDING_FILE = 'mimic_densenet_mimic.npy'

#EMBEDDING_FILE = 'chex_chess.npy'
#EMBEDDING_FILE = 'chex_densenet_chex.npy'

META_FILE = 'mimic_meta.csv'

## Load Metadata and Embedding

In [None]:
from src.utils import get_mimic_meta_data, get_chexpert_meta_data

if 'mimic' in EMBEDDING_FILE:
    train_df, val_df, test_df = get_mimic_meta_data(os.path.join(DATA_DIR, META_FILE))
else:
    train_df, val_df, test_df = get_chexpert_meta_data(DATA_DIR)
print(f'DATASET SIZES: TRAIN {len(train_df)} | VAL {len(val_df)} | TEST {len(test_df)}')


emb = np.load(os.path.join(DATA_DIR, EMBEDDING_FILE))
emb = np.nan_to_num(emb)
train_emb = emb[train_df['idx']]
test_emb = emb[test_df['idx']]

In [None]:
evaluator = EmbeddingEvaluator(train_df, test_df, train_emb, test_emb)

## 1. Predict Age

In [None]:
from tqdm import tqdm
from src.net import TensorDataset, ClassificationModule
from torch import nn, Tensor
from typing import Sequence
import warnings
from sklearn.metrics import r2_score, mean_absolute_error
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader


def train_age_regressor(
    x_train: np.ndarray,
    y_train: Sequence,
    x_val: Tensor,
    y_val: Sequence,
    max_epochs: int = 10,
    batch_size: int = 256,
):
    # Create PyTorch datasets and data loaders
    train_dataset = TensorDataset(x_train, y_train)
    val_dataset = TensorDataset(x_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2 * batch_size)

    # Initialize the PyTorch Lightning model
    model = ClassificationModule(
        model=nn.Linear(x_train.shape[1], 1),
        loss_func=nn.MSELoss()
    )

    # Initialize the PyTorch Lightning Trainer
    warnings.filterwarnings('ignore', category=PossibleUserWarning)
    trainer = L.Trainer(
        logger=CSVLogger('lightning_logs'),
        max_epochs=max_epochs,
        enable_model_summary=False,
        enable_progress_bar=False,
    )
    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    model.eval()
    # Remove shuffle for train predictions
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    res = {}
    with torch.no_grad():
        for loader, name in [(train_loader, 'train'), (val_loader, 'val')]:
            y_pred_list = []
            y_true_list = []
            for x_batch, y_batch in loader:
                y_pred = model(x_batch)
                y_pred_list.extend(y_pred.detach().numpy() * 100)
                y_true_list.extend(y_batch.numpy() * 100)

            res[name] = y_pred_list
            
            

            # Calculate Metrics using scikit-learn
            mae = mean_absolute_error(y_true_list, y_pred_list)
            r2 = r2_score(y_true_list, y_pred_list)
            #print(f'{name} MAE: {mae:.4f} | R2: {r2:.4f}')
            res[name] = (mae, r2)
            
    return res


### Retrieve from Original

In [None]:
L.seed_everything(1337424242)

results = []
for i in tqdm(range(10)):
    r = train_age_regressor(
        evaluator.train_emb,
        train_df['age'].tolist(),
        evaluator.test_emb,
        test_df['age'].tolist(),
    )
    results.append(r)

In [None]:
test_tuples = [e['val'] for e in results]
maes, r2s = list(zip(*test_tuples))

msg = ('{:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f}'
       .format(np.mean(maes), np.std(maes), np.mean(r2s), np.std(r2s)))
print(msg)

### Retrieve from Orthogonalization

In [None]:
L.seed_everything(1337424242)

results = []
for i in tqdm(range(10)):
    r = train_age_regressor(
        evaluator.train_emb_ortho,
        train_df['age'].tolist(),
        evaluator.test_emb_ortho,
        test_df['age'].tolist(),
    )
    results.append(r)

In [None]:
test_tuples = [e['val'] for e in results]
maes, r2s = list(zip(*test_tuples))

msg = ('{:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f}'
       .format(np.mean(maes), np.std(maes), np.mean(r2s), np.std(r2s)))
print(msg)

## 2. Predict Sex

In [None]:
from src.utils import eval_predictions


def train_sex_regressor(
    x_train: np.ndarray,
    y_train: Sequence,
    x_val: Tensor,
    y_val: Sequence,
    max_epochs: int = 10,
    batch_size: int = 256,
):
    # Create PyTorch datasets and data loaders
    train_dataset = TensorDataset(x_train, y_train)
    val_dataset = TensorDataset(x_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2 * batch_size)

    # Initialize the PyTorch Lightning model
    model = ClassificationModule(
        model=nn.Linear(x_train.shape[1], 1),
        loss_func=nn.BCEWithLogitsLoss()
    )

    # Initialize the PyTorch Lightning Trainer
    warnings.filterwarnings('ignore', category=PossibleUserWarning)
    trainer = L.Trainer(
        logger=CSVLogger('lightning_logs'),
        max_epochs=max_epochs,
        enable_model_summary=False,
        enable_progress_bar=False,
    )
    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    model.eval()
    # Remove shuffle for train predictions
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    res = {}
    with torch.no_grad():
        for loader, name in [(train_loader, 'train'), (val_loader, 'val')]:
            y_pred_list = []
            y_true_list = []
            for x_batch, y_batch in loader:
                y_pred = model(x_batch)
                y_pred_list.extend(torch.sigmoid(y_pred).detach().numpy())
                y_true_list.extend(y_batch.numpy())

            # Calculate Metrics using scikit-learn
            m = eval_predictions(np.asarray(y_true_list), np.asarray(y_pred_list))
            res[name] = (m['AUC'], m['SENS'], m['SPEC'])
        return res


### Retrieve from Original

In [None]:
L.seed_everything(1337424242)

results = []
for i in tqdm(range(10)):
    r = train_sex_regressor(
    evaluator.train_emb,
    np.where(train_df['sex'] == 'M', 1, 0),
    evaluator.test_emb,
    np.where(test_df['sex'] == 'M', 1, 0),
)
    results.append(r)

In [None]:
test_tuples = [e['val'] for e in results]
aucs, sens, specs = list(zip(*test_tuples))

msg = ('{:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f}'
       .format(np.mean(aucs), np.std(aucs), np.mean(sens), np.std(sens), np.mean(specs), np.std(specs)))
print(msg)

### Retrieve from Orthogonalization

In [None]:
L.seed_everything(1337424242)

results = []
for i in tqdm(range(10)):
    r = train_sex_regressor(
    evaluator.train_emb_ortho,
    np.where(train_df['sex'] == 'M', 1, 0),
    evaluator.test_emb_ortho,
    np.where(test_df['sex'] == 'M', 1, 0),
)
    results.append(r)

In [None]:
test_tuples = [e['val'] for e in results]
aucs, sens, specs = list(zip(*test_tuples))

msg = ('{:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f}'
       .format(np.mean(aucs), np.std(aucs), np.mean(sens), np.std(sens), np.mean(specs), np.std(specs)))
print(msg)

## 3. Predict Race

In [None]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix


def eval_predictions_multiclass(true: np.ndarray, pred: np.ndarray):
    # Get the number of classes
    num_classes = np.max(true) + 1
    res = {i : {} for i in range(num_classes)}

    pred_classes = np.argmax(pred, axis=1)
    aucs = roc_auc_score(true, pred, multi_class='ovr', average=None)
    

    # Initialize a list to store specificity for each class
    specificity_per_class = []
    sensitivity_per_class = []

    # Calculate specificity for each class
    for class_label in range(num_classes):
        class_pred = (pred_classes == class_label).astype(int)
        class_true = (true == class_label).astype(int)

        # Compute the confusion matrix for the current class
        confusion = confusion_matrix(class_true, class_pred)

        # Extract true negatives and false positives for the current class
        tn, fp, fn, tp = confusion.ravel()

        # Compute specificity for the current class
        class_specificity = tn / (tn + fp)
        class_sensitivity = tp / (tp + fn)

        # Append specificity to the list
        specificity_per_class.append(class_specificity)
        sensitivity_per_class.append(class_sensitivity)

    for i in range(num_classes):
        res[i]['AUC'] = aucs[i]
        res[i]['SENS'] = sensitivity_per_class[i]
        res[i]['SPEC'] = specificity_per_class[i]
    return res


In [None]:
def train_race_regressor(
        x_train: np.ndarray,
        y_train: Sequence,
        x_val: Tensor,
        y_val: Sequence,
        max_epochs: int = 10,
        batch_size: int = 256,
):
    # Create PyTorch datasets and data loaders
    train_dataset = TensorDataset(x_train, y_train, label_dtype=torch.long)
    val_dataset = TensorDataset(x_val, y_val, label_dtype=torch.long)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2 * batch_size)

    # Initialize the PyTorch Lightning model
    model = ClassificationModule(
        model=nn.Linear(x_train.shape[1], 3),
        loss_func=nn.CrossEntropyLoss()
    )

    # Initialize the PyTorch Lightning Trainer
    warnings.filterwarnings('ignore', category=PossibleUserWarning)
    trainer = L.Trainer(
        logger=CSVLogger('lightning_logs'),
        max_epochs=max_epochs,
        enable_model_summary=False,
        enable_progress_bar=False,
    )

    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    model.eval()
    # Remove shuffle for train predictions
    res = {}
    with torch.no_grad():
        for loader, name in [(train_loader, 'train'), (val_loader, 'val')]:
            y_pred_list = []
            y_true_list = []
            for x_batch, y_batch in loader:
                y_pred = model(x_batch)
                y_pred_list.extend(torch.softmax(y_pred, dim=1).cpu().detach().numpy())
                y_true_list.extend(y_batch.cpu().numpy())
                
            # Calculate Metrics using scikit-learn
            res[name] = eval_predictions_multiclass(np.asarray(y_true_list), np.stack(y_pred_list))
    return res


In [None]:
import pandas as pd
L.seed_everything(1337424242)

enc_race_train, unique_classes = pd.factorize(train_df['race'], sort=True)
print(unique_classes)

enc_race_test, unique_classes = pd.factorize(test_df['race'], sort=True)
print(unique_classes)

results = []
for i in tqdm(range(10)):
    r = train_race_regressor(
    evaluator.train_emb,
    enc_race_train,
    evaluator.test_emb,
    enc_race_test,
)
    results.append(r)

### Retrieve from Original

In [None]:
test_tuples = [e['val'] for e in results]

for i, cls in enumerate(unique_classes):
    print(cls, '----------------------')
    metrics = [t[i] for t in test_tuples]
    
    aucs = [m['AUC'] for m in metrics]
    sens = [m['SENS'] for m in metrics]
    specs = [m['SPEC'] for m in metrics]
    
    msg = ('{:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f}'
           .format(np.mean(aucs), np.std(aucs), np.mean(sens), np.std(sens), np.mean(specs), np.std(specs)))
    print(msg)

### Retrieve from Orthogonalization

In [None]:
L.seed_everything(1337424242)

enc_race_train, unique_classes = pd.factorize(train_df['race'], sort=True)
print(unique_classes)

enc_race_test, unique_classes = pd.factorize(test_df['race'], sort=True)
print(unique_classes)

results = []
for i in tqdm(range(10)):
    r = train_race_regressor(
    evaluator.train_emb_ortho,
    enc_race_train,
    evaluator.test_emb_ortho,
    enc_race_test,
)
    results.append(r)

In [None]:
test_tuples = [e['val'] for e in results]

for i, cls in enumerate(unique_classes):
    print(cls, '----------------------')
    metrics = [t[i] for t in test_tuples]
    
    aucs = [m['AUC'] for m in metrics]
    sens = [m['SENS'] for m in metrics]
    specs = [m['SPEC'] for m in metrics]
    
    msg = ('{:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f} & {:.3f} $\pm$ {:.3f}'
           .format(np.mean(aucs), np.std(aucs), np.mean(sens), np.std(sens), np.mean(specs), np.std(specs)))
    print(msg)