In [1]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from PIL import Image
from pydicom import dcmread
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

In [2]:
def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, cur_res, lr, tensorboard_step
):
    writer.add_scalar("data/loss_dis", loss_critic, global_step=tensorboard_step)
    writer.add_scalar("data/loss_gen", loss_gen, global_step=tensorboard_step)
    writer.add_scalar("data/cur_resl", int(cur_res), global_step=tensorboard_step)
    writer.add_scalar("data/cur_lr", lr, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

## Hyperparameters

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4
z_dim = 64
image_dim = 64 * 64 * 1
batch_size = 32
num_epochs = 100
IMAGE_HEIGHT = 64  
IMAGE_WIDTH = 64
TRAIN_DIR = f"/Storage/PauloOctavioDir/nodule_images/images"
TENSORBOARD_MODEL_NAME = "gan"

In [5]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
train_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.CenterCrop(504),
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
        transforms.Normalize((0,), (4000,)),
    ]
)

## Dataset

In [6]:
class LungNoduleDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        images = os.listdir(image_dir)
        if 'rtss.dcm' in images: images.remove('rtss.dcm')
        self.images = images


    def __len__(self):
         return len(self.images)


    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        data = dcmread(img_path)
        image = np.array(data.pixel_array).astype('float32')
        # Conversio to HU
        intercept = int(data.RescaleIntercept)
        slope = int(data.RescaleSlope)
        image = slope * image + intercept
        image[image < -2000] = 0
        image[image > 3000] = 3000
        return self.transform(image)
    
dataset = LungNoduleDataset(image_dir=TRAIN_DIR, transform=train_transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [7]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer = SummaryWriter(f"logs/{TENSORBOARD_MODEL_NAME}")
step = 0

for epoch in range(num_epochs):
    for batch_idx, real in enumerate(loader):
        real = real.view(-1, image_dim).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, IMAGE_HEIGHT, IMAGE_WIDTH)
                data = real.reshape(-1, 1, IMAGE_HEIGHT, IMAGE_WIDTH)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer.add_image(
                    "Fake", img_grid_fake, global_step=step
                )
                writer.add_image(
                    "Real", img_grid_real, global_step=step
                )    
                
                writer.add_scalar("data/loss_dis", lossD.item(), global_step=step)
                writer.add_scalar("data/loss_gen", lossG.item(), global_step=step)
                
                step += 1

Epoch [0/100] Batch 0/183                       Loss D: 0.6844, loss G: 0.8069
Epoch [1/100] Batch 0/183                       Loss D: 0.0990, loss G: 3.8083
Epoch [2/100] Batch 0/183                       Loss D: 0.8560, loss G: 1.1472
Epoch [3/100] Batch 0/183                       Loss D: 0.1423, loss G: 2.7475
Epoch [4/100] Batch 0/183                       Loss D: 0.7423, loss G: 1.3737
Epoch [5/100] Batch 0/183                       Loss D: 0.6734, loss G: 0.9811
Epoch [6/100] Batch 0/183                       Loss D: 0.2890, loss G: 1.5054
Epoch [7/100] Batch 0/183                       Loss D: 0.8018, loss G: 0.6075
Epoch [8/100] Batch 0/183                       Loss D: 0.4050, loss G: 1.0769
Epoch [9/100] Batch 0/183                       Loss D: 0.7249, loss G: 0.6287
Epoch [10/100] Batch 0/183                       Loss D: 0.4371, loss G: 1.1917
Epoch [11/100] Batch 0/183                       Loss D: 0.4557, loss G: 1.2409
Epoch [12/100] Batch 0/183                       L