In [1]:
%matplotlib inline
import os
import ipywidgets as widgets
from matplotlib import pyplot as plt

# size of latent z vector
nz = widgets.IntText(
    value=512,
    description='nz'
)

ngf =  widgets.IntText(
    value=32,
    description='ngf'
)

ndf =  widgets.IntText(
    value=16,
    description='ndf'
)

n_epoch = widgets.IntText(
    value=25,
    description='n epoch'
)

lr = widgets.FloatText(
    value = 0.002,
    description = 'learning rate'
)

beta1 = widgets.FloatText(
    value = 0.5,
    description = 'beta 1'
)

im_size = widgets.IntText(
    value = 64,
    description = 'im size'
)

file_path = widgets.Text(
    description = 'path'
)

u_box = widgets.HBox([nz, ngf, ndf, n_epoch])
b_box = widgets.HBox([lr, beta1,im_size])
v_b = widgets.VBox([u_box,b_box,file_path])

display(v_b)

VBox(children=(HBox(children=(IntText(value=512, description='nz'), IntText(value=32, description='ngf'), IntT…

In [6]:
import h5py
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

BATCH_SIZE = 5

class ctimage(Dataset):
    def __init__(self, path):
        self.img = h5py.File(path, 'r')['data']
    def __len__(self):
        return self.img.shape[0]
    def __getitem__(self,idx):
        #one sided label smoothing
        return torch.from_numpy(np.expand_dims(self.img[idx],axis = 0)), 0.9

training_data = ctimage(file_path.value)

dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [7]:
import dcgan
from torch import nn
nc = 1

def init_weights(m):
    if type(m) == nn.Conv3d:
        m.weight.data.normal_(0.0, 0.02)
    elif type(m) == nn.BatchNorm3d:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)        

netG = dcgan.make_generator_model(im_size.value, nz.value, nc, ngf.value)
netD = dcgan.make_discriminator_model(im_size.value, nc, ndf.value)

netG.apply(init_weights)
netD.apply(init_weights)

Sequential(
  (0): Conv3d(1, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (1): LeakyReLU(negative_slope=0.2, inplace)
  (3): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): LeakyReLU(negative_slope=0.2, inplace)
  (6): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): LeakyReLU(negative_slope=0.2, inplace)
  (9): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (10): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): LeakyReLU(negative_slope=0.2, inplace)
  (12): Conv3d(128, 1, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
  (13): Sigmoid()
)

In [8]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.BCELoss()

device = "cuda:0" if torch.cuda.is_available() else "cpu"

print(device)

if(torch.cuda.is_available()):
    netG = netG.cuda()
    netD = netD.cuda()
    criterion = criterion.cuda()

d_optimizer = optim.Adam(netD.parameters(), lr = lr.value, betas = (beta1.value, 0.999))
g_optimizer = optim.Adam(netG.parameters(), lr = lr.value, betas = (beta1.value, 0.999))

d_schedule = optim.lr_scheduler.StepLR(d_optimizer, step_size=500, gamma=0.1)
g_schedule = optim.lr_scheduler.StepLR(g_optimizer, step_size=500, gamma=0.1)

cuda:0


In [9]:
import time
from torch import tensor
import os.path
from tqdm import tqdm

training_curve = "training_curve"

num_iteration = len(dataloader)*n_epoch.value

hf = h5py.File(os.path.join('.',training_curve,'training_curve.hdf5'), "w")
loss_d = hf.create_dataset("Loss D", (num_iteration,), dtype='f')
loss_g = hf.create_dataset("Loss G", (num_iteration,), dtype='f')
d_x = hf.create_dataset("D(x)", (num_iteration,), dtype='f')
d_g_z = hf.create_dataset("D(G(z))", (num_iteration,2), dtype='f')

gen_iterations = 0
for epoch in tqdm(range(n_epoch.value)):
    start = time.time()
    d_schedule.step()
    g_schedule .step()
    for i_batch, sample_batched in enumerate(dataloader):
        netD.zero_grad()
        
        real_img = sample_batched[0].to(device, dtype=torch.float)
        real_label = sample_batched[1].to(device, dtype=torch.float)        
        
        #train with real
        output = netD(real_img)
        errD_real = criterion(output, real_label)
        errD_real.backward()
        D_x = output.data.mean()
        
        noise = torch.rand((BATCH_SIZE,nz.value,1,1,1)).to(device, dtype=torch.float)
        fake = netG(noise).detach()
        
        output = netD(fake)
        
        fake_label = tensor([0.0]*BATCH_SIZE, dtype = torch.float).to(device)
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        d_optimizer.step()
        
        g_iter = 1
        
        while g_iter != 0:
            netG.zero_grad()
            real_label = tensor([1.0]*BATCH_SIZE, dtype = torch.float).to(device) # fake labels are real for generator cost
            noise = torch.ones(())
            noise = noise.new_empty((BATCH_SIZE,nz.value,1,1,1), dtype=torch.float, device=device)
            noise.normal_(0,1)
            fake = netG(noise)
            output = netD(fake)
            errG = criterion(output, real_label)
            errG.backward()
            D_G_z2 = output.data.mean()
            g_optimizer.step()
            g_iter -= 1
        
        loss_d[gen_iterations] = errD.data.item()
        loss_g[gen_iterations] = errG.data.item()
        d_x[gen_iterations] = D_x.cpu()
        d_g_z[gen_iterations][0] = D_G_z1
        d_g_z[gen_iterations][1] = D_G_z2
        
        gen_iterations += 1
        
    work_dir = "training_checkpoints"
    if epoch % 20 == 0:
        torch.save(netG.state_dict(), os.path.join(".",work_dir,"netG_epoch_{}.pth".format(epoch)))
        torch.save(netD.state_dict(), os.path.join(".",work_dir,"netD_epoch_{}.pth".format(epoch)))
        hf.flush()

hf.close()

 24%|██▍       | 728/3000 [35:20:46<111:32:02, 176.73s/it]Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connec

KeyboardInterrupt: 