In [1]:
import sys
import argparse
import random
import copy
import pyro
import torch
import os
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'
sys.path.append('..')
sys.path.append('../..')
from train_setup import setup_directories, setup_tensorboard, setup_logging
from train_setup import setup_dataloaders
# From datasets import get_attr_max_min
from utils import EMA, seed_all
from dscm import DSCM
from hvae2 import HVAE2
import torch.nn.functional as F
from pgm.train_pgm import sup_epoch, eval_epoch
from pgm.utils_pgm import check_nan, update_stats, calculate_loss, plot_cf
from dscm import vae_preprocess
from pgm.layers import TraceStorage_ELBO
from pgm.chest_pgm import FlowPGM

from train_cf import norm, loginfo, preprocess, inv_preprocess

### Load predictors

In [None]:
class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

# Load predictors
predictor_path = '../pgm/checkpoints/scanner_sex_finding/sup_aux_padchest/checkpoint.pt'
print(f'\nLoading predictor checkpoint: {predictor_path}')
predictor_checkpoint = torch.load(predictor_path)
predictor_args = Hparams()
predictor_args.update(predictor_checkpoint['hparams'])

predictor_args.loss_norm = "l2"

predictor = FlowPGM(predictor_args).cuda()
predictor.load_state_dict(predictor_checkpoint['ema_model_state_dict'])


### Load PGM

In [None]:
pgm_path = '../pgm/checkpoints/scanner_sex_finding/sup_pgm_padchest/checkpoint.pt'
print(f'\nLoading PGM checkpoint: {pgm_path}')
pgm_checkpoint = torch.load(pgm_path)
pgm_args = Hparams()
pgm_args.update(pgm_checkpoint['hparams'])
pgm = FlowPGM(pgm_args).cuda()
pgm.load_state_dict(pgm_checkpoint['ema_model_state_dict'])

### Load HVAE

In [None]:
# Load deep VAE
beta = 3
vae_path = f"../checkpoints/scanner_sex_finding/padchest224_224_beta_{beta}/checkpoint.pt"

print(f'\nLoading VAE checkpoint: {vae_path}')
vae_checkpoint = torch.load(vae_path)
vae_args = Hparams()
vae_args.batch_size = 10

vae_args.update(vae_checkpoint['hparams'])

vae = HVAE2(vae_args).cuda()
vae.load_state_dict(vae_checkpoint['ema_model_state_dict'])

### Set up dataset

In [None]:
dataloaders = setup_dataloaders(vae_args, cache=False, shuffle_train=True)

### Test predictor

In [6]:
elbo_fn = TraceStorage_ELBO(num_particles=1)
# test_stats = eval_epoch(predictor, dataloaders['valid'])
# print('test | '+' - '.join(f'{k}: {v:.4f}' for k, v in test_stats.items()))

### Load DSCM

In [None]:
class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

args = Hparams()

dscm_dir = "padchest_beta_9_5_focus_finding_soft_lr_1e4_lagrange_lr_1_damping_10"


which_checkpoint="6500_checkpoint"

args.load_path = f'checkpoints/scanner_sex_finding/{dscm_dir}/{which_checkpoint}.pt'
print(args.load_path)
dscm_checkpoint = torch.load(args.load_path )
args.update(dscm_checkpoint['hparams'])
model = DSCM(args, pgm, predictor, vae)
args.cf_particles =1
model.load_state_dict(dscm_checkpoint['ema_model_state_dict'])
model.cuda()

# Set model require_grad to False
for p in model.parameters():
    p.requires_grad = False


### Plot functions

In [8]:
_VMIN,_VMAX = -120, 120
def undo_norm(pa):
    for k, v in pa.items():
        if k =="age":
            pa[k] = (v + 1) / 2 *100 # [-1,1] -> [0,100]
    return pa

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors

class MidpointNormalize(colors.Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        v_ext = np.max( [ np.abs(self.vmin), np.abs(self.vmax) ] )
        x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))

@torch.no_grad()
def plot(x, fig=None, ax=None, nrows=1, cmap='Greys_r', norm=None, cbar=False, set_cbar_ticks=True, logger=None):
    m, n = nrows, x.shape[0] // nrows
    if ax is None:
        fig, ax = plt.subplots(m, n, figsize=(n * 4, 8))
    im = []
    for i in range(m):
        for j in range(n):
            idx = (i, j) if m > 1 else j
            ax = [ax] if n == 1 else ax
            _x = x[i * n + j].squeeze()
            if norm is not None:
                norm = MidpointNormalize(vmin=_x.min(), midpoint=0, vmax=_x.max())
                # norm = colors.TwoSlopeNorm(vmin=_x.min(), vcenter=0., vmax=_x.max())
            # logger.info(f"ax[idx] is: {type(ax[idx])}, m: {m}, n: {n}, shape: {np.shape(ax[idx])}")
            _im = ax[idx].imshow(_x, cmap=cmap, norm=norm)
            im.append(_im)
            ax[idx].axes.xaxis.set_ticks([])
            ax[idx].axes.yaxis.set_ticks([])

    if cbar:
        if fig:
            fig.subplots_adjust(wspace=-0.275, hspace=0.25)
        for i in range(m):
            for j in range(n):
                idx = [i, j] if m > 1 else j
                # cbar_ax = fig.add_axes([
                #     ax[idx].get_position().x0 + 0.0025, # left
                #     ax[idx].get_position().y1, # bottom
                #     0.003, # width
                #     ax[idx].get_position().height # height
                # ])
                cbar_ax = fig.add_axes([
                    ax[idx].get_position().x0,
                    ax[idx].get_position().y0 - 0.015,
                    ax[idx].get_position().width,
                    0.0075
                ])
                cbar = plt.colorbar(im[i * n + j], cax=cbar_ax,
                                    orientation="horizontal")  # , ticks=mticker.MultipleLocator(25)) #, ticks=mticker.AutoLocator())
                # cbar.ax.tick_params(rotation=0)
                # cbar.ax.locator_params(nbins=5)
                _x = x[i * n + j].squeeze()

                if set_cbar_ticks:
                    d = 20
                    _vmin, _vmax = _x.min().abs().item(), _x.max().item()
                    _vmin = -(_vmin - (_vmin % d))
                    _vmax = _vmax - (_vmax % d)
                    lt = [_vmin, 0, _vmax]

                    if (np.abs(_vmin) - 0) > d or (_vmax - 0) > d:
                        lt.insert(1, _vmin // 2)
                        lt.insert(-2, _vmax // 2)
                    cbar.set_ticks(lt)
                else:
                    cbar.ax.locator_params(nbins=5)
                    cbar.formatter.set_powerlimits((0, 0))

                cbar.outline.set_visible(False)
    return fig, ax


@torch.no_grad()
def save_plot(save_path, obs, cfs, do, var_cf_x=None, num_images=10):
    _ = plot_cf(obs['x'], cfs['x'], 
        inv_preprocess({k: v for k, v in obs.items() if k != 'x'}),  # pa
        inv_preprocess({k: v for k, v in cfs.items() if k != 'x'}),  # cf_pa
        inv_preprocess(do), # Counterfactual variance per pixel
        var_cf_x = var_cf_x,
        num_images=num_images,
    )
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


@torch.no_grad()
def plot_cf(x, cf_x, pa, cf_pa, do, var_cf_x=None, num_images=8, logger=None):
    n = num_images  # 8 columns
    x = (x[:n].detach().cpu() + 1) * 127.5
    cf_x = (cf_x[:n].detach().cpu() + 1) * 127.5
    # logger.info(f"x: {x.size()}")
    fs = 12  # font size
    m = 3 if var_cf_x is None else 4  # nrows
    s = 5
    fig, ax = plt.subplots(m, n, figsize=(n * s, m * s))
    # fig, ax = plt.subplots(m, n, figsize=(n*s, m*s+2))
    # logger.info(f"ax: {np.shape(ax)}")
    # logger.info(f"ax[0]: {type(ax[0])} {np.shape(ax[0])}, m: {m}, s: {s}, n: {n}")
    _, _ = plot(x, ax=ax[0])
    _, _ = plot(cf_x, ax=ax[1])
    _, _ = plot(cf_x - x, ax=ax[2], fig=fig, cmap='RdBu_r', cbar=True,
                norm=MidpointNormalize(midpoint=0))
    if var_cf_x is not None:
        _, _ = plot(var_cf_x[:n].detach().sqrt().cpu(),
                    fig=fig, cmap='jet', ax=ax[3], cbar=True, set_cbar_ticks=False)

    sex_categories = ['male', 'female']  # 0, 1
    finding_categories = ['No finding', 'Finding']  # 0, 1
    scanner_categories = ['Phillips', 'Imaging']  # 0, 1

    for j in range(n):
        msg = ''
        for i, (k, v) in enumerate(do.items()):
            if k == 'sex':
                vv = sex_categories[int(v[j].item())]
                kk = 's'
            elif k == 'finding':
                vv = finding_categories[int(v[j].item())]
                kk = 'f'
            elif k == 'scanner':
                vv = scanner_categories[int(v[j].item())]
                kk = 'sc'
            else:
                continue
            msg += kk + '{{=}}' + vv
            msg += ', ' if (i + 1) < len(list(do.keys())) else ''

        if 'sex' in pa.keys():
            s = str(sex_categories[int(pa['sex'][j].item())])
            f = str(finding_categories[int(pa['finding'][j].item())])
            sc = str(scanner_categories[int(pa['scanner'][j].item())])

        if 'sex' in pa.keys():
            ax[0, j].set_title(f's={s}, f={f}, sc={sc}',
                               pad=8, fontsize=fs, 
                               multialignment='center', linespacing=1.5)
            ax[1, j].set_title(f'do(${msg}$)', fontsize=fs, pad=10)

        # Plot counterfactual
        if 'sex' in cf_pa.keys():
            cf_s = str(sex_categories[int(cf_pa['sex'][j].item())])
            cf_f = str(finding_categories[int(cf_pa['finding'][j].item())])
            cf_sc = str(scanner_categories[int(cf_pa['scanner'][j].item())])

            ax[1, j].set_xlabel(
                rf'$\widetilde{{s}}{{=}}{cf_s}, \ \widetilde{{f}}{{=}}{cf_f}, \ \widetilde{{sc}}{{=}}{cf_sc}$',
                labelpad=9, fontsize=fs, multialignment='center', linespacing=1.25)

    ax[0, 0].set_ylabel('Observation', fontsize=fs + 2, labelpad=8)
    ax[1, 0].set_ylabel('Counterfactual', fontsize=fs + 2, labelpad=8)
    ax[2, 0].set_ylabel('Treatment Effect', fontsize=fs + 2, labelpad=8)
    if var_cf_x is not None:
        ax[3, 0].set_ylabel('Uncertainty', fontsize=fs + 2, labelpad=8)

    return fig

### Evaluate Counterfactuals

In [9]:
args = Hparams()
args.save_dir = f"results_tech_demo/{dscm_dir}/{which_checkpoint}"
os.makedirs(args.save_dir , exist_ok=True)

model.pgm.eval()
model.predictor.eval()
model.vae.eval()
dag_variables = list(model.pgm.variables.keys())
preds = {k: [] for k in dag_variables}
targets = {k: [] for k in dag_variables}
# args.save_dir = f"../../results/{dscm_dir}/{which_checkpoint}"
# os.makedirs(args.save_dir , exist_ok=True)
# loader = tqdm(enumerate(dataloaders['test']), total=len(
#     dataloaders['test']), mininterval=0.1)

@torch.no_grad()
def eval_counterfactuals(model, dataloader, predictor, do_pa=None):
    ' this can consume lots of memory if dataset is large'
    model.pgm.eval()
    model.predictor.eval()
    predictor.eval()
    model.vae.eval()
    dag_variables = list(model.pgm.variables.keys())
    preds = {k: [] for k in dag_variables}
    targets = {k: [] for k in dag_variables}
    plt_counter = 0
    cf_particles=1

    for batch in tqdm(dataloader):
        # if plt_counter>10:
        #     continue
        plt_counter+=1
        bs = batch['x'].shape[0]
        batch = preprocess(batch)
        # randomly intervene on a single parent, where pa_k ~ p(pa_k)
        do = {}
        do_k = copy.deepcopy(do_pa) if do_pa else random.choice(dag_variables) 

        # do[do_k] = batch[do_k].clone()[torch.randperm(bs)]

        do[do_k] =1-  batch[do_k].clone()
        do = preprocess(norm(do))

        # get counterfactual pa
        pa = {k: v for k, v in batch.items() if k != 'x'}     
        _pa = vae_preprocess(
            vae_args, {k: v.clone() for k, v in pa.items()})   
        # cf_pa = model.pgm.counterfactual(obs=pa, intervention=do, num_particles=1)       
        
        # get counterfactual x
        out = model.forward(batch, do, elbo_fn, cf_particles=cf_particles)
        cf_pa = out['cf_pa']
        _cf_pa = vae_preprocess(
                vae_args, {k: v.clone() for k, v in cf_pa.items()})

        nans = 0
        for k, v in out['cfs'].items():
        # for k, v in cfs.items():
            k_nans = torch.isnan(v).sum()
            nans += k_nans
            if k_nans > 0:
                print(f'\nFound {k_nans} nans in cf {k}.')
        if nans > 0:
            continue

        predict_out = predictor.predict(**out['cfs'])
        # predict_out = model.predictor.predict(**cfs)


        for k, v in predict_out.items():
            preds[k].extend(v)
        
        # interventions are the targets for prediction
        for k in targets.keys():
            if k in do.keys():
                targets[k].extend(
                    do[k]
                )
            else:
                targets[k].extend(
                    cf_pa[k]
                )

        if plt_counter<20:
            save_path = os.path.join(args.save_dir, f'test_{do_k}_{plt_counter}_cfs.png')
            save_plot(save_path, batch, out['cfs'], do,  
                      var_cf_x = out['var_cf_x'],
                      num_images=4)
        # else:
        #     break
     
    for k, v in preds.items():
        preds[k] = torch.stack(v).squeeze().cpu()
        targets[k] = torch.stack(targets[k]).squeeze().cpu()
        # print(f'{k} | preds: {preds[k].shape} - targets: {targets[k].shape}')

    stats = {}
    for k in dag_variables:
        if k in ['sex', 'finding', 'scanner']:
            stats[k+'_rocauc'] = roc_auc_score(
                targets[k].numpy(), preds[k].numpy(), average='macro')
            # stats[k+'_acc'] = (targets[k].squeeze(-1) == torch.round(preds[k])).sum().item() / targets[k].shape[0]
        elif k == 'age':
            stats[k] = torch.mean(torch.abs(targets[k] - preds[k])).item() * 100
        elif k == 'race':
            num_corrects = (targets[k].argmax(-1) == preds[k].argmax(-1)).sum()
            stats[k + "_acc"] = num_corrects.item() / targets[k].shape[0]
            stats[k + "_rocauc"] = roc_auc_score(
                targets[k].numpy(),
                preds[k].numpy(),
                multi_class="ovr",
                average="macro",)
    return stats, preds, targets


In [None]:
stats_do = {
    'scanner':{}, 
    'sex':{}, 
    'finding':{},
            }
preds_do = {
    'scanner':{}, 
    'sex':{}, 
    'finding':{},
            }
targets_do = {
    'scanner':{}, 
    'sex':{}, 
    'finding':{},
            }

# base_stats = {
#     'race_rocauc':0.8468415443534377,
#     'sex_rocauc': 0.996207396871533,
#     'finding_rocauc':0.9443025803264049,
#     'age':6.17903545498848 ,
# }

for do_v in stats_do.keys():
    stats, preds, targets = eval_counterfactuals(model, dataloaders['valid'], predictor, do_pa=do_v)
    # stats = eval_random(model, dataloaders['valid'], do_pa=do_v)
    stats_do[do_v] = stats

# stats, preds, targets = eval_counterfactuals(model, dataloaders['valid'], predictor_for_evaluation)
# # stats = eval_random(model, dataloaders['valid'])
# stats_do['random'] = stats

for do_v in stats_do.keys():
    print(f'do_{do_v} | '+' - '.join(f'{k}: {v:.3f}' for k,v in stats_do[do_v].items()))

In [None]:
for do_v in stats_do.keys():
    print(f'do_{do_v} | '+' - '.join(f'{k}: {v:.3f}' for k,v in stats_do[do_v].items()))