In [None]:
%cd ..

%config InlineBackend.figure_format = "retina"

/home/marco/thesis_project/contrast-gan-3D


In [None]:
import os

# https://discuss.pytorch.org/t/gpu-device-ordering/60785/2
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import numpy as np

from contrast_gan_3D.alias import ScanType
from contrast_gan_3D.experiments.basic_conf import *
from contrast_gan_3D.experiments.test_conf_2D import *
# from contrast_gan_3D.experiments.gradient_penalty_conf import *
from contrast_gan_3D.model.loss import HULoss
from contrast_gan_3D.trainer.Trainer import Trainer
from contrast_gan_3D.trainer.utils import create_dataloaders, cval_paths
from contrast_gan_3D.utils import set_GPU

In [None]:
profiler_dir = None
run_id = "pippo"
fold_idx = 0

train_folds, val_folds = cval_paths(n_cval_folds, "/home/marco/data/ostia_final.xlsx")

scaled_HU_bounds = scaler(np.array(desired_HU_bounds))
print(scaled_HU_bounds)

train_loaders, val_loaders = create_dataloaders(
    train_folds[fold_idx],
    val_folds[fold_idx],
    train_patch_size,
    val_patch_size,
    train_batch_size,
    val_batch_size,
    rng,
    scaler,
    train_transform=train_transform,
)

In [None]:
from contrast_gan_3D.trainer.logger.LoggerInterface import SingleThreadedLogger

logger_interface = SingleThreadedLogger(logger_interface.logger)

chosen_bs = train_batch_size
chosen_ps = train_patch_size
chosen_loaders = train_loaders

# logger_interface.logger.sample_size = chosen_ps[ScanType.OPT.value]

subopt_bs = (
    chosen_bs[ScanType.LOW.value] + chosen_bs[ScanType.HIGH.value],
    1,
    *chosen_ps,
)
opt_bs = (chosen_bs[ScanType.OPT.value], 1, *chosen_ps)
print(subopt_bs, opt_bs)

In [None]:
device = set_GPU(3)

trainer = Trainer(
    train_iterations,
    val_iterations,
    validate_every,
    train_generator_every,
    log_every,
    log_images_every,
    generator_class,
    critic_class,
    generator_optim_class,
    critic_optim_class,
    HULoss(*scaled_HU_bounds, subopt_bs),
    # HULoss(*scaled_HU_bounds, (train_batch_size * 2, 1, *train_patch_size)),
    logger_interface,
    val_batch_size,
    weight_clip=weight_clip,
    generator_lr_scheduler_class=generator_lr_scheduler_class,
    critic_lr_scheduler_class=critic_lr_scheduler_class,
    device=device,
    checkpoint_every=checkpoint_every,
    rng=rng,
)
# trainer.load_checkpoint(
#     # "/home/marco/contrast-gan-3D/logs/model_checkpoints/9hnh7gto.pt"
#     "/home/marco/contrast-gan-3D/logs/model_checkpoints/07qiygyk.pt"
# )

In [None]:
from contrast_gan_3D.model import utils as model_utils

print(model_utils.count_parameters(trainer.critic))
print(model_utils.count_parameters(trainer.generator))

In [None]:
print(trainer.critic)
print("-----------")
print(trainer.generator)

In [None]:
patches = [next(chosen_loaders[scan_type.value]) for scan_type in ScanType]
[p["data"].shape for p in patches]

In [None]:
trainer.train_step(patches, 0)

In [None]:
model_utils.compute_convolution_filters_shape(
    trainer.critic, patches[0]["data"].shape[1:], show=True
)
print("----")
model_utils.compute_convolution_filters_shape(
    trainer.generator, patches[0]["data"].shape[1:], show=True
)

In [None]:
from matplotlib import pyplot as plt

from contrast_gan_3D.alias import ScanType
from contrast_gan_3D.data.HD5Scan import HD5Scan

centerlines_pixels = {k: np.array([]) for k in ScanType}
for sc in centerlines_pixels:
    print(sc)
    arr = centerlines_pixels[sc]
    val_load = val_loaders[sc.value]
    for p in val_load.generator._data:
        print(p)
        with HD5Scan(p) as scan:
            arr = np.append(arr, scan.ccta[scan.labelmap[::].astype(bool)])
    centerlines_pixels[sc] = arr


[len(c) for c in centerlines_pixels.values()]

In [None]:
fig, ax = plt.subplots()

for sc, v in centerlines_pixels.items():
    ax.hist(v, bins=80, alpha=0.5, density=True, label=sc.name)
fig.legend()
fig.suptitle("Arteries centerlines HU distribution")
fig.tight_layout()
plt.show()
plt.close(fig)

In [None]:
from torch import nn

# idea of 1D convolution: bottleneck!
inp_shape = (256, 128, 128, 128)
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 64, 1, 1, 0), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 64, 3, 1, 1), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 256, 1, 1, 0), inp_shape)
print('------')
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 256, 3, 1, 1), inp_shape)

In [None]:
attenuations = trainer.generator(subopt)
recon = subopt - attenuations

print(attenuations[1].min(), attenuations[1].max())

D_real  = trainer.critic(opt)
D_fake = trainer.critic(recon.detach())

In [None]:
from torchviz import make_dot

from contrast_gan_3D.model.utils import wgan_gradient_penalty

loss_D = trainer.loss_GAN(D_fake, D_real)
gp = wgan_gradient_penalty(
    opt.repeat((2,) + (1,) * len(opt.shape[1:])),
    recon,
    trainer.critic,
    trainer.device,
)
if False:
    loss_D += gp


# loss_D.backward()

loss_G = -trainer.loss_GAN(trainer.critic(recon))
# loss_G.backward()

In [None]:
make_dot(gp)