In [1]:
import argparse
import os
from math import log10

import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook

In [2]:
import sys

sys.path.append('..')
from src.configs import config
from src.modules.srgan.data_utils import HPATrainDatasetFromFolder, HPAValDatasetFromFolder, \
    RecursionTrainDatasetFromFolder, RecursionValDatasetFromFolder, display_transform
from src.modules.srgan.loss import GeneratorLoss
from src.modules.srgan.model import Generator, Discriminator
import src.modules.srgan.ssim as pytorch_ssim

%load_ext autoreload
%autoreload 2

In [3]:
def train_SRGAN(train_data_dir, valid_data_dir, process_data_dir,
                dataset_type='hpa', crop_size=88, upscale_factor=4, num_epochs=10,
                perception_enabled=True):
    
    # Load test and train sets
    if dataset_type == 'hpa':
        train_set = HPATrainDatasetFromFolder(train_data_dir, crop_size=crop_size, upscale_factor=upscale_factor)
        val_set = HPAValDatasetFromFolder(valid_data_dir, upscale_factor=upscale_factor)
    elif dataset_type == 'rx':
        train_set = RecursionTrainDatasetFromFolder(train_data_dir, crop_size=crop_size, upscale_factor=upscale_factor)
        val_set = RecursionValDatasetFromFolder(valid_data_dir, upscale_factor=upscale_factor)
    else:
        raise ValueError('Unknown dataset type ' + dataset_type)

    train_loader = DataLoader(dataset=train_set, num_workers=24, batch_size=48, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=8, batch_size=1, shuffle=False)

    process_data_dir = process_data_dir + '/' + dataset_type
    
    # Initialize the networks
    netG = Generator(upscale_factor)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss(perception_enabled=perception_enabled)

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())
    
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    for epoch in range(1, num_epochs + 1):
        train_bar = tqdm_notebook(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        netG.train()
        netD.train()
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size

            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            real_img = Variable(target)
            if torch.cuda.is_available():
                real_img = real_img.cuda()
            z = Variable(data)
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()

            optimizerG.step()
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            g_loss = generator_criterion(fake_out, fake_img, real_img)
            running_results['g_loss'] += g_loss.data.item() * batch_size
            d_loss = 1 - real_out + fake_out
            running_results['d_loss'] += d_loss.data.item() * batch_size
            running_results['d_score'] += real_out.data.item() * batch_size
            running_results['g_score'] += fake_out.data.item() * batch_size

            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

        with torch.no_grad():
            netG.eval()
            out_path = process_data_dir + '/training_results/SRF_' + str(upscale_factor) + '/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)
                
            val_bar = tqdm_notebook(val_loader)
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            
            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                valing_results['batch_sizes'] += batch_size
                lr = Variable(val_lr)
                hr = Variable(val_hr)
                if torch.cuda.is_available():
                    lr = lr.cuda()
                    hr = hr.cuda()
                sr = netG(lr)

                batch_mse = ((sr - hr) ** 2).data.mean()
                valing_results['mse'] += batch_mse * batch_size
                batch_ssim = pytorch_ssim.ssim(sr, hr).data.item()
                valing_results['ssims'] += batch_ssim * batch_size
                valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        valing_results['psnr'], valing_results['ssim']))

                val_images.extend(
                    [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                     display_transform()(sr.data.cpu().squeeze(0))])
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            val_save_bar = tqdm_notebook(val_images, desc='[saving training results]')
            index = 1
            for image in val_save_bar:
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                index += 1

        # save model parameters
        epochs_path = process_data_dir + '/epochs'
        if not os.path.exists(epochs_path):
            os.makedirs(epochs_path)
            
        torch.save(netG.state_dict(), epochs_path +'/netG_epoch_%d_%d.pth' % (upscale_factor, epoch))
        torch.save(netD.state_dict(), epochs_path + '/netD_epoch_%d_%d.pth' % (upscale_factor, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        if epoch % 1 == 0 and epoch != 0:
            out_path = process_data_dir + '/statistics/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')

In [4]:
TRAIN_DATA_DIR = '../data/hpa/train'
VAL_DATA_DIR = '../data/hpa/valid'
PROCESS_DATA_DIR = '../srgan_training'

In [5]:
train_SRGAN(TRAIN_DATA_DIR, VAL_DATA_DIR, PROCESS_DATA_DIR, num_epochs=4, perception_enabled=False)

# generator parameters: 713481
# discriminator parameters: 5214273


HBox(children=(IntProgress(value=0, max=647), HTML(value='')))






HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=4, style=ProgressStyle(descri…




HBox(children=(IntProgress(value=0, max=647), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=4, style=ProgressStyle(descri…




HBox(children=(IntProgress(value=0, max=647), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=4, style=ProgressStyle(descri…




HBox(children=(IntProgress(value=0, max=647), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=4, style=ProgressStyle(descri…




In [5]:
train_SRGAN(TRAIN_DATA_DIR, VAL_DATA_DIR, PROCESS_DATA_DIR,
            dataset_type='hpa', upscale_factor=2, num_epochs=6)

# generator parameters: 565768
# discriminator parameters: 5214273


HBox(children=(IntProgress(value=0, max=644), HTML(value='')))






HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=10, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, max=644), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=10, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, max=644), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=10, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, max=644), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=10, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, max=644), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=10, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, max=644), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




HBox(children=(IntProgress(value=0, description='[saving training results]', max=10, style=ProgressStyle(descr…




In [7]:
RX_TRAIN_DIR = '../data/recursion/train'
RX_VAL_DIR = '../data/recursion/valid'

In [None]:
train_SRGAN(RX_TRAIN_DIR, RX_VAL_DIR, PROCESS_DATA_DIR,
            dataset_type='rx', upscale_factor=2, num_epochs=5, crop_size=122)

# generator parameters: 565768
# discriminator parameters: 5214273


HBox(children=(IntProgress(value=0, max=1515), HTML(value='')))

#### 