In [1]:
%matplotlib inline
import os
import ipywidgets as widgets
from matplotlib import pyplot as plt

file_path = widgets.Text(
    description = 'path'
)


display(file_path)

Text(value='', description='path')

In [2]:
import configparser
config = configparser.ConfigParser()
config.read(file_path.value)

nz = int(config['NN']['nz'])
ngf = int(config['NN']['ngf'])
ndf = int(config['NN']['ndf'])
lr = float(config['NN']['lr'])
beta1 = float(config['NN']['beta1'])

im_size = int(config['Training']['im_size'])
n_epoch = int(config['Training']['n_epoch'])
BATCH_SIZE = int(config['Training']['BATCH_SIZE'])
Ti_path = config['Training']['path']

In [3]:
import h5py
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch


class ctimage(Dataset):
    def __init__(self, path):
        self.img = h5py.File(path, 'r')['data']
    def __len__(self):
        return self.img.shape[0]
    def __getitem__(self,idx):
        #one sided label smoothing
        return torch.from_numpy(np.expand_dims(self.img[idx],axis = 0)), 0.9

training_data = ctimage(Ti_path)

dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [4]:
import dcgan
from torch import nn
nc = 1

device = torch.device("cuda:0")

def init_weights(m):
    if type(m) == nn.Conv3d:
        m.weight.data.normal_(0.0, 0.02)
    elif type(m) == nn.BatchNorm3d:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)        

netG = dcgan.make_generator_model(im_size, nz, nc, ngf)
netD = dcgan.make_discriminator_model(im_size, nc, ndf)

netG.apply(init_weights)
netD.apply(init_weights)

Sequential(
  (0): Conv3d(1, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (1): LeakyReLU(negative_slope=0.2, inplace)
  (3): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): LeakyReLU(negative_slope=0.2, inplace)
  (6): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): LeakyReLU(negative_slope=0.2, inplace)
  (9): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (10): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): LeakyReLU(negative_slope=0.2, inplace)
  (12): Conv3d(128, 1, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
  (13): Sigmoid()
)

In [5]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.BCELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if(torch.cuda.device_count()>1):
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)
    criterion = criterion.to(device)
    
    netG = netG.to(device)
    netD = netD.to(device)
elif(torch.cuda.is_available()):
    netG = netG.to(device)
    netD = netD.to(device)
    criterion = criterion.to(device)

d_optimizer = optim.Adam(netD.parameters(), lr = lr, betas = (beta1, 0.999))
g_optimizer = optim.Adam(netG.parameters(), lr = lr, betas = (beta1, 0.999))

d_schedule = optim.lr_scheduler.StepLR(d_optimizer, step_size=500, gamma=0.1)
g_schedule = optim.lr_scheduler.StepLR(g_optimizer, step_size=500, gamma=0.1)

In [6]:
import time
from torch import tensor
import os.path
from tqdm import tqdm

training_curve = "training_curve"

num_iteration = len(dataloader)*n_epoch

hf = h5py.File(os.path.join('.',training_curve,'training_curve.hdf5'), "w")
loss_d = hf.create_dataset("Loss D", (num_iteration,), dtype='f')
loss_g = hf.create_dataset("Loss G", (num_iteration,), dtype='f')
d_x = hf.create_dataset("D(x)", (num_iteration,), dtype='f')
d_g_z = hf.create_dataset("D(G(z))", (num_iteration,2), dtype='f')

gen_iterations = 0
for epoch in tqdm(range(n_epoch)):
    start = time.time()
    d_schedule.step()
    g_schedule .step()
    for i_batch, sample_batched in enumerate(dataloader):
        netD.zero_grad()
        
        real_img = sample_batched[0].to(device, dtype=torch.float)
        real_label = sample_batched[1].to(device, dtype=torch.float)        
        
        #train with real
        output = netD(real_img)
        errD_real = criterion(output, real_label)
        errD_real.backward()
        D_x = output.data.mean()
        
        noise = torch.rand((BATCH_SIZE,nz,1,1,1)).to(device, dtype=torch.float)
        fake = netG(noise).detach()
        
        output = netD(fake)
        
        fake_label = tensor([0.0]*BATCH_SIZE, dtype = torch.float).to(device)
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        d_optimizer.step()
        
        g_iter = 1
        
        while g_iter != 0:
            netG.zero_grad()
            real_label = tensor([1.0]*BATCH_SIZE, dtype = torch.float).to(device) # fake labels are real for generator cost
            noise = torch.ones(())
            noise = noise.new_empty((BATCH_SIZE,nz,1,1,1), dtype=torch.float, device=device)
            noise.normal_(0,1)
            fake = netG(noise)
            output = netD(fake)
            errG = criterion(output, real_label)
            errG.backward()
            D_G_z2 = output.data.mean()
            g_optimizer.step()
            g_iter -= 1
        
        loss_d[gen_iterations] = errD.data.item()
        loss_g[gen_iterations] = errG.data.item()
        d_x[gen_iterations] = D_x.cpu()
        d_g_z[gen_iterations][0] = D_G_z1
        d_g_z[gen_iterations][1] = D_G_z2
        
        gen_iterations += 1
        
    work_dir = "training_checkpoints"
    if epoch % 20 == 0:
        if(torch.cuda.device_count()>1):
            G_Data = netG.module.state_dict()
            D_Data = netD.module.state_dict()
        else:
            G_Data = netG.state_dict()
            D_Data = netD.state_dict()
        torch.save(G_Data, os.path.join(".",work_dir,"netG_epoch_{}.pth".format(epoch)))
        torch.save(D_Data, os.path.join(".",work_dir,"netD_epoch_{}.pth".format(epoch)))
        hf.flush()

hf.close()

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
100%|██████████| 1/1 [00:21<00:00, 21.05s/it]
