In [7]:
import os
import pandas as pd
import torch
import numpy as np
from pytorch_lightning.core.lightning import LightningModule
from sklearn.linear_model import LogisticRegression, SGDClassifier
import torch.nn.functional as F
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [242]:
data_path = '/share/pi/nigam/projects/spfohl/cohorts/admissions/optum/'
predictions_path = os.path.join(data_path, 'experiments', 'baseline_tuning_fold_1_10', 'performance', 'LOS_7', '0.yaml', '1', 'output_df.parquet')
cohort_path = os.path.join(data_path, 'cohort', 'cohort.parquet')
row_id_map_path = os.path.join(data_path, 'merged_features_binary', 'features_sparse', 'features_row_id_map.parquet')

In [227]:
pred_df = pd.read_parquet(predictions_path)

In [238]:
cohort = pd.read_parquet(cohort_path)

In [243]:
row_id_map = pd.read_parquet(row_id_map_path)

In [245]:
row_id_map

Unnamed: 0,features_row_id,prediction_id
0,0,-9223363366502185856
1,1,-9223362673657268414
2,2,-9223362053444238164
3,3,-9223361452100897490
4,4,-9223360812677213811
...,...,...
8074566,8074566,9223364539900589339
8074567,8074567,9223365295068207882
8074568,8074568,9223366169655560425
8074569,8074569,9223367109999326759


In [244]:
cohort.head()

Unnamed: 0,person_id,admit_date,discharge_date,hospital_mortality,month_mortality,LOS_days,LOS_7,readmission_30,age_in_years,age_group,race_eth,gender_concept_name,prediction_id,fold_id
0,3,2017-09-12,2017-09-14,0,0,2,0,0,58,[55-65),Other,MALE,1669980515859799912,test
1,4,2007-09-17,2007-09-18,0,0,1,0,0,29,[18-30),Other,MALE,-3060212163314143226,test
2,11,2007-06-17,2007-06-18,0,0,1,0,0,27,[18-30),Other,FEMALE,2688580955788010273,7
3,14,2009-03-02,2009-03-04,0,0,2,0,0,35,[30-45),Other,FEMALE,-751964920177342014,3
4,16,2005-02-20,2005-02-22,0,0,2,0,0,62,[55-65),Other,FEMALE,-470516056504384768,8


In [228]:
pred_df

Unnamed: 0,phase,outputs,pred_probs,labels,row_id
0,val,-0.311713,0.334550,1,31
1,val,-0.869886,0.116189,0,36
2,val,-1.979475,0.009725,0,37
3,val,-0.322860,0.325983,1,51
4,val,-0.782082,0.135438,0,59
...,...,...,...,...,...
1533932,test,-0.252077,0.366276,1,8074549
1533933,test,-0.385773,0.296091,0,8074550
1533934,test,-0.637188,0.191029,0,8074552
1533935,test,-0.464459,0.253941,0,8074554


In [229]:
class LogProbModel(LightningModule):

    def __init__(self, apply_log_transform=True):
        super().__init__()
        self.layer = torch.nn.Linear(1, 2, bias=True)
        self.apply_log_transform = apply_log_transform

    def forward(self, x):
        return self.layer(x)
#         return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.unsqueeze(1)
#         print(x)
        if self.apply_log_transform:
            x = torch.log(x)
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
class LogProbModel2(LightningModule):

    def __init__(self, apply_log_transform=True):
        super().__init__()
        self.layer = torch.nn.Linear(1, 1, bias=True)
        self.apply_log_transform = apply_log_transform

    def forward(self, x):
        return self.layer(x)
    
    def forward_on_batch(self, batch):
        x, y = batch
        x = x.unsqueeze(1)
        y = y.unsqueeze(1)
        if self.apply_log_transform:
            x = torch.log(x)
        y_hat = self.forward(x)
        return F.binary_cross_entropy_with_logits(y_hat, y)
    
    def training_step(self, batch, batch_idx):
        return {'loss': self.forward_on_batch(batch)}
    


#     def configure_optimizers(self):
#         optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
#         return optimizer

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
            'monitor': 'val_loss',
            'interval': 'epoch',
            'frequency': 1
        }
        return [optimizer], [scheduler]

    def validation_step(self, batch, batch_idx):
        return {'val_loss': self.forward_on_batch(batch)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}


train_df = pred_df.query('phase == "val"')
val_df = pred_df.query('phase == "test"')
train_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(train_df.pred_probs), torch.FloatTensor(train_df.labels))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, num_workers=4)
val_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(val_df.pred_probs.values), torch.FloatTensor(val_df.labels.values))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, num_workers=4)
model = LogProbModel2()

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.0,
    verbose=True,
    patience=50
)

trainer = Trainer(
    max_epochs=100, 
    early_stop_callback=early_stop_callback, 
    progress_bar_refresh_rate=0, 
    checkpoint_callback=False, 
    logger=False,
    gpus=1 if torch.cuda.is_available() else 0
)
trainer.fit(model, train_loader, val_loader)
print(model.layer.weight)
print(model.layer.bias)

EarlyStopping mode auto is unknown, fallback to auto mode.
EarlyStopping mode set to min for monitoring val_loss.
GPU available: False, used: False
No environment variable for node rank defined. Set as 0.

  | Name  | Type   | Params
-----------------------------
0 | layer | Linear | 2     
Detected KeyboardInterrupt, attempting graceful shutdown...


Parameter containing:
tensor([[2.4244]], requires_grad=True)
Parameter containing:
tensor([1.9988], requires_grad=True)


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f25a593ad40>
Traceback (most recent call last):
  File "/share/pi/nigam/envs/anaconda/envs/prediction_utils/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 962, in __del__
    self._shutdown_workers()
  File "/share/pi/nigam/envs/anaconda/envs/prediction_utils/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 942, in _shutdown_workers
    w.join()
  File "/share/pi/nigam/envs/anaconda/envs/prediction_utils/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/share/pi/nigam/envs/anaconda/envs/prediction_utils/lib/python3.7/multiprocessing/popen_fork.py", line 48, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/share/pi/nigam/envs/anaconda/envs/prediction_utils/lib/python3.7/multiprocessing/popen_fork.py", line 28, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 


In [233]:
train_df = pred_df.query('phase == "val"')

In [236]:
class CalibrationEvaluator:
    def get_calibration_df(
        self,
        df,
        group_vars=["config_filename", "phase", "task", "attribute", "group"],
        df_eval=None,
    ):
        group_vars = [var for var in group_vars if var in df.columns]
        model_dict = {}
        calibration_dict = {}
        for group, df_grouped in df.groupby(group_vars):
            df_grouped = df_grouped.query("pred_probs > 0")
            log_probs = np.log(df_grouped.pred_probs.values.reshape(-1, 1))
            model_dict[group] = LogisticRegression(solver="lbfgs", penalty="none")
            model_dict[group].fit(log_probs, df_grouped.labels.values)
            if df_eval is None:
                calibration_dict[group] = df_grouped.assign(
                    density_conditional_y1=model_dict[group].predict_proba(log_probs)[
                        :, -1
                    ]
                ).reset_index(drop=True)
            else:
                df_eval_grouped = self.filter_by_group_spec(df_eval, group_vars, group)
                df_eval_grouped = df_eval_grouped.query("pred_probs > 0")
                log_probs = np.log(df_eval_grouped.pred_probs.values.reshape(-1, 1))
                calibration_dict[group] = df_eval_grouped.assign(
                    density_conditional_y1=model_dict[group].predict_proba(log_probs)[
                        :, -1
                    ]
                ).reset_index(drop=True)
        calibration_df = pd.concat(calibration_dict).reset_index(drop=True)
        return calibration_df

    def get_calibration_df_combined(
        self,
        df,
        group_vars=[
            "sensitive_attribute",
            "config_filename",
            "phase",
            "task",
            "attribute",
        ],
    ):
        calibration_df_group = self.get_calibration_df(
            df, group_vars=group_vars + ["group"]
        )
        calibration_df_overall = self.get_calibration_df(df, group_vars=group_vars)
        calibration_df = calibration_df_group.merge(
            calibration_df_overall.rename(
                columns={"density_conditional_y1": "density_conditional_y1_overall"}
            )
        )
        return calibration_df

    def get_calibration_result(
        self, df, group_vars=["config_filename", "phase", "task", "attribute", "group"],
    ):
        group_vars = [var for var in group_vars if var in df.columns]
        return (
            df.assign(
                brier_diff_signed=lambda x: x.labels - x.pred_probs,
                brier_diff_squared=lambda x: x.brier_diff_signed ** 2,
                calib_diff_signed=lambda x: x.density_conditional_y1 - x.pred_probs,
                calib_diff_squared=lambda x: x.calib_diff_signed ** 2,
                calib_density_diff_signed=lambda x: x.density_conditional_y1
                - x.density_conditional_y1_overall,
                calib_density_diff_squared=lambda x: x.calib_density_diff_signed ** 2,
            )
            .groupby(group_vars)
            .agg(
                brier=("brier_diff_squared", lambda x: x.mean()),
                brier_signed=("brier_diff_signed", lambda x: x.mean()),
                calib_error=("calib_diff_squared", lambda x: x.mean()),
                calib_error_signed=("calib_diff_signed", lambda x: x.mean()),
                calib_group_error=("calib_density_diff_squared", lambda x: x.mean()),
                calib_group_error_signed=(
                    "calib_density_diff_signed",
                    lambda x: x.mean(),
                ),
            )
            .reset_index()
        )

    def filter_by_group_spec(self, df, group_vars, group_values):
        for group_var, group_value in zip(group_vars, group_values):
            df = df.loc[df[group_var] == group_value]
        return df

In [235]:
model_sk = LogisticRegression(solver="lbfgs", penalty="none")
%time model_sk.fit(np.log(train_df.pred_probs.values.reshape(-1, 1)), train_df.labels.values)
print([model_sk.coef_, model_sk.intercept_])

model_sk = SGDClassifier(loss='log', learning_rate='adaptive', eta0=1e-1, max_iter=1000, alpha=1e-12)

%time model_sk.fit(np.log(train_df.pred_probs.values.reshape(-1, 1)), train_df.labels.values)
print([model_sk.coef_, model_sk.intercept_])

CPU times: user 10.6 s, sys: 12.5 s, total: 23.1 s
Wall time: 925 ms
[array([[2.44633899]]), array([1.98917973])]
CPU times: user 8.53 s, sys: 2.65 s, total: 11.2 s
Wall time: 6.5 s
[array([[2.44681725]]), array([1.98977588])]


In [220]:
big_train_df = pd.concat([train_df for i in range(100)])

In [221]:
big_train_df

Unnamed: 0,phase,outputs,pred_probs,labels,row_id
0,val,-1.094916,0.079028,0,16
1,val,-2.456556,0.003679,0,21
2,val,-3.585434,0.000268,0,24
3,val,-0.885466,0.112826,0,35
4,val,-1.538378,0.027455,0,37
...,...,...,...,...,...
17873,val,-1.423776,0.037314,0,198601
17874,val,-1.970004,0.011250,0,198627
17875,val,-3.218438,0.000618,0,198629
17876,val,-2.217583,0.006336,0,198635


In [225]:
model_sk = LogisticRegression(solver="lbfgs", penalty="none")
%time model_sk.fit(np.log(big_train_df.pred_probs.values.reshape(-1, 1)), big_train_df.labels.values)
print([model_sk.coef_, model_sk.intercept_])

model_sk = SGDClassifier(
    loss='log', 
    early_stopping=True, 
    learning_rate='adaptive', 
    eta0=1e-1, 
    max_iter=100, 
    alpha=1e-12
)

%time model_sk.fit(np.log(big_train_df.pred_probs.values.reshape(-1, 1)), big_train_df.labels.values)
print([model_sk.coef_, model_sk.intercept_])

CPU times: user 23.7 s, sys: 27 s, total: 50.7 s
Wall time: 1.97 s
[array([[0.85900524]]), array([-0.49485377])]
CPU times: user 25.2 s, sys: 3.6 s, total: 28.8 s
Wall time: 24.1 s
[array([[0.86293342]]), array([-0.48265425])]


We are interested in constructing an estimator of the integrated calibration index (ICI) and matching conditional frequency metric (MCF) on a test set. These measures rely on estimating nuisance densities of p(y|f(x)) for the full test set and for each group. For scalability reasons, it is desired to be able to train models for these densities using minibatch SGD, which performs best with a validation set for early stopping. Furthermore, this approach can be improved by using a cross-fitting estimator, repeating the entire process on separate folds of the data.

For the standard procedure (no cross-fitting)
    * Split data into 90/10 split, stratified by group and outcome
    * For each group, train a logistic regression model using the corresponding within-group 10% as early-stopping validation. Repeat for the aggregate sample.
    * Compute ICI and MCF as plug-in estimates
   
For the cross-fitting procedure
    * Split data in to K-folds, stratified by group and outcome
    * For each fold i, perform the standard estimation procedure, training with data not in fold i and evaluating on fold i
    * Take the mean of each value of interest across the folds