In [1]:
import argparse
import os
from math import log10

import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

import pytorch_ssim
from data_utils import TrainDatasetLoader, ValidationDatasetLoader, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator
from pathlib import Path


In [2]:
path_l = ['./epochs', './statistics', './data']
for p in path_l:
    if not os.path.exists(p):
        os.makedirs(p)


In [3]:
UPSCALE_FACTOR = 4
CROP_SIZE = 88

In [4]:
# Data preparation

train_set = TrainDatasetLoader(Path('./data/train'), crop_size=CROP_SIZE, factor=UPSCALE_FACTOR)
val_set = ValidationDatasetLoader(Path('./data/valid'), factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=1, batch_size=1, shuffle=False)


In [5]:
# model setting
START_EPOCH = 0
NUM_EPOCHS = 100

netG = Generator(UPSCALE_FACTOR).cuda()
netD = Discriminator().cuda()

generator_criterion = GeneratorLoss().cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())




In [8]:
results = {'d_loss':[], 'g_loss':[], 'd_score':[], 'g_score':[], 'psnr':[], 'ssim':[]}

resume = True

if resume:
    # Load statistics from csv
    result_dataframe = pd.read_csv('./statistics/srf_' + str(UPSCALE_FACTOR) + '_train_results.csv')
    results['d_loss'] = result_dataframe['Loss_D'].values.tolist()
    results['g_loss'] = result_dataframe['Loss_G'].values.tolist()
    results['d_score'] = result_dataframe['Score_D'].values.tolist()
    results['g_score'] = result_dataframe['Score_G'].values.tolist()
    results['psnr'] = result_dataframe['PSNR'].values.tolist()
    results['ssim'] = result_dataframe['SSIM'].values.tolist()
    
    # Get start epoch from dataframe
    START_EPOCH = result_dataframe['Epoch'].iloc[-1]
    
    # Load model from start epoch
    generator_checkpoint = './epochs/netG_epoch_4_%i.pth'%START_EPOCH
    discriminator_checkpoint = './epochs/netD_epoch_4_%i.pth'%START_EPOCH
    checkpointG = torch.load(generator_checkpoint, map_location="cuda:0")
    checkpointD = torch.load(discriminator_checkpoint, map_location="cuda:0")
    netG.load_state_dict(checkpointG)
    netD.load_state_dict(checkpointD)
    
    print('train starts in epoch', START_EPOCH+1)

[0.9989655017852784, 0.974418044090271, 0.9066504240036012, 0.8817505240440369, 0.9933868050575256, 0.9786656498908995, 0.9965881109237672, 0.9808578491210938, 0.9328036308288574, 0.9789286851882936]
[11.254087991095718, 11.85107821304836, 12.684156954604148, 15.773243793125008, 13.904385940061026, 14.714163991601396, 6.337763003747671, 10.585449590819724, 10.694705869605093, 11.556622504523002]
train starts in epoch 11


In [9]:

for epoch in range(START_EPOCH+1, NUM_EPOCHS + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes':0, 'd_loss':0, 'g_loss':0, 'd_score':0, 'g_score':0}
    
    netG.train()
    netD.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size
        
        # Train Descriminator network : maximize D(x) - 1 - D(G(z))
        # Give a big score for choosing the original image
        real_img = Variable(target).cuda()
        z = Variable(data).cuda()
        
        fake_img = netG(z)
        netD.zero_grad()
        
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        
        d_loss = 1 - real_out + fake_out
        
        d_loss.backward(retain_graph = True)
        optimizerD.step()
        
        # train Generator network : minimize 1 - D(G(z)) + Perception Loss + Image Loss + TV Loss
        #
        netG.zero_grad()
        
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
        
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()
        
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
        
        optimizerG.step()
        
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size
        
        train_bar.set_description(desc = '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
            epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes'],
        ))
        
    netG.eval()
    out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
        
    with torch.no_grad():
        val_bar = tqdm(val_loader)
        valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
        val_images = []
        for val_lr, val_hr_restore, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results['batch_sizes'] += batch_size
            lr = val_lr.cuda()
            hr = val_hr.cuda()
            sr = netG(lr)
            
            batch_mse = ((sr-hr)**2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            
            valing_results['ssims'] += batch_ssim * batch_size
            
            # PSNR (Peak Signal-to-Noise Ratio)
            # PSNR = 10 * log10( R^2 / MSE )
            
            valing_results['psnr'] = 10 * log10((1**2) / (valing_results['mse'] / valing_results['batch_sizes']))
            #valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
            
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
            val_bar.set_description(
                desc = '[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                valing_results['psnr'], valing_results['ssim']))
            
            val_images.extend(
                [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                display_transform()(sr.data.cpu().squeeze(0))])
            
        val_images = torch.stack(val_images)
        val_images = torch.chunk(val_images, val_images.size(0) // 15)
        val_save_bar = tqdm(val_images, desc = '[saving training results]')
        index = 1
        for image in val_save_bar:
            image = utils.make_grid(image, nrow = 3, padding = 5)
            utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding = 5)
            index += 1
        
        # save model
    torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
    torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(valing_results['psnr'])
    results['ssim'].append(valing_results['ssim'])
        

    if epoch != 0:
        out_path = 'statistics/'
        data_frame = pd.DataFrame(
            data = {'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                   'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                   index = range(1, epoch + 1))
        data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label = 'Epoch')
        
        
        
        
        
        

[11/100] Loss_D: 0.8542 Loss_G: 0.0270 D(x): 0.4374 D(G(z)): 0.4072: 100%|███████████████| 1/1 [00:01<00:00,  1.48s/it]
[converting LR images to SR images] PSNR: 11.7220 dB SSIM: 0.4212: 100%|███████████████| 50/50 [00:03<00:00, 12.87it/s]
[saving training results]: 100%|███████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.65it/s]
[12/100] Loss_D: 0.8616 Loss_G: 0.0330 D(x): 0.4760 D(G(z)): 0.3714: 100%|███████████████| 1/1 [00:01<00:00,  1.39s/it]
[converting LR images to SR images] PSNR: 11.5947 dB SSIM: 0.5343: 100%|███████████████| 50/50 [00:04<00:00, 12.11it/s]
[saving training results]: 100%|███████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.72it/s]
[13/100] Loss_D: 0.9153 Loss_G: 0.0193 D(x): 0.5013 D(G(z)): 0.3434: 100%|███████████████| 1/1 [00:01<00:00,  1.40s/it]
  0%|                                                                                           | 0/50 [00:00<?, ?it/s]


KeyboardInterrupt: 