In [1]:
%cd ..

/home/marco/contrast-gan-3D


In [2]:
import os

# https://discuss.pytorch.org/t/gpu-device-ordering/60785/2
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


import numpy as np
import torch

from contrast_gan_3D.alias import ScanType
from contrast_gan_3D.config import CHECKPOINTS_DIR

from contrast_gan_3D.experiments.test_conf import *
# from contrast_gan_3D.experiments.gradient_penalty_conf import *
from contrast_gan_3D.model.loss import HULoss
from contrast_gan_3D.trainer.Trainer import Trainer
from contrast_gan_3D.trainer.utils import create_dataloaders, cval_paths
from contrast_gan_3D.utils import set_GPU

In [3]:
profiler_dir = None
run_id = "pippo"
fold_idx = 0

train_folds, val_folds = cval_paths(cval_folds, "/home/marco/data/ostia_final.xlsx")

scaled_HU_bounds = scaler(np.array(desired_HU_bounds))
print(scaled_HU_bounds)

val_batch_size = 2
train_loaders, val_loaders = create_dataloaders(
    train_folds[fold_idx],
    val_folds[fold_idx],
    train_patch_size,
    val_patch_size,
    train_batch_size,
    val_batch_size,
    scaler=scaler,
    train_transform=train_transform,
)

[0.18666667 0.35333333]


In [4]:
from contrast_gan_3D.trainer.logger.LoggerInterface import SingleThreadedLogger

logger_interface = SingleThreadedLogger(logger_interface.logger)

device = set_GPU(3)

trainer = Trainer(
    train_iterations,
    val_iterations,
    validate_every,
    train_generator_every,
    log_every,
    log_images_every,
    generator_class,
    critic_class,
    generator_optim_class,
    critic_optim_class,
    HULoss(*scaled_HU_bounds, (val_batch_size, 1, *val_patch_size)),
    # HULoss(*scaled_HU_bounds, (train_batch_size * 2, 1, *train_patch_size)),
    logger_interface,
    CHECKPOINTS_DIR / f"{run_id}.pt",
    weight_clip=weight_clip,
    generator_lr_scheduler_class=generator_lr_scheduler_class,
    critic_lr_scheduler_class=critic_lr_scheduler_class,
    device=device,
    checkpoint_every=checkpoint_every,
)
# trainer.load_checkpoint(
#     # "/home/marco/contrast-gan-3D/logs/model_checkpoints/9hnh7gto.pt"
#     "/home/marco/contrast-gan-3D/logs/model_checkpoints/07qiygyk.pt"
# )

[2024-04-12 13:47:37,522: INFO] Using device: cuda:3 (contrast_gan_3D.trainer.Trainer:65)
[2024-04-12 13:47:37,617: INFO] Starting from iteration 0 (contrast_gan_3D.trainer.Trainer:321)


In [5]:
from contrast_gan_3D.model import utils as model_utils

print(model_utils.count_parameters(trainer.critic))
print(model_utils.count_parameters(trainer.generator))

697809
1035297


In [6]:
print(trainer.critic)
print("-----------")
print(trainer.generator)

PatchGANDiscriminator(
  (model): Sequential(
    (first): ConvBlock3D(
      (conv): Conv3d(1, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (normalization): Identity()
      (activation_fn): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (middle): Sequential(
      (0): ConvBlock3D(
        (conv): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        (normalization): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation_fn): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (1): ConvBlock3D(
        (conv): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        (normalization): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation_fn): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (2): ConvBlock3D(
        (conv): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2,

In [7]:
patches = [next(train_loaders[scan_type.value]) for scan_type in ScanType]
opt, low, high = patches
opt = opt["data"].to(trainer.device, non_blocking=True)
subopt = torch.cat([low["data"], high["data"]])
subopt = subopt.to(trainer.device, non_blocking=True)

using pin_memory on device 3
using pin_memory on device 3
using pin_memory on device 3


In [8]:
model_utils.print_convolution_filters_shape(trainer.critic, opt.shape[1:])
print("----")
model_utils.print_convolution_filters_shape(trainer.generator, opt.shape[1:])

Input shape: [1, 128, 128, 128]
model.first.conv                         -> [16, 64, 64, 64]       # params: 1040       weight: [16, 1, 4, 4, 4]     bias: [16]
model.middle.0.conv                      -> [32, 32, 32, 32]       # params: 32768      weight: [32, 16, 4, 4, 4]   
model.middle.1.conv                      -> [64, 16, 16, 16]       # params: 131072     weight: [64, 32, 4, 4, 4]   
model.middle.2.conv                      -> [128, 8, 8, 8]         # params: 524288     weight: [128, 64, 4, 4, 4]  
model.last                               -> [1, 7, 7, 7]           # params: 8193       weight: [1, 128, 4, 4, 4]    bias: [1]
----
Input shape: [1, 128, 128, 128]
model.first.conv                         -> [16, 128, 128, 128]    # params: 5488       weight: [16, 1, 7, 7, 7]    
model.downsampling.0.conv                -> [32, 64, 64, 64]       # params: 13824      weight: [32, 16, 3, 3, 3]   
model.downsampling.1.conv                -> [64, 32, 32, 32]       # params: 55296      wei

In [14]:
from torch import nn

# idea of 1D convolution: bottleneck!
inp_shape = (256, 128, 128, 128)
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 64, 1, 1, 0), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 64, 3, 1, 1), inp_shape)
model_utils.print_convolution_filters_shape(nn.Conv3d(64, 256, 1, 1, 0), inp_shape)
print('------')
model_utils.print_convolution_filters_shape(nn.Conv3d(256, 256, 3, 1, 1), inp_shape)

Input shape: [256, 128, 128, 128]
                                         -> [64, 128, 128, 128]    # params: 16448      weight: [64, 256, 1, 1, 1]   bias: [64]
Input shape: [256, 128, 128, 128]
                                         -> [64, 128, 128, 128]    # params: 110656     weight: [64, 64, 3, 3, 3]    bias: [64]
Input shape: [256, 128, 128, 128]
                                         -> [256, 128, 128, 128]   # params: 16640      weight: [256, 64, 1, 1, 1]   bias: [256]
------
Input shape: [256, 128, 128, 128]
                                         -> [256, 128, 128, 128]   # params: 1769728    weight: [256, 256, 3, 3, 3]  bias: [256]


In [22]:
attenuations = trainer.generator(subopt)
recon = subopt - attenuations

print(attenuations[1].min(), attenuations[1].max())

D_real  = trainer.critic(opt)
D_fake = trainer.critic(recon.detach())

tensor(-0.9999, device='cuda:3', grad_fn=<MinBackward1>) tensor(0.9673, device='cuda:3', grad_fn=<MaxBackward1>)


In [None]:
from torchviz import make_dot

from contrast_gan_3D.model.utils import wgan_gradient_penalty

loss_D = trainer.loss_GAN(D_fake, D_real)
gp = wgan_gradient_penalty(
    opt.repeat((2,) + (1,) * len(opt.shape[1:])),
    recon,
    trainer.critic,
    trainer.device,
)
if False:
    loss_D += gp


# loss_D.backward()

loss_G = -trainer.loss_GAN(trainer.critic(recon))
# loss_G.backward()

In [None]:
make_dot(gp)

In [None]:
# trainer.logger_interface([sample], [recon], [attenuations], 0, "train")

In [None]:
# EXAMPLE TO CREATE SITK IMAGES TO IMPORT INTO ITK-SNAP
from pathlib import Path

from contrast_gan_3D.data.HD5Scan import HD5Scan
from contrast_gan_3D.utils import io_utils

test_folder = Path("/home/marco/data/after_preproc")
test_folder.mkdir(exist_ok=True)

p = Path("/home/marco/data/MMWHS/ct_test/ct_test_2035_image.h5")
print(p)

if False:
    with HD5Scan(p) as scan:
        pprint(scan.meta)
        print(scan.ccta.shape)

        # HWD -> DHW (xyz->zyx, numpy to sitk convention)
        ccta = scan.ccta[::].transpose(2, 0, 1)
        io_utils.to_itksnap_volume(
            ccta,
            scan.meta["offset"],
            scan.meta["spacing"],
            test_folder / f"{p.stem}.mhd",
        )