In [None]:
import sys
import argparse
import random
import copy
import pyro
import torch
import os
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'
sys.path.append('..')
sys.path.append('../..')
from causal_models.trainer import preprocess_batch
from train_setup import setup_directories, setup_tensorboard, setup_logging
from train_setup import setup_dataloaders
# From datasets import get_attr_max_min
from utils import EMA, seed_all
from dscm.dscm import DSCM
from hvae2 import HVAE2
import torch.nn.functional as F
from pgm.train_pgm import sup_epoch, eval_epoch
from pgm.utils_pgm import check_nan, update_stats, calculate_loss, plot_cf
from dscm.dscm import vae_preprocess
from pgm.layers import TraceStorage_ELBO
from pgm.chest_pgm_segmentor import FlowPGM_with_seg 

from unet import ResUnet


### Setup dataloaders

In [None]:
class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

# Load predictors
predictor_path = '../pgm/checkpoints/Left-Lung_volume_Right-Lung_volume_Heart_volume/sup_aux_mimic_256_64_with_segmentation/checkpoint.pt'
print(f'\nLoading predictor checkpoint: {predictor_path}')
predictor_checkpoint = torch.load(predictor_path)
predictor_args = Hparams()
predictor_args.update(predictor_checkpoint['hparams'])

predictor_args.loss_norm = "l2"
predictor_args.setup = "sup_seg"

# Load deep VAE
beta = 3
vae_path = f"../checkpoints/Left-Lung_volume_Right-Lung_volume_Heart_volume/mimic_crop_256_64_beta_{beta}_segmentations/checkpoint.pt"

print(f'\nLoading VAE checkpoint: {vae_path}')
vae_checkpoint = torch.load(vae_path)
vae_args = Hparams()
vae_args.batch_size = 10

vae_args.update(vae_checkpoint['hparams'])


In [None]:
model = FlowPGM_with_seg(predictor_args)

In [None]:
dataloaders = setup_dataloaders(vae_args, cache=False, shuffle_train=False)

In [None]:
for_count = 0
for batch in tqdm(dataloaders["valid"]):
    for_count+=1
    
    segs = model.predict_segmentations(**batch)
    if for_count>1:
        break


In [None]:
def BCEDiceloss(input, target):
    pred = input.view(-1).type(torch.DoubleTensor)
    truth = target.view(-1).type(torch.DoubleTensor)
    # BCE loss
    bce_loss = torch.nn.BCELoss()(pred, truth)
    # Dice Loss
    dice_coef = (2.0 * (pred * truth).float().sum() + 1) / (
        pred.float().sum() + truth.float().sum() + 1
    )
    return bce_loss + (1 - dice_coef)

def segmentation_loss(pred_batch, target_batch):
    """Calculate the segmentation loss."""
    loss=0
    for k in pred_batch.keys():
        assert pred_batch[k].size()==target_batch[k].size(), f"{k} size does not match, pred_batch size {pred_batch[k].size()}; target batch size {target_batch[k].size()}"
        if k in ["Left-Lung", "Right-Lung", "Heart"]:
            loss+=BCEDiceloss(pred_batch[k], target_batch[k])
    return loss

In [None]:
loss = segmentation_loss(pred_batch=segs, target_batch=batch)


In [None]:
print(loss)