#### TODO list:
  - test and pick the best loss function
  - test and run the training code
    - text output
    - img output
    - save models
  - convert to .py file for easy future running (tuning)

In [1]:
from model import Generator, Discriminator
from loss import SRGAN_Loss
from utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform, scale_lr2hr

import os
from math import log10
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
import pytorch_ssim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
### global parameters
SCALE_FACTOR = 4
TRAIN_HR_DIR = "temp_data/train_images_hr/"
TRAIN_LR_DIR = "temp_data/train_images_lr/"
VAL_HR_DIR = "temp_data/val_images_hr/"
VAL_LR_DIR = "temp_data/val_images_lr/"
RESULTS_DIR = "results/" + "SR" + str(SCALE_FACTOR) + "/"

# network parameters
CONTENT_LOSS = "both"  # try "both" in future test
ADVERSARIAL_LOSS = "bce"
TV_LOSS_ON = False

# training parameters
BATCH_SIZE = 1
NUM_EPOCHS = 10
NUM_WORKERS = 0  # workers for loading data

# GPU

In [3]:
train_dataset = TrainDatasetFromFolder(hr_dir=TRAIN_HR_DIR, lr_dir=TRAIN_LR_DIR)
train_loader = DataLoader(dataset=train_dataset, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = ValDatasetFromFolder(hr_dir=VAL_HR_DIR, lr_dir=VAL_LR_DIR)
val_loader = DataLoader(dataset=val_dataset, num_workers=NUM_WORKERS, batch_size=1, shuffle=False)

netG = Generator(scale_factor=SCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

generator_criterion = SRGAN_Loss(content_loss=CONTENT_LOSS, 
                                 adversarial_loss=ADVERSARIAL_LOSS, 
                                 tv_loss_on=TV_LOSS_ON)

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()
    
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
if not os.path.isdir(RESULTS_DIR):
    os.mkdir(RESULTS_DIR)
out_path_val = RESULTS_DIR + "val_predict/"
if not os.path.isdir(out_path_val):
    os.mkdir(out_path_val)
out_path_net = RESULTS_DIR + "net_weights/"
if not os.path.isdir(out_path_net):
    os.mkdir(out_path_net)

for epoch in range(1, NUM_EPOCHS + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    
    ### training
    for lr_img, hr_img in train_bar:
        g_update_first = True
        batch_size = lr_img.size(0)
        running_results['batch_sizes'] += batch_size
        
        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(hr_img)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = Variable(lr_img)
        if torch.cuda.is_available():
            z = z.cuda()
        fake_img = netG(z)

        netD.zero_grad()
        real_out = netD(real_img)
        fake_out = netD(fake_img)
        d_loss = 1 - real_out.mean() + fake_out.mean()  # L1 loss
        #  could also try a BCELoss: log(1-real_out)+log(fake_out)
        d_loss.backward(retain_graph=True)
        optimizerD.step()
        
        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerG.step()

        # loss for current batch before optimization 
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.sum().item()
        running_results['g_score'] += fake_out.sum().item()

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(hr): %.4f D(G(lr)): %.4f' % (
            epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))
        pass
    
    ### evaluating
    netG.eval()
    out_path_val_epoch = out_path_val + "epoch_%d/" % epoch
    if not os.path.isdir(out_path_val_epoch):
        os.mkdir(out_path_val_epoch)
    with torch.no_grad():
        valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
        val_bar = tqdm(val_loader)
        val_images = []
        val_names = []
        for val_lr, val_hr, val_name in val_bar:
            batch_size = val_lr.size(0)
            valing_results['batch_sizes'] += batch_size
            lr = val_lr
            hr = val_hr
            lr2hr = scale_lr2hr((256, 256))(lr.squeeze(0))
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            sr = netG(lr)

            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            valing_results['ssims'] += batch_ssim * batch_size
            valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
            val_bar.set_description(
                desc='[validation] PSNR: %.4f dB SSIM: %.4f' % (
                    valing_results['psnr'], valing_results['ssim']))
            val_images.extend(
                [display_transform()(lr2hr.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                 display_transform()(sr.data.cpu().squeeze(0))])
            val_names.extend(val_name)
        val_images = torch.stack(val_images)
        val_images = torch.chunk(val_images, val_images.size(0) // 3)
        # val_save_bar = zip(tqdm(val_images, desc='[saving validating results]'), val_names)
        for image, name in zip(val_images, val_names):
            image = utils.make_grid(image, nrow=3, padding=5)
            utils.save_image(image, out_path_val_epoch + '%s.png' % name.strip(".tif"), padding=5)
    
    ### save model parameters
    torch.save(netG.state_dict(), out_path_net + "netG_epoch_%d.pth" % epoch)
    torch.save(netD.state_dict(), out_path_net + "netD_epoch_%d.pth" % epoch)
    
    ### save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(valing_results['psnr'])
    results['ssim'].append(valing_results['ssim'])

    if epoch % 1 == 0 and epoch != 0:
        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(RESULTS_DIR + 'train_stats.csv', index_label='Epoch')

# generator parameters: 734219
# discriminator parameters: 138908865


[1/10] Loss_D: 0.9966 Loss_G: 0.0111 D(hr): 0.8815 D(G(lr)): 0.8790: 100%|██████████| 587/587 [02:04<00:00,  4.70it/s]
[validation] PSNR: 21.1037 dB SSIM: 0.5346: 100%|██████████| 5/5 [00:00<00:00, 30.56it/s]
[2/10] Loss_D: 1.0000 Loss_G: 0.0030 D(hr): 1.0000 D(G(lr)): 1.0000: 100%|██████████| 587/587 [02:05<00:00,  4.68it/s]
[validation] PSNR: 23.0675 dB SSIM: 0.6305: 100%|██████████| 5/5 [00:00<00:00, 31.88it/s]
[3/10] Loss_D: 1.0000 Loss_G: 0.0023 D(hr): 1.0000 D(G(lr)): 1.0000: 100%|██████████| 587/587 [02:03<00:00,  4.77it/s]
[validation] PSNR: 24.2256 dB SSIM: 0.6780: 100%|██████████| 5/5 [00:00<00:00, 37.24it/s]
[4/10] Loss_D: 1.0000 Loss_G: 0.0022 D(hr): 1.0000 D(G(lr)): 1.0000: 100%|██████████| 587/587 [02:04<00:00,  4.72it/s]
[validation] PSNR: 18.1952 dB SSIM: 0.5701: 100%|██████████| 5/5 [00:00<00:00, 31.39it/s]
[5/10] Loss_D: 1.0000 Loss_G: 0.0034 D(hr): 1.0000 D(G(lr)): 1.0000: 100%|██████████| 587/587 [02:05<00:00,  4.69it/s]
[validation] PSNR: 24.9274 dB SSIM: 0.7173: 1

# CPU

In [None]:
train_dataset = TrainDatasetFromFolder(hr_dir=TRAIN_HR_DIR, lr_dir=TRAIN_LR_DIR)
train_loader = DataLoader(dataset=train_dataset, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = ValDatasetFromFolder(hr_dir=VAL_HR_DIR, lr_dir=VAL_LR_DIR)
val_loader = DataLoader(dataset=val_dataset, num_workers=NUM_WORKERS, batch_size=1, shuffle=False)

netG = Generator(scale_factor=SCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

generator_criterion = SRGAN_Loss(content_loss=CONTENT_LOSS, 
                                 adversarial_loss=ADVERSARIAL_LOSS, 
                                 tv_loss_on=TV_LOSS_ON)

    
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
if not os.path.isdir(RESULTS_DIR):
    os.mkdir(RESULTS_DIR)
out_path_val = RESULTS_DIR + "val_predict/"
if not os.path.isdir(out_path_val):
    os.mkdir(out_path_val)
out_path_net = RESULTS_DIR + "net_weights/"
if not os.path.isdir(out_path_net):
    os.mkdir(out_path_net)

for epoch in range(1, NUM_EPOCHS + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    
    ### training
    for lr_img, hr_img in train_bar:
        g_update_first = True
        batch_size = lr_img.size(0)
        running_results['batch_sizes'] += batch_size
        
        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(hr_img)
        z = Variable(lr_img)
        fake_img = netG(z)

        netD.zero_grad()
        real_out = netD(real_img)
        fake_out = netD(fake_img)
        d_loss = 1 - real_out.mean() + fake_out.mean()  # L1 loss
        #  could also try a BCELoss: log(1-real_out)+log(fake_out)
        d_loss.backward(retain_graph=True)
        optimizerD.step()
        
        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerG.step()

        # loss for current batch before optimization 
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.sum().item()
        running_results['g_score'] += fake_out.sum().item()

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(hr): %.4f D(G(lr)): %.4f' % (
            epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))
        pass
    
    ### evaluating
    netG.eval()
    out_path_val_epoch = out_path_val + "epoch_%d/" % epoch
    if not os.path.isdir(out_path_val_epoch):
        os.mkdir(out_path_val_epoch)
    with torch.no_grad():
        valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
        val_bar = tqdm(val_loader)
        val_images = []
        val_names = []
        for val_lr, val_hr, val_name in val_bar:
            batch_size = val_lr.size(0)
            valing_results['batch_sizes'] += batch_size
            lr = val_lr
            hr = val_hr
            lr2hr = scale_lr2hr((256, 256))(lr.squeeze(0))
            sr = netG(lr)

            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            valing_results['ssims'] += batch_ssim * batch_size
            valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
            val_bar.set_description(
                desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                    valing_results['psnr'], valing_results['ssim']))
            val_images.extend(
                [display_transform()(lr2hr.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                 display_transform()(sr.data.cpu().squeeze(0))])
            val_names.extend(val_name)
        val_images = torch.stack(val_images)
        val_images = torch.chunk(val_images, val_images.size(0) // 3)
        # val_save_bar = zip(tqdm(val_images, desc='[saving validating results]'), val_names)
        for image, name in zip(vap_images, val_names):
            image = utils.make_grid(image, nrow=3, padding=5)
            utils.save_image(image, out_path_val_epoch + '%s.png' % name.strip(".tif"), padding=5)
    
    ### save model parameters
    torch.save(netG.state_dict(), out_path_net + "netG_epoch_%d.pth" % epoch)
    torch.save(netD.state_dict(), out_path_net + "netD_epoch_%d.pth" % epoch)
    
    ### save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(valing_results['psnr'])
    results['ssim'].append(valing_results['ssim'])

    if epoch % 1 == 0 and epoch != 0:
        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(RESULTS_DIR + 'train_stats.csv', index_label='Epoch')

# Test

In [None]:
def torch_cuda_memory(device=None):
    print("memory allocated: %.2f" % (torch.cuda.max_memory_allocated(device=device)/1024/1024/1024))
    print("memory cached: %.2f" % (torch.cuda.max_memory_cached(device=device)/1024/1024/1024))

In [None]:
# prints currently alive Tensors and Variables
import torch
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass

In [10]:
for obj in gc.get_objects():
    if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        print(reduce(obj.mul, obj.size()) if len(obj.size()) > 0 else 0, type(obj), obj.size())

ModuleNotFoundError: No module named '_dbm'

In [None]:
train_dataset = TrainDatasetFromFolder(hr_dir=TRAIN_HR_DIR, lr_dir=TRAIN_LR_DIR)
train_loader = DataLoader(dataset=train_dataset, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = ValDatasetFromFolder(hr_dir=VAL_HR_DIR, lr_dir=VAL_LR_DIR)
val_loader = DataLoader(dataset=val_dataset, num_workers=NUM_WORKERS, batch_size=1, shuffle=False)

netG = Generator(scale_factor=SCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

generator_criterion = SRGAN_Loss(content_loss=CONTENT_LOSS, 
                                 adversarial_loss=ADVERSARIAL_LOSS, 
                                 tv_loss_on=TV_LOSS_ON)

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()
    
if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()
    
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
if not os.path.isdir(RESULTS_DIR):
    os.mkdir(RESULTS_DIR)
out_path_val = RESULTS_DIR + "val_predict/"
if not os.path.isdir(out_path_val):
    os.mkdir(out_path_val)
out_path_net = RESULTS_DIR + "net_weights/"
if not os.path.isdir(out_path_net):
    os.mkdir(out_path_net)

for epoch in range(1, NUM_EPOCHS + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    
    ### training
    for lr_img, hr_img in train_bar:
        break
    break

In [None]:
g_update_first = True
batch_size = lr_img.size(0)
running_results['batch_sizes'] += batch_size

############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
real_img = Variable(hr_img)
if torch.cuda.is_available():
    real_img = real_img.cuda()
z = Variable(lr_img)
if torch.cuda.is_available():
    z = z.cuda()

In [None]:
fake_img = netG(z)

In [None]:
loss = 1 - real_out.mean()
loss.backward()

In [None]:

netD.zero_grad()
real_out = netD(real_img)
fake_out = netD(fake_img)
d_loss = 1 - real_out.mean() + fake_out.mean()  # L1 loss
#  could also try a BCELoss: log(1-real_out)+log(fake_out)
d_loss.backward(retain_graph=True)
optimizerD.step()

In [None]:
torch_cuda_memory()

In [None]:
netD.zero_grad()

In [None]:
torch.cuda.memory_allocated(device=None)

In [None]:
torch.cuda.max_memory_allocated(device=None)

In [None]:
torch.cuda.reset_max_memory_allocated(device=None)

In [None]:
torch.cuda.memory_cached(device=None)

In [None]:
torch.cuda.max_memory_cached(device=None)

In [None]:
torch.cuda.reset_max_memory_cached(device=None)

In [2]:
from tqdm import tqdm
from time import sleep

for z in range(2):
    a = tqdm(range(10), desc='1st loop')
    for i in a:
        sleep(0.1)
        a.set_description("abc")
    b = tqdm(range(5), desc='2nd loop')
    for j in b:
        sleep(0.3)
        b.set_description("cba")
    c = zip(tqdm(range(200), desc='3nd loop'), range(200))
    for k1, k2 in c:
        sleep(0.01)