# Predict Protected Features


In [1]:
import torch
%load_ext autoreload
%autoreload 2

import os

import numpy as np
from src.eval import EmbeddingEvaluator, Disease

DATA_DIR = '../data'
#EMBEDDING_FILE = 'mimic_cfm.npy'
#EMBEDDING_FILE = 'mimic_densenet_mimic.npy'
EMBEDDING_FILE = 'chex_densenet_chex.npy'
META_FILE = 'mimic_meta.csv'

## Load Metadata and Embedding

In [2]:
from src.utils import get_mimic_meta_data, get_chexpert_meta_data

#train_df, val_df, test_df = get_mimic_meta_data(os.path.join(DATA_DIR, META_FILE))
train_df, val_df, test_df = get_chexpert_meta_data(DATA_DIR)
print(f'DATASET SIZES: TRAIN {len(train_df)} | VAL {len(val_df)} | TEST {len(test_df)}')


emb = np.load(os.path.join(DATA_DIR, EMBEDDING_FILE))
emb = np.nan_to_num(emb)
train_emb = emb[train_df['idx']]
test_emb = emb[test_df['idx']]

DATASET SIZES: TRAIN 76205 | VAL 12673 | TEST 38240


In [3]:
evaluator = EmbeddingEvaluator(train_df, test_df, train_emb, test_emb)

## Predict Age

In [4]:
from src.net import TensorDataset, ClassificationModule
from torch import nn, Tensor
from typing import Sequence, Tuple
import pytorch_lightning as L
import warnings
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

from torch.utils.data import DataLoader

L.seed_everything(1337424242)


def train_age_regressor(
    x_train: np.ndarray,
    y_train: Sequence,
    x_val: Tensor,
    y_val: Sequence,
    max_epochs: int = 10,
    batch_size: int = 256,
) -> Tuple[np.ndarray, np.ndarray]:
    # Create PyTorch datasets and data loaders
    train_dataset = TensorDataset(x_train, y_train)
    val_dataset = TensorDataset(x_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2 * batch_size)

    # Initialize the PyTorch Lightning model
    model = ClassificationModule(
        model=nn.Linear(x_train.shape[1], 1),
        loss_func=nn.MSELoss()
    )

    # Initialize the PyTorch Lightning Trainer
    trainer = L.Trainer(max_epochs=max_epochs)

    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    model.eval()
    # Remove shuffle for train predictions
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    res = {}
    with torch.no_grad():
        for loader, name in [(train_loader, 'train'), (val_loader, 'val')]:
            y_pred_list = []
            y_true_list = []
            for x_batch, y_batch in loader:
                y_pred = model(x_batch)
                y_pred_list.extend(y_pred.detach().numpy() * 100)
                y_true_list.extend(y_batch.numpy() * 100)

            res[name] = y_pred_list
            
            

            # Calculate Metrics using scikit-learn
            mae = mean_absolute_error(y_true_list, y_pred_list)
            r2 = r2_score(y_true_list, y_pred_list)
            print(f'{name} MAE: {mae:.4f} | R2: {r2:.4f}')


Global seed set to 1337424242


### Age NON-ortho

In [5]:
train_age_regressor(
    evaluator.train_emb,
    train_df['age'].tolist(),
    evaluator.test_emb,
    test_df['age'].tolist(),
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type    | Params
--------------------------------------
0 | model     | Linear  | 1.0 K 
1 | loss_func | MSELoss | 0     
--------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


train MAE: 10.5945 | R2: 0.4089
val MAE: 10.6137 | R2: 0.3897


### Age ortho


In [6]:
train_age_regressor(
    evaluator.train_emb_ortho,
    train_df['age'].tolist(),
    evaluator.test_emb_ortho,
    test_df['age'].tolist(),
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type    | Params
--------------------------------------
0 | model     | Linear  | 1.0 K 
1 | loss_func | MSELoss | 0     
--------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


train MAE: 14.0806 | R2: -0.0065
val MAE: 13.8702 | R2: -0.0082


## Predict Sex

In [7]:

from src.utils import eval_predictions

from torch.utils.data import DataLoader

L.seed_everything(1337424242)


def train_sex_regressor(
    x_train: np.ndarray,
    y_train: Sequence,
    x_val: Tensor,
    y_val: Sequence,
    max_epochs: int = 10,
    batch_size: int = 256,
) -> Tuple[np.ndarray, np.ndarray]:
    # Create PyTorch datasets and data loaders
    train_dataset = TensorDataset(x_train, y_train)
    val_dataset = TensorDataset(x_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2 * batch_size)

    # Initialize the PyTorch Lightning model
    model = ClassificationModule(
        model=nn.Linear(x_train.shape[1], 1),
        loss_func=nn.BCEWithLogitsLoss()
    )

    # Initialize the PyTorch Lightning Trainer
    trainer = L.Trainer(max_epochs=max_epochs)

    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    model.eval()
    # Remove shuffle for train predictions
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    res = {}
    with torch.no_grad():
        for loader, name in [(train_loader, 'train'), (val_loader, 'val')]:
            y_pred_list = []
            y_true_list = []
            for x_batch, y_batch in loader:
                y_pred = model(x_batch)
                y_pred_list.extend(torch.sigmoid(y_pred).detach().numpy())
                y_true_list.extend(y_batch.numpy())

            res[name] = y_pred_list
            
            

            # Calculate Metrics using scikit-learn
            eval_predictions(np.asarray(y_true_list), np.asarray(y_pred_list))


Global seed set to 1337424242


### Sex NON-ortho

In [8]:
train_sex_regressor(
    evaluator.train_emb,
    np.where(train_df['sex'] == 'M', 1, 0),
    evaluator.test_emb,
    np.where(test_df['sex'] == 'M', 1, 0),
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type              | Params
------------------------------------------------
0 | model     | Linear            | 1.0 K 
1 | loss_func | BCEWithLogitsLoss | 0     
------------------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


METRICS:	AUC 0.8640 | ACC 0.7859 | SENS 0.8417 | SPEC 0.7065 | PREC 0.8034 | F1 0.8221
METRICS:	AUC 0.8683 | ACC 0.7900 | SENS 0.8472 | SPEC 0.7102 | PREC 0.8029 | F1 0.8245


### Sex ortho


In [9]:
train_sex_regressor(
    evaluator.train_emb_ortho,
    np.where(train_df['sex'] == 'M', 1, 0),
    evaluator.test_emb_ortho,
    np.where(test_df['sex'] == 'M', 1, 0),
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type              | Params
------------------------------------------------
0 | model     | Linear            | 1.0 K 
1 | loss_func | BCEWithLogitsLoss | 0     
------------------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


METRICS:	AUC 0.4950 | ACC 0.5875 | SENS 1.0000 | SPEC 0.0000 | PREC 0.5875 | F1 0.7402
METRICS:	AUC 0.5028 | ACC 0.5822 | SENS 1.0000 | SPEC 0.0000 | PREC 0.5822 | F1 0.7360


In [10]:

from src.utils import eval_predictions, eval_predictions_multiclass

from torch.utils.data import DataLoader

L.seed_everything(1337424242)


def train_race_regressor(
        x_train: np.ndarray,
        y_train: Sequence,
        x_val: Tensor,
        y_val: Sequence,
        max_epochs: int = 10,
        batch_size: int = 256,
) -> Tuple[np.ndarray, np.ndarray]:
    # Create PyTorch datasets and data loaders
    train_dataset = TensorDataset(x_train, y_train, label_dtype=torch.long)
    val_dataset = TensorDataset(x_val, y_val, label_dtype=torch.long)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2 * batch_size)

    # Initialize the PyTorch Lightning model
    model = ClassificationModule(
        model=nn.Linear(x_train.shape[1], 3),
        loss_func=nn.CrossEntropyLoss()
    )

    # Initialize the PyTorch Lightning Trainer
    trainer = L.Trainer(max_epochs=max_epochs)

    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    model.eval()
    # Remove shuffle for train predictions
    res = {}
    with torch.no_grad():
        for loader, name in [(train_loader, 'train'), (val_loader, 'val')]:
            y_pred_list = []
            y_true_list = []
            for x_batch, y_batch in loader:
                y_pred = model(x_batch)
                y_pred_list.extend(torch.softmax(y_pred, dim=1).cpu().detach().numpy())
                y_true_list.extend(y_batch.cpu().numpy())
                
            res[name] = y_pred_list

            # Calculate Metrics using scikit-learn
            eval_predictions_multiclass(np.asarray(y_true_list), np.stack(y_pred_list))


Global seed set to 1337424242


## Race non-ortho

In [11]:
import pandas as pd

train_race_regressor(
    evaluator.train_emb,
    pd.factorize(train_df['race'])[0],
    evaluator.test_emb,
    pd.factorize(test_df['race'])[0],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Linear           | 3.1 K 
1 | loss_func | CrossEntropyLoss | 0     
-----------------------------------------------
3.1 K     Trainable params
0         Non-trainable params
3.1 K     Total params
0.012     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Specificity per class: [0.05522484823480875, 0.9882777554986582, 0.9998300499936269]
METRICS:	AUC 0.7311 | ACC 0.7795 | SENS 0.3536 | SPEC 0.6811 | PREC 0.6413 | F1 0.3340
Specificity per class: [0.05562172463077656, 0.9883399815894446, 0.9996900884656562]
METRICS:	AUC 0.7245 | ACC 0.7824 | SENS 0.3539 | SPEC 0.6812 | PREC 0.5511 | F1 0.3352


## Race ortho

In [12]:
import pandas as pd

train_race_regressor(
    evaluator.train_emb_ortho,
    pd.factorize(train_df['race'])[0],
    evaluator.test_emb_ortho,
    pd.factorize(test_df['race'])[0],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Linear           | 3.1 K 
1 | loss_func | CrossEntropyLoss | 0     
-----------------------------------------------
3.1 K     Trainable params
0         Non-trainable params
3.1 K     Total params
0.012     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
  _warn_prf(average, modifier, msg_start, len(result))


Specificity per class: [0.0, 1.0, 1.0]
METRICS:	AUC 0.5021 | ACC 0.7774 | SENS 0.3333 | SPEC 0.6667 | PREC 0.2591 | F1 0.2916
Specificity per class: [0.0, 1.0, 1.0]
METRICS:	AUC 0.5136 | ACC 0.7804 | SENS 0.3333 | SPEC 0.6667 | PREC 0.2601 | F1 0.2922


  _warn_prf(average, modifier, msg_start, len(result))
