In [1]:
%cd ..

/home/marco/contrast-gan-3D


In [2]:
from contrast_gan_3D.data.CCTADataLoader import CCTADataLoader
from contrast_gan_3D.experiments.conf_2D import *

loader = CCTADataLoader(
    ["/home/marco/data/ASOCA_Philips/images/ASOCA-000.h5"],
    train_patch_size,
    train_batch_size,
    rng,
    scaler=scaler,
)

In [4]:
batch = loader.generate_train_batch()
print(batch["data"].shape)

(6, 1, 128, 128)


In [None]:
import os

# https://discuss.pytorch.org/t/gpu-device-ordering/60785/2
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


import numpy as np
import torch

from contrast_gan_3D.alias import ScanType
from contrast_gan_3D.config import CHECKPOINTS_DIR

from contrast_gan_3D.experiments.gradient_penalty_conf import *
from contrast_gan_3D.model.loss import HULoss
from contrast_gan_3D.trainer.Trainer import Trainer
from contrast_gan_3D.trainer.utils import create_dataloaders, cval_paths
from contrast_gan_3D.utils import set_GPU

In [None]:
profiler_dir = None
run_id = "pippo"
fold_idx = 0

train_folds, val_folds = cval_paths(n_cval_folds, "/home/marco/data/ostia_final.xlsx")

scaled_HU_bounds = scaler(np.array(desired_HU_bounds))
print(scaled_HU_bounds)

val_batch_size = 2
train_loaders, val_loaders = create_dataloaders(
    train_folds[fold_idx],
    val_folds[fold_idx],
    train_patch_size,
    val_patch_size,
    train_batch_size,
    val_batch_size,
    rng,
    scaler,
    train_transform=train_transform,
)

In [None]:
from matplotlib import pyplot as plt

from contrast_gan_3D.alias import ScanType
from contrast_gan_3D.data.HD5Scan import HD5Scan

centerlines_pixels = {k: np.array([]) for k in ScanType}
for sc in centerlines_pixels:
    print(sc)
    arr = centerlines_pixels[sc]
    val_load = val_loaders[sc.value]
    for p in val_load.generator._data:
        print(p)
        with HD5Scan(p) as scan:
            arr = np.append(arr, scan.ccta[scan.labelmap[::].astype(bool)])
    centerlines_pixels[sc] = arr


[len(c) for c in centerlines_pixels.values()]

In [None]:
fig, ax = plt.subplots()

for sc, v in centerlines_pixels.items():
    ax.hist(v, bins=80, alpha=0.5, density=True, label=sc.name)
fig.legend()
fig.suptitle("Arteries centerlines HU distribution")
fig.tight_layout()
plt.show()
plt.close(fig)

In [None]:
from pathlib import Path
from typing import List


# NOTE need some form of patch aggregation
def compute_histograms(paths: List[Path]):
    ...

def plot_histograms():
    ...

In [None]:
from contrast_gan_3D.trainer.logger.LoggerInterface import SingleThreadedLogger

logger_interface = SingleThreadedLogger(logger_interface.logger)

device = set_GPU(3)

trainer = Trainer(
    train_iterations,
    val_iterations,
    validate_every,
    train_generator_every,
    log_every,
    log_images_every,
    generator_class,
    critic_class,
    generator_optim_class,
    critic_optim_class,
    HULoss(*scaled_HU_bounds, (val_batch_size, 1, *val_patch_size)),
    # HULoss(*scaled_HU_bounds, (train_batch_size * 2, 1, *train_patch_size)),
    logger_interface,
    CHECKPOINTS_DIR / f"{run_id}.pt",
    weight_clip=weight_clip,
    generator_lr_scheduler_class=generator_lr_scheduler_class,
    critic_lr_scheduler_class=critic_lr_scheduler_class,
    device=device,
    checkpoint_every=checkpoint_every,
)
# trainer.load_checkpoint(
#     # "/home/marco/contrast-gan-3D/logs/model_checkpoints/9hnh7gto.pt"
#     "/home/marco/contrast-gan-3D/logs/model_checkpoints/07qiygyk.pt"
# )

In [None]:
from contrast_gan_3D.model import utils as model_utils

print(model_utils.count_parameters(trainer.critic))
print(model_utils.count_parameters(trainer.generator))

In [None]:
print(trainer.critic)
print("-----------")
print(trainer.generator)

In [None]:
patches = [next(train_loaders[scan_type.value]) for scan_type in ScanType]
opt, low, high = patches
opt = opt["data"].to(trainer.device, non_blocking=True)
subopt = torch.cat([low["data"], high["data"]])
subopt = subopt.to(trainer.device, non_blocking=True)

In [None]:
model_utils.print_convolution_filters_shape(trainer.critic, opt.shape[1:])
print("----")
model_utils.print_convolution_filters_shape(trainer.generator, opt.shape[1:])

In [None]:
from torch import nn

# idea of 1D convolution: bottleneck!
inp_shape = (256, 128, 128, 128)
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 64, 1, 1, 0), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 64, 3, 1, 1), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 256, 1, 1, 0), inp_shape)
print('------')
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 256, 3, 1, 1), inp_shape)

In [None]:
attenuations = trainer.generator(subopt)
recon = subopt - attenuations

print(attenuations[1].min(), attenuations[1].max())

D_real  = trainer.critic(opt)
D_fake = trainer.critic(recon.detach())

In [None]:
from torchviz import make_dot

from contrast_gan_3D.model.utils import wgan_gradient_penalty

loss_D = trainer.loss_GAN(D_fake, D_real)
gp = wgan_gradient_penalty(
    opt.repeat((2,) + (1,) * len(opt.shape[1:])),
    recon,
    trainer.critic,
    trainer.device,
)
if False:
    loss_D += gp


# loss_D.backward()

loss_G = -trainer.loss_GAN(trainer.critic(recon))
# loss_G.backward()

In [None]:
make_dot(gp)

In [None]:
# trainer.logger_interface([sample], [recon], [attenuations], 0, "train")

In [None]:
# EXAMPLE TO CREATE SITK IMAGES TO IMPORT INTO ITK-SNAP
from pathlib import Path

from contrast_gan_3D.data.HD5Scan import HD5Scan
from contrast_gan_3D.utils import io_utils

test_folder = Path("/home/marco/data/after_preproc")
test_folder.mkdir(exist_ok=True)

p = Path("/home/marco/data/MMWHS/ct_test/ct_test_2035_image.h5")
print(p)

if False:
    with HD5Scan(p) as scan:
        pprint(scan.meta)
        print(scan.ccta.shape)

        # HWD -> DHW (xyz->zyx, numpy to sitk convention)
        ccta = scan.ccta[::].transpose(2, 0, 1)
        io_utils.to_itksnap_volume(
            ccta,
            scan.meta["offset"],
            scan.meta["spacing"],
            test_folder / f"{p.stem}.mhd",
        )