In [38]:
import os

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn, optim
from torchinfo import summary

from utils import load_all_data, squeeze_and_concat, filter_mask_keep_labels, multiclass_dice_loss
from datasets import MultiTissueDataset
from unet_advanced import UNetAdvanced as UNetGan
from gan_basic import DiscriminatorModel
from train_utils import EarlyStopping


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Selected device: {device}")
print(f'Num available GPUs: ', torch.cuda.device_count())


p = torch.cuda.get_device_properties()
print(f"Device: {p.name} (Memory: {p.total_memory / 1e9:.2f} GB)")

Selected device: cuda
Num available GPUs:  1
Device: NVIDIA TITAN RTX (Memory: 25.19 GB)


In [41]:
# Data import 
DATA_FOLDER = "/scratch/pdiciano/GenAI/ACDC_mine/data/ACDC_tissue_prop"

data = load_all_data(DATA_FOLDER)
data_concat = squeeze_and_concat(data)

mask_keep_labels = [0, 1, 2, 3]  # ventricule right, ventricule left, myocardium right, myocardium left
data_concat['input_labels'] = filter_mask_keep_labels(data_concat['multiClassMasks'], mask_keep_labels)

In [54]:
gen = UNetGan(in_ch=4, num_classes=12, dropout_p=0.3)
discr = DiscriminatorModel(in_ch=5, base_ch=64)


In [55]:
l1 = nn.L1Loss()

inp = torch.randn(2, 4, 256, 256)
out = torch.ones(2, 4, 256, 256)

loss = l1(inp, out)
loss


tensor(1.1662)

In [56]:
dataset = MultiTissueDataset(data_concat)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [57]:
batch = next(iter(dataloader))


criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

In [58]:
criterion_GAN(torch.rand(2, 1), torch.rand(2, 1))

tensor(0.9424)

In [None]:
lambda_l1 = 1.0

gen.train()
discr.train()

optim_gen = optim.Adam(gen.parameters(), lr=1e-4)
optim_discr = optim.Adam(discr.parameters(), lr=1e-4)

input = batch['input_label']
gt = batch['multiClassMask']

# train discriminator
gen_img = torch.argmax(gen(input), dim=1).unsqueeze(1)

discr_in_real = torch.cat((input, gt.unsqueeze(1)), dim=1)
discr_in_fake = torch.cat((input, gen_img.detach()), dim=1)

discr_real = discr(discr_in_real)
discr_fake = discr(discr_in_fake)

loss = criterion_GAN(discr_real, torch.ones_like(discr_real)) + \
    criterion_GAN(discr_fake, torch.zeros_like(discr_fake))

optim_discr.zero_grad()
loss.backward()
optim_discr.step()


# train generator 
discr_in_fake = torch.cat((input, gen_img), dim=1) # no detach now
discr_fake = discr(discr_in_fake)
discr_real = discr(discr_in_real)


l1_loss = criterion_L1(gen_img.squeeze().float(), gt.float())

# Invert real/fake labels for generator loss
loss_gen = criterion_GAN(discr_fake, torch.ones_like(discr_fake)) + lambda_l1 * l1_loss

optim_gen.zero_grad()
loss_gen.backward()
optim_gen.step()

