In [1]:
%matplotlib inline
import os
import ipywidgets as widgets
from matplotlib import pyplot as plt

file_path = widgets.Text(
    description = 'path'
)


display(file_path)

Text(value='', description='path')

In [2]:
import configparser
config = configparser.ConfigParser()
config.read(file_path.value)

num_unit = int(config['NN']['n_u'])
num_block = int(config['NN']['n_block'])
growth = int(config['NN']['growth'])

lr = float(config['NN']['lr'])
beta1 = float(config['NN']['beta1'])
alpha = float(config['NN']['alpha'])

im_size = int(config['Training']['im_size'])
n_epoch = int(config['Training']['n_epoch'])
BATCH_SIZE = int(config['Training']['BATCH_SIZE'])
hr_path = config['Training']['hr_path']
lr_path = config['Training']['lr_path']
lr_path_test = config['Training']['lr_test_data']
hr_path_test = config['Training']['hr_test_data']
log_name = config['Training']['log']
checkpoint_dir = config['Training']['checkpoint_dir']

In [3]:
import h5py
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

class ctimage(Dataset):
    def __init__(self, path_hr, path_lr):
        self.img_hr = h5py.File(path_hr, 'r')
        self.img_lr = h5py.File(path_lr, 'r')
    def __del__(self):
        self.img_hr.close()
        self.img_lr.close()
    def __len__(self):
        return self.img_hr['data'].shape[0]
    def __getitem__(self,idx):
        #one sided label smoothing
        return torch.from_numpy(np.expand_dims(self.img_hr['data'][idx],axis = 0)), torch.from_numpy(np.expand_dims(self.img_lr['data'][idx],axis = 0))

training_data = ctimage(hr_path,lr_path)

dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

f = h5py.File(lr_path_test, 'r')
test_lr_img = torch.from_numpy(np.expand_dims(f['data'],axis = 0))
f.close()

f = h5py.File(hr_path_test, 'r')
test_hr_img = torch.from_numpy(np.expand_dims(f['data'],axis = 0))
f.close()

In [4]:
import DenseNet as dn
from torch import nn

def init_weights(m):
    if type(m) == nn.Conv3d:
        m.weight.data.normal_(0.0, 0.02)
    elif type(m) == nn.BatchNorm3d:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)        

netG = dn.Generator(4,4)#fill input
netD = dn.Discriminator()

netG.apply(init_weights)
netD.apply(init_weights)

Discriminator(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (lrelu1): LeakyReLU(negative_slope=0.01)
  (block1): DiscriminatorBlock(
    (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2))
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu): LeakyReLU(negative_slope=0.01)
  )
  (block2): DiscriminatorBlock(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1))
    (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu): LeakyReLU(negative_slope=0.01)
  )
  (block3): DiscriminatorBlock(
    (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2))
    (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu): LeakyReLU(negative_slope=0.01)
  )
  (block4): DiscriminatorBlock(
    (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1))
    (bn1): BatchNorm3d(256, eps=1e-05, momentum=

In [5]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.BCELoss()
gen_loss = dn.GeneratorLoss()

#if(torch.cuda.device_count()>1):
#    netG = nn.DataParallel(netG)
#    netD = nn.DataParallel(netD)
#    criterion = criterion.cuda()
#    gen_loss = gen_loss.cuda()
    
#    test_lr_img = test_lr_img.cuda()
#    test_hr_img = test_hr_img.cuda()
    
#    netG = netG.cuda()
#    netD = netD.cuda()
#elif(torch.cuda.is_available()):
netG = netG.cuda()
netD = netD.cuda()
criterion = criterion.cuda()
gen_loss = gen_loss.cuda()

test_lr_img = test_lr_img.cuda()
test_hr_img = test_hr_img.cuda()

d_optimizer = optim.Adam(netD.parameters(), lr = lr, betas = (beta1, 0.999))
g_optimizer = optim.Adam(netG.parameters(), lr = lr, betas = (beta1, 0.999))

d_schedule = optim.lr_scheduler.StepLR(d_optimizer, step_size=500, gamma=0.1)
g_schedule = optim.lr_scheduler.StepLR(g_optimizer, step_size=500, gamma=0.1)

In [None]:
import time
from torch import tensor
import torchvision
import os.path
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir = "training_curve", comment=log_name)

gen_iterations = 0
for epoch in tqdm(range(n_epoch)):
    start = time.time()
    d_schedule.step()
    g_schedule.step()
    for i_batch, sample_batched in enumerate(dataloader):
        netD.zero_grad()
        
        hr_img = sample_batched[0].float().cuda()
        lr_img = sample_batched[1].float().cuda()
        
        real_label = tensor([0.9]*BATCH_SIZE, dtype = torch.float).cuda()  
        
        #train with real
        output = netD(hr_img)
        errD_real = criterion(output, real_label)
        errD_real.backward()
        
        fake = netG(lr_img)
        
        output = netD(fake.detach())
        
        fake_label = tensor([0.0]*BATCH_SIZE, dtype = torch.float).cuda()
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        
        errD = errD_real + errD_fake
        
        d_optimizer.step()
        
        netG.zero_grad()
        real_label = tensor([1.0]*BATCH_SIZE, dtype = torch.float).cuda() # fake labels are real for generator cost
        output = netD(fake)
        errG = gen_loss(output, real_label, fake, hr_img)
        errG.backward()

        g_optimizer.step()
        
        writer.add_scalar('Loss/d',errD.data.item(),gen_iterations)
        writer.add_scalar('Loss/g',errG.data.item(),gen_iterations)
        
        ## test
        #result_img = netG(test_lr_img)
        #output = netD(result_img)
        #errT = gen_loss(output,real_label,result_img,test_hr_img)
        #grid_img = torchvision.utils.make_grid(result_img, nrow = 4)
        #writer.add_image('test img', grid_img[:,:,:,:,10], 20, dataformats='N1HW')
        #writer.add_scalar('Loss/g Test',errT.data.item(),gen_iterations)
        
        gen_iterations += 1
        
    work_dir = "training_checkpoints"
    
    if epoch % 20 == 0:
        #if(torch.cuda.device_count()>1):
        #    G_Data = netG.module.state_dict()
        #    D_Data = netD.module.state_dict()
        #else:
        G_Data = netG.state_dict()
        D_Data = netD.state_dict()
        torch.save(G_Data, os.path.join(".",work_dir,checkpoint_dir,"netG_epoch_{}.pth".format(epoch)))
        torch.save(D_Data, os.path.join(".",work_dir,checkpoint_dir,"netD_epoch_{}.pth".format(epoch)))

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


In [None]:
writer.close()