In [1]:
%cd ..

%config InlineBackend.figure_format = "retina"

/home/marco/contrast-gan-3D


In [2]:
import numpy as np
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR

from contrast_gan_3D.constants import TRAIN_PATCH_SIZE
from contrast_gan_3D.data import utils as dset_utils
from contrast_gan_3D.data.CCTADataset import CCTADataset
from contrast_gan_3D.model.discriminator import NLayerDiscriminator
from contrast_gan_3D.model.generator import ResnetGenerator
from contrast_gan_3D.trainer import utils as train_utils
from contrast_gan_3D.trainer.Reloader import Reloader
from contrast_gan_3D.trainer.Trainer import Trainer

In [3]:
train_folds, val_folds = train_utils.crossval_paths("/home/marco/data/ostia_final.xlsx")

train_by_lab = [train_utils.divide_scans_in_fold(f) for f in train_folds]
val_by_lab = [train_utils.divide_scans_in_fold(f) for f in val_folds]

rng = np.random.default_rng(42)
train_transforms = dset_utils.make_train_transforms()

# iterate over whole cross-validation dataset
for i, (train, val) in enumerate(zip(train_by_lab, val_by_lab)):
    train_loaders = {}
    for label, paths in train.items():
        train_loaders[label] = Reloader(
            CCTADataset(
                paths,
                [label] * len(paths),
                TRAIN_PATCH_SIZE,
                rng=rng,
                transform=train_transforms,
            ),
            reload=True,
            batch_size=1,
            shuffle=True,
        )

    val_loaders = {}
    for label, paths in val.items():
        val_loaders[label] = Reloader(
            # CCTADataset(paths, [label] * len(paths), -1, rng=rng),
            CCTADataset(paths, [label] * len(paths), TRAIN_PATCH_SIZE, rng=rng),
            reload=False,
            batch_size=1,
            shuffle=False,
        )

    # #########  TEST TRAIN
    # for j in range(5):
    #     print(j)
    #     batch = next(train_loaders[ScanType.LOW.value])
    #     print(batch["meta"]["path"])
    #     print(batch["data"].shape)

    # #########  TEST VAL
    # print(len(val_loaders[ScanType.HIGH.value].dataset))
    # for j, batch in enumerate(val_loaders[ScanType.HIGH.value]):
    #     print(j)
    #     print(batch["data"].shape)
    #     print(batch["meta"]["path"])

    # # only do one cross validation
    # if i == 0:
    #     break

In [4]:
lr = 2e-4
b1 = 5e-1
b2 = 0.999
milestones = [3, 7]
lr_gamma = 0.1
train_generator_every = 2
hu_loss_kwargs = {"bias": -1024, "factor": 600, "min_HU": 350, "max_HU": 450}
train_iterations = 10

generator = ResnetGenerator(6, 2, 64)
generator_optim = Adam(generator.parameters(), lr=lr, betas=(b1, b2))
generator_lr_scheduler = MultiStepLR(
    generator_optim, milestones=milestones, gamma=lr_gamma
)

discriminator = NLayerDiscriminator(1, 1, 3, 64)
discriminator_optim = Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
discriminator_lr_scheduler = MultiStepLR(
    discriminator_optim, milestones=milestones, gamma=lr_gamma
)

trainer = Trainer(
    generator,
    discriminator,
    generator_optim,
    discriminator_optim,
    train_generator_every,
    generator_lr_scheduler=generator_lr_scheduler,
    discriminator_lr_scheduler=discriminator_lr_scheduler,
    device_num=4,
    **hu_loss_kwargs
)

In [5]:
trainer.fit(12, train_loaders, val_loaders)

  0%|          | 0/12 [00:00<?, ?it/s]

[1mtrain[0m iteration 0 loss:
	D: -0.013175791129469872
	G: 62133.02734375
	adversarial: 0.006100583355873823
	similarity: -0.999999463558197
	HU: 62134.01953125

[1mvalidation[0m iteration 1 loss:
	n_opt: 13
	n_subopt: 11
	D: 0.006461285054683685
	adversarial_real: 0.003233474213629961
	similarity: -0.9999995665116743
	adversarial_fake: -0.0032339888540181246
	adversarial: 0.0032339888540181246

[1mtrain[0m iteration 2 loss:
	D: 0.013791349716484547
	G: 94563.0546875
	adversarial: -0.001969830831512809
	similarity: -0.9999997019767761
	HU: 94564.0546875

[1mvalidation[0m iteration 3 loss:
	n_opt: 13
	n_subopt: 11
	D: 0.0
	adversarial_real: 0.0
	similarity: 0.0
	adversarial_fake: 0.0
	adversarial: 0.0

[1mtrain[0m iteration 4 loss:
	D: -0.007481778040528297
	G: 278065.3125
	adversarial: 0.007585951127111912
	similarity: -0.9999997019767761
	HU: 278066.3125

[1mvalidation[0m iteration 5 loss:
	n_opt: 13
	n_subopt: 11
	D: 0.0
	adversarial_real: 0.0
	similarity: 0.0
	adversari

KeyboardInterrupt: 