In [1]:
import sys
import argparse
import random
import copy
import pyro
import torch
import os
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
import os
# os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES']='0'
sys.path.append('..')
sys.path.append('../..')
from causal_models.trainer import preprocess_batch
from train_setup import setup_directories, setup_tensorboard, setup_logging
from train_setup import setup_dataloaders
# From datasets import get_attr_max_min
from utils import EMA, seed_all
import torch.nn.functional as F
from pgm.train_pgm_segmentor import sup_epoch, eval_epoch
from pgm.utils_pgm import check_nan, update_stats, calculate_loss, plot_cf
from pgm.layers import TraceStorage_ELBO
from pgm.chest_pgm_segmentor import FlowPGM_with_seg

from dscm.train_cf import preprocess

### Load checkpoint

In [2]:
class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

# Load predictors
exp_name = "sup_aux_mimic_256_64_segmentor"
predictor_path = f'../pgm/checkpoints/Left-Lung_volume_Right-Lung_volume_Heart_volume/{exp_name}/checkpoint.pt'
print(f'\nLoading predictor checkpoint: {predictor_path}')
predictor_checkpoint = torch.load(predictor_path)
predictor_args = Hparams()
predictor_args.update(predictor_checkpoint['hparams'])

predictor = FlowPGM_with_seg(predictor_args).cuda()
predictor.load_state_dict(predictor_checkpoint['ema_model_state_dict'])

predictor_args.batch_size = 5



Loading predictor checkpoint: ../pgm/checkpoints/Left-Lung_volume_Right-Lung_volume_Heart_volume/sup_aux_mimic_256_64_segmentor/checkpoint.pt


### Dataloader

In [3]:
dataloaders = setup_dataloaders(predictor_args, cache=False, shuffle_train=False)

{'project_name': 'chest X ray generation', 'seed': 11, 'mixed_precision': False, 'is_unit_test_config': False, 'data': {'batch_size': 5, 'num_workers': 12, 'pin_memory': True, 'input_channels': 1, 'weights': 'None', 'augmentations': {'resize': [256, 64], 'center_crop': 'None', 'random_rotation': 0, 'horizontal_flip': False, 'vertical_flip': False, 'random_crop': 'None', 'random_color_jitter': 0.1, 'random_erase_scale': [0.0, 0.0], 'sharp': 0.0}, 'prop_train': 1.0, '_target_': 'data_handling.chest_xray.MimicDataModule', 'dataset': 'mimic', 'domain': 'None', 'cache': False, 'seg_target_list': ['Left-Lung', 'Right-Lung', 'Heart']}, 'trainer': {'name': 'base', 'lr': 0.001, 'num_epochs': 400, 'patience_for_scheduler': 10, 'metric_to_monitor': 'Val/loss', 'metric_to_monitor_mode': 'min', 'val_check_interval': 'None', 'weight_decay': 0.0, 'use_train_augmentations': True, 'loss': 'ce', 'contrastive_temperature': 0.1, 'return_two_views': False, 'finetune_path': 'None', 'device': [0], 'max_steps

  df['disease'] = df['disease'].replace({'No Finding': 0, 'Pleural Effusion': 1})


sex
Male      0.516252
Female    0.483748
Name: proportion, dtype: float64
disease
No Finding          0.541031
Pleural Effusion    0.458969
Name: proportion, dtype: float64
race
White    0.764848
Black    0.188804
Asian    0.046348
Name: proportion, dtype: float64
Len dataset 9968
Test df: 


  df['disease'] = df['disease'].replace({'No Finding': 0, 'Pleural Effusion': 1})


sex
Male      0.522024
Female    0.477976
Name: proportion, dtype: float64
disease
No Finding          0.546651
Pleural Effusion    0.453349
Name: proportion, dtype: float64
race
White    0.772491
Black    0.186999
Asian    0.040511
Name: proportion, dtype: float64
Len dataset 30535


  df['disease'] = df['disease'].replace({'No Finding': 0, 'Pleural Effusion': 1})


In [5]:
batch = next(iter(dataloaders['valid']))

In [7]:
print(batch.keys())
print(batch['Left-Lung_volume'].size())

dict_keys(['shortpath', 'age', 'race', 'sex', 'finding', 'x', 'Left-Lung', 'Left-Lung_volume', 'Right-Lung', 'Right-Lung_volume', 'Heart', 'Heart_volume', 'pa'])
torch.Size([5, 1])


### Evaluate segmentor

In [4]:
test_stats = eval_epoch(predictor, dataloaders['valid'])
print('test | '+' - '.join(f'{k}: {v:.4f}' for k, v in test_stats.items()))

 26%|██▌       | 505/1966 [00:35<01:41, 14.38it/s]


KeyboardInterrupt: 

### Plot segmentations

In [None]:
def plot_segmentations(orig_batch, pred_batch, num_rows=5):
    seg_key = ["Left-Lung", "Right-Lung", "Heart"]
    n = num_rows
    m = 7
    s = 3
    fs = 12
    # Original image and GT segmentations
    x = (orig_batch['x'].detach().cpu()+1) *127.5
    GT_segs = {k: v.detach().cpu() for k,v in orig_batch.items() if k in seg_key}
    
    # Predicted segmentations
    pred_segs = {k: v.detach().cpu() for k,v in pred_batch.items() if k in seg_key}

    # Plot original images and segmentations
    fig, ax = plt.subplots(n, m, figsize=(m * s,n * s))

    for i in range(n):
        # Original image
        ax[i,0].imshow(x[i,0],cmap="gray")

        # GT segmentations
        ax[i,1].imshow(GT_segs['Left-Lung'][i,0], cmap="gray")
        ax[i,2].imshow(GT_segs['Right-Lung'][i,0], cmap="gray")
        ax[i,3].imshow(GT_segs['Heart'][i,0], cmap="gray")

        # Pred segmentations
        ax[i,4].imshow(pred_segs['Left-Lung'][i,0], cmap="gray")
        ax[i,5].imshow(pred_segs['Right-Lung'][i,0], cmap="gray")
        ax[i,6].imshow(pred_segs['Heart'][i,0], cmap="gray")

    for i in range(n):
        for j in range(m):
            idx = (i,j)
            ax[idx].axes.xaxis.set_ticks([])
            ax[idx].axes.yaxis.set_ticks([])

    ax[0, 0].set_title('Observation', fontsize=fs + 2, pad=8)
    ax[0, 1].set_title('GT Left Lung', fontsize=fs + 2, pad=8)
    ax[0, 2].set_title('GT Right Lung', fontsize=fs + 2, pad=8)
    ax[0, 3].set_title('GT Heart', fontsize=fs + 2, pad=8)
    ax[0, 4].set_title('Pred Left Lung', fontsize=fs + 2, pad=8)
    ax[0, 5].set_title('Pred Right Lung', fontsize=fs + 2, pad=8)
    ax[0, 6].set_title('Pred Heart', fontsize=fs + 2, pad=8)
    return fig    


In [None]:
save_dir = f"pred_segs/{exp_name}/"
os.makedirs(save_dir , exist_ok=True)
plt_count=0

for batch in tqdm(dataloaders['test']):
    batch = preprocess(batch)
    pred_segs = predictor.predict_segmentations(**batch)
    plt_count+=1
    
    pred_volumes = predictor.predict_volumes(**batch)
    save_path = os.path.join(save_dir, f'test_segmentations_{plt_count}.png')
    _ = plot_segmentations(orig_batch=batch,
                           pred_batch=pred_segs,
                           num_rows=5)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    if plt_count>20:
        break

In [None]:
print(pred_volumes['Left-Lung_volume'])
print(pred_volumes['Right-Lung_volume'])
print(pred_volumes['Heart_volume'])

In [None]:
print(batch['Left-Lung_volume'][:,0])
print(batch['Right-Lung_volume'][:,0])
print(batch['Heart_volume'][:,0])