In [1]:
%matplotlib inline
import os
import ipywidgets as widgets
from matplotlib import pyplot as plt

file_path = widgets.Text(
    description = 'path'
)


display(file_path)

Text(value='', description='path')

In [2]:
import configparser
config = configparser.ConfigParser()
config.read(file_path.value)

num_unit = int(config['NN']['n_u'])
num_block = int(config['NN']['n_block'])
growth = int(config['NN']['growth'])

lr = float(config['NN']['lr'])
beta1 = float(config['NN']['beta1'])
alpha = float(config['NN']['alpha'])

im_size = int(config['Training']['im_size'])
n_epoch = int(config['Training']['n_epoch'])
BATCH_SIZE = int(config['Training']['BATCH_SIZE'])
hr_path = config['Training']['hr_path']
lr_path = config['Training']['lr_path']
lr_path_test = config['Training']['lr_test_data']
hr_path_test = config['Training']['hr_test_data']
log_name = config['Training']['log']
checkpoint_dir = config['Training']['checkpoint_dir']

In [3]:
import h5py
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

class ctimage(Dataset):
    def __init__(self, img_hr, img_lr):
        self.img_hr = img_hr
        self.img_lr = img_lr
    def __len__(self):
        return self.img_hr.shape[0]
    def __getitem__(self,idx):
        hr_img = torch.from_numpy(np.expand_dims(self.img_hr[idx],axis = 0))
        lr_img = torch.from_numpy(np.expand_dims(self.img_lr[idx],axis = 0))
        return hr_img,lr_img

hr_file = h5py.File(hr_path, 'r')
lr_file = h5py.File(lr_path, 'r')

t_hr_file = h5py.File(hr_path_test, 'r')
t_lr_file = h5py.File(lr_path_test, 'r')

training_data = ctimage(hr_file['data'],lr_file['data'])
test_data = ctimage(t_hr_file['data'],t_lr_file['data'])

dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [4]:
import DenseNet as dn
import dcgan as dc
from torch import nn

def init_weights(m):
    if type(m) == nn.Conv3d:
        m.weight.data.normal_(0.0, 0.02)
    elif type(m) == nn.BatchNorm3d:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)        

netG = dc.make_generator_model(im_size,1,1,16)#fill input
netD = dc.make_discriminator_model(im_size,1,16)

netG.apply(init_weights)
netD.apply(init_weights)

Sequential(
  (0): Conv3d(1, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (3): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): LeakyReLU(negative_slope=0.2, inplace=True)
  (6): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): LeakyReLU(negative_slope=0.2, inplace=True)
  (9): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (10): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): LeakyReLU(negative_slope=0.2, inplace=True)
  (12): Conv3d(128, 1, kernel_size=(4, 4, 4), stride=(1, 1, 1), bias=False)
  (13): Sigmoid()
)

In [5]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.BCELoss()
gen_loss = dn.GeneratorLoss()

if(torch.cuda.device_count()>1):
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)
    criterion = criterion.cuda()
    gen_loss = gen_loss.cuda()
    
    netG = netG.cuda()
    netD = netD.cuda()
elif(torch.cuda.is_available()):
    netG = netG.cuda()
    netD = netD.cuda()
    criterion = criterion.cuda()
    gen_loss = gen_loss.cuda()

d_optimizer = optim.Adam(netD.parameters(), lr = lr, betas = (beta1, 0.999))
g_optimizer = optim.Adam(netG.parameters(), lr = lr, betas = (beta1, 0.999))

d_schedule = optim.lr_scheduler.StepLR(d_optimizer, step_size=500, gamma=0.1)
g_schedule = optim.lr_scheduler.StepLR(g_optimizer, step_size=500, gamma=0.1)

In [6]:
import time
from torch import tensor
import torchvision
import os.path
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(comment=log_name)

gen_iterations = 0
for epoch in tqdm(range(n_epoch)):
    for i_batch, sample_batched in enumerate(dataloader):
        netD.train()
        netG.train()
        
        netD.zero_grad()
        
        hr_img = sample_batched[0].float().cuda()
        lr_img = sample_batched[1].float().cuda()
        
        real_label = tensor([0.9]*BATCH_SIZE, dtype = torch.float).cuda()  
        
        #train with real
        output = netD(hr_img)
        errD_real = criterion(output, real_label)
        errD_real.backward()
        
        fake = netG(lr_img)
        
        output = netD(fake.detach())
        
        fake_label = tensor([0.0]*BATCH_SIZE, dtype = torch.float).cuda()
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        
        errD = errD_real + errD_fake
        
        d_optimizer.step()
        
        netG.zero_grad()
        real_label = tensor([1.0]*BATCH_SIZE, dtype = torch.float).cuda() # fake labels are real for generator cost
        output = netD(fake)
        errG = gen_loss(output, real_label, fake, hr_img)
        errG.backward()

        g_optimizer.step()
        
        writer.add_scalar('Loss/d',errD.item(),gen_iterations)
        writer.add_scalar('Loss/g',errG.item(),gen_iterations)
        
        gen_iterations += 1
        
    work_dir = "training_checkpoints"
    d_schedule.step()
    g_schedule.step()
    
    if epoch % 20 == 0:
        if(torch.cuda.device_count()>1):
            G_Data = netG.module.state_dict()
            D_Data = netD.module.state_dict()
        else:
            G_Data = netG.state_dict()
            D_Data = netD.state_dict()
        torch.save(G_Data, os.path.join(".",work_dir,checkpoint_dir,"netG_epoch_{}.pth".format(epoch)))
        torch.save(D_Data, os.path.join(".",work_dir,checkpoint_dir,"netD_epoch_{}.pth".format(epoch)))
        
        netD.eval()
        netG.eval()
        
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(test_dataloader):
                hr_img_t = sample_batched[0].float().to(device_1)
                lr_img_t = sample_batched[1].float().to(device_0)
                
                result_img = netG(lr_img_t)
                output = netD(result_img)
                errT = gen_loss(output,real_label,result_img,hr_img_t)
                img_grid = torchvision.utils.make_grid(result_img[:,:,:,:,10])
                
                writer.add_images('test output', img_grid, epoch)
            
                writer.add_scalar('Loss/g_test', errT.item(), epoch)
        

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)



RuntimeError: CUDA out of memory. Tried to allocate 4.59 GiB (GPU 0; 3.94 GiB total capacity; 2.88 GiB already allocated; 4.88 MiB free; 14.29 MiB cached)

In [None]:
writer.close()

In [None]:
hr_file.close()
lr_file.close()

t_hr_file.close()
t_lr_file.close()