In [None]:
%cd ..

%config InlineBackend.figure_format = "retina"

In [None]:
import os

# https://discuss.pytorch.org/t/gpu-device-ordering/60785/2
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from batchgenerators.dataloading.single_threaded_augmenter import \
    SingleThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import load_pickle
from torch import Tensor

from contrast_gan_3D import config, utils
from contrast_gan_3D.alias import ScanType
from contrast_gan_3D.data import utils as data_u
from contrast_gan_3D.eval.CCTAContrastCorrector import CCTAContrastCorrector
# from contrast_gan_3D.experiments.basic_conf import *
from contrast_gan_3D.experiments.conf_2D import *
from contrast_gan_3D.model import utils as model_utils
# from contrast_gan_3D.experiments.test_conf_2D import *
# from contrast_gan_3D.experiments.gradient_penalty_conf import *
from contrast_gan_3D.model.loss import HULoss
from contrast_gan_3D.trainer import utils as train_u
from contrast_gan_3D.trainer.logger.LoggerInterface import SingleThreadedLogger
from contrast_gan_3D.trainer.Trainer import Trainer
from contrast_gan_3D.utils import set_GPU
from contrast_gan_3D.utils import visualization as viz

In [None]:
splits = load_pickle("/home/marco/thesis_project/contrast-gan-3D/cross_val_splits.pkl")
print(len(splits))

train_fold, val_fold = splits["train"][0], splits["test"][0]

In [None]:
logger_interface = SingleThreadedLogger(logger_interface.logger)

# chosen_bs = train_batch_size
chosen_bs = {
    v.value: b for v, b in [(ScanType.OPT, 1), (ScanType.LOW, 1), (ScanType.HIGH, 1)]
}
# chosen_ps = train_patch_size
chosen_ps = val_patch_size

subopt_bs = (
    chosen_bs[ScanType.LOW.value] + chosen_bs[ScanType.HIGH.value],
    1,
    *chosen_ps,
)
opt_bs = (chosen_bs[ScanType.OPT.value], 1, *chosen_ps)
print(subopt_bs, opt_bs)

In [None]:
train_loaders, val_loaders = train_u.create_dataloaders(
    train_fold,
    val_fold,
    train_patch_size,
    val_patch_size,
    chosen_bs,
    chosen_bs,
    rng,
    scaler=scaler,
    num_workers=num_workers,
    train_transform=train_transform,
    seed=seed,
)

train_loaders = {
    k: SingleThreadedAugmenter(v.generator, v.transform)
    for k, v in train_loaders.items()
}
val_loaders = {
    k: SingleThreadedAugmenter(v.generator, v.transform)
    for k, v in val_loaders.items()
}

# chosen_loaders = train_loaders
chosen_loaders = val_loaders

In [None]:
scaled_HU_bounds = scaler(np.array(desired_HU_bounds))
print(scaled_HU_bounds)

device = set_GPU(7)

log_images_every, train_generator_every, log_every = 1, 1, 1

trainer = Trainer(
    train_iterations,
    val_iterations,
    validate_every,
    train_generator_every,
    train_critic_every,
    log_every,
    log_images_every,
    generator_class,
    critic_class,
    generator_optim_class,
    critic_optim_class,
    HULoss(*scaled_HU_bounds, subopt_bs),
    logger_interface,
    val_batch_size,
    weight_clip=weight_clip,
    generator_lr_scheduler_class=generator_lr_scheduler_class,
    critic_lr_scheduler_class=critic_lr_scheduler_class,
    device=device,
    checkpoint_every=None,
    rng=rng,
)

# checkpoint_path = "/home/marco/contrast-gan-3D/logs/model_checkpoints/9hnh7gto.pt"
# checkpoint_path = "/home/marco/contrast-gan-3D/logs/model_checkpoints/07qiygyk.pt"
# checkpoint_path = train_u.find_latest_checkpoint(config.CHECKPOINTS_DIR / "6en9vikh")
checkpoint_path = train_u.find_latest_checkpoint(config.CHECKPOINTS_DIR / "17pgi67n")
checkpoint_path = Path(checkpoint_path)

trainer.load_checkpoint(checkpoint_path)

In [None]:
print(model_utils.count_parameters(trainer.critic))
print(model_utils.count_parameters(trainer.generator))
# print(trainer.critic)
# print("-----------")
# print(trainer.generator)

In [None]:
print(list(ScanType))
patches = [next(chosen_loaders[scan_type.value]) for scan_type in ScanType]
print([p["data"].shape for p in patches])

self = trainer
iteration = 0

opt_d, low_d, high_d = patches

In [None]:
low = low_d["data"].to(device, non_blocking=True)
attenuation_low: Tensor = self.generator(low)
low, attenuation_low = utils.to_CPU(low), utils.to_CPU(attenuation_low)
low_recon = low - attenuation_low
torch.cuda.empty_cache()

In [None]:
high = high_d["data"].to(device, non_blocking=True)
attenuation_high: Tensor = self.generator(high)
high, attenuation_high = utils.to_CPU(high), utils.to_CPU(attenuation_high)
high_recon = high - attenuation_high
torch.cuda.empty_cache()

In [None]:
if True:
    opt = opt_d["data"].to(device, non_blocking=True)
    attenuation_opt: Tensor = self.generator(opt)
    opt, attenuation_opt = utils.to_CPU(opt), utils.to_CPU(attenuation_opt)
    opt_recon = opt - attenuation_opt
    torch.cuda.empty_cache()

In [None]:
if self.train_log_sample_size is None:
    self.train_log_sample_size = 64
    if len(high_recon.shape) != 5:  # 2D case
        bs = (len(x["data"]) for x in patches)
        self.train_log_sample_size = min(*bs, self.train_log_sample_size)

In [None]:
from matplotlib import cm
from contrast_gan_3D.utils import swap_last_dim

print(swap_last_dim(cm.RdBu(torch.rand((1, 1, 512, 512, 128))).squeeze()).shape)
print(swap_last_dim(cm.RdBu(torch.rand((1, 1, 512, 512, 6))).squeeze()).shape)

In [None]:
self.logger_interface(
    patches,
    [opt_recon, low_recon, high_recon],
    [attenuation_opt, attenuation_low, attenuation_high],
    list(ScanType),
    iteration,
    "train",
    self.train_log_sample_size,
)

In [None]:
fig = logger_interface.logger.create_attenuation_grid(
    scaler.unscale(opt_recon), [0, ..., slice(0, opt_recon.shape[-1], 2)], False
)
plt.show()
plt.close(fig)

In [None]:
print(desired_HU_bounds, scaler(np.array(desired_HU_bounds)))

p = opt.ravel()[::10]
p_unscaled = scaler.unscale(p)
p_recon = scaler.unscale(opt_recon.ravel()[::10])
name = opt_d["name"]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].hist(p, bins=80)
axes[0].set_title("Generator's input")
axes[1].hist(p_unscaled, bins=80)
axes[1].set_title("Original")
axes[2].hist(p_recon, bins=80)
axes[2].set_title("Reconstructed")

fig.suptitle(name)
plt.tight_layout()
plt.show()
# plt.savefig(savefolder / f"{name}_hist.png")
plt.close(fig)

In [None]:
for d, tensor, rec_tensor, sc in zip(
    [opt_d, low_d, high_d],
    [opt, low, high],
    [opt_recon, low_recon, high_recon],
    ScanType,
):
    # fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    og_ctls = scaler.unscale(tensor[d["seg"]])
    recon_ctls = scaler.unscale(rec_tensor[d["seg"]])
    assert og_ctls.shape == recon_ctls.shape
    # axes[0].hist(og_ctls, bins=80)
    # axes[0].set_title("Original centerlines")
    # axes[1].hist(recon_ctls, bins=80)
    # axes[1].set_title("Corrected centerlines")

    fig, axes = plt.subplots(figsize=(8, 5))
    args = {
        "alpha": 0.5,
        "bins": 80,
        # "density": True
    }
    axes.hist(og_ctls, label="Original", **args)
    axes.hist(recon_ctls, label="Corrected", **args)
    fig.legend()
    fig.suptitle(d["name"][0] + f" {sc.name}")
    fig.tight_layout()
    plt.show()
    plt.close(fig)

In [None]:
scan, meta = data_u.load_patient(low_d["path"][0])
print(meta["name"], low_d["path"][0])

In [None]:
corrector_3D = CCTAContrastCorrector(
    generator_class,
    scaler,
    device,
    inference_patch_size=train_patch_size,
    checkpoint_path=checkpoint_path,
)

In [None]:
high_corrected = corrector_3D(scan[..., 0], desc=meta["name"])
print(scan.shape, high_corrected.shape)
torch.cuda.empty_cache()

In [None]:
savefolder = Path("/home/marco/data/test_inference/")
corrector_3D.save_scan(
    high_corrected, meta["offset"], meta["spacing"], savefolder / f"{meta['name']}"
)

In [None]:
if False:
    folder = Path("/home/marco/thesis_project/contrast-gan-3D/assets")

    SHOW = True

    for sc in ScanType:
        print(sc)
        print("----------------------------------")
        savefolder = folder / f"{sc.name}_unnormed"
        savefolder.mkdir(exist_ok=True, parents=True)

        for i in range(len(patches[sc.value]["data"])):
            p = patches[sc.value]["data"][i]
            p_seg = patches[sc.value]["seg"][i]
            p_unscaled = scaler.unscale(p)
            name = patches[sc.value]["name"][i]

            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            axes[0].hist(p.ravel()[::5], bins=80)
            axes[0].set_title("Scaled")
            axes[1].hist(p_unscaled.ravel()[::5], bins=80)
            axes[1].set_title("Original")
            fig.suptitle(name)
            plt.tight_layout()
            if SHOW:
                plt.show()
            else:
                plt.savefig(savefolder / f"{name}_hist.png")
            plt.close(fig)

            fig = viz.plot_axial_slices_and_centerlines(
                p_unscaled[..., ::2],
                p_seg[..., ::2],
                # **logger_interface.logger.grid_args,
                normalize=True,
                value_range=(p_unscaled.min().item(), p_unscaled.max().item()),
                cbar=True,
            )
            fig.suptitle(name)
            plt.tight_layout()
            if SHOW:
                plt.show()
            else:
                plt.savefig(savefolder / f"{name}.png")
            plt.close(fig)

            fig = logger_interface.logger.create_attenuation_grid(
                p_unscaled, [0, ..., slice(0, 128, 2)], scale_by_factor=False
            )
            fig.suptitle(name)
            plt.tight_layout()
            if SHOW:
                plt.show()
            else:
                plt.savefig(savefolder / f"{name}_attenuation.png")
            plt.close(fig)

In [None]:
model_utils.compute_convolution_filters_shape(
    trainer.critic, patches[0]["data"].shape[1:], show=True
)
print("----")
model_utils.compute_convolution_filters_shape(
    trainer.generator, patches[0]["data"].shape[1:], show=True
)

In [None]:
attenuations = trainer.generator(subopt)
recon = subopt - attenuations

print(attenuations[1].min(), attenuations[1].max())

D_real  = trainer.critic(opt)
D_fake = trainer.critic(recon.detach())

In [None]:
from torchviz import make_dot

from contrast_gan_3D.model.utils import wgan_gradient_penalty

loss_D = trainer.loss_GAN(D_fake, D_real)
gp = wgan_gradient_penalty(
    opt.repeat((2,) + (1,) * len(opt.shape[1:])),
    recon,
    trainer.critic,
    trainer.device,
)
if False:
    loss_D += gp


# loss_D.backward()

loss_G = -trainer.loss_GAN(trainer.critic(recon))
# loss_G.backward()

In [None]:
make_dot(gp)

In [None]:
from torch import nn

# idea of 1D convolution: bottleneck!
inp_shape = (256, 128, 128, 128)
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 64, 1, 1, 0), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 64, 3, 1, 1), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 256, 1, 1, 0), inp_shape)
print('------')
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 256, 3, 1, 1), inp_shape)