In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
from PIL import Image
import numpy as np
from UNet3d import UNet3D
from dataset import DataLoaderTurb
from torch.utils.data import DataLoader
from network import *
from utils import *

In [2]:
log_path = './log.txt'

In [3]:
turb_params = {
    'img_size': (128,128),
    'D':0.1,        # Apeture diameter
    'r0':0.05,      # Fried parameter 
    'L':1000,       # Propogation distance
    'thre':0.002,   # Used to suppress small values in the tilt correlation matrix. Increase 
                    # this threshold if the pixel displacement appears to be scattering
    'adj':1,        # Adjusting factor of delta0 for tilt matrix
    'wavelength':0.500e-6,
    'corr':-0.02,    # Correlation strength for PSF without tilt. suggested range: (-1 ~ -0.01)
    'zer_scale':1   # manually adjust zernike coefficients of the PSF without tilt.
}
simu_bs = 10
bs = 10
lr_g=1e-7
lr_d=1e-7
start_iter = 360000
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('CPU')
gen_net = UNet3D(n_frames=bs, feat_channels=[32, 128, 128, 256, 512],).cuda().train()
# if start_iter < 5000:
#     checkpoint = torch.load('.checkpoints/init_G_res.pth')
# else:
#     checkpoint = torch.load(f'./checkpoints/model_G_{start_iter}.pth')
# gen_net.load_state_dict(checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint)

net_G = Generator(turb_params, batch_size=simu_bs, restorer=gen_net, device=device)
if start_iter >= 5000:
    checkpoint = torch.load(f'./checkpoints/model_G_{start_iter}.pth')
#     lr_g = checkpoint['optimizer']['param_groups'][0]['lr']
else:
    checkpoint = torch.load('./checkpoints/init_G_res.pth')
net_G.load_state_dict(checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint)

<All keys matched successfully>

In [4]:
net_D = Discriminator().cuda()
if start_iter > 0:
    checkpoint = torch.load(f'./checkpoints/model_D_{start_iter}.pth')
    net_D.load_state_dict(checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint)
#     lr_d = checkpoint['optimizer']['param_groups'][0]['lr']
    TGAN = TurbGAN(save_path='./', model_g=net_G, model_d=net_D, lr_g=lr_g, lr_d=lr_d, continue_train=True)
else:
    TGAN = TurbGAN(save_path='./', model_g=net_G, model_d=net_D, lr_g=lr_g, lr_d=lr_d, continue_train=False)

In [5]:
target = TF.to_tensor(Image.open('./image/gt.jpg').convert("RGB"))

In [6]:
G_loader = DataLoader(dataset=DataLoaderTurb('./image/img_in'), batch_size=bs, \
                          shuffle=True, num_workers=8, drop_last=True, pin_memory=True)
D_loader = DataLoader(dataset=DataLoaderTurb('./image/img_in'), batch_size=simu_bs, \
                          shuffle=True, num_workers=8, drop_last=True, pin_memory=True)

In [7]:
niter = start_iter + 1
max_niter = 400000
current_loss_G = []
current_loss_D = []
current_psnr = []
best_psnr = 0
loss_G_mean = []
loss_D_mean = []
psnr_mean = []
check_freq = 500

while True:
    for G_in, D_real in zip(G_loader, D_loader):
        TGAN.set_input(G_in, D_real)
        TGAN.optimize(niter)
        generated = TGAN.recon.squeeze().detach()
        current_loss_G.append(TGAN.loss_G.item())
        current_loss_D.append(TGAN.loss_D.item())
        current_psnr.append(calculate_psnr(generated.cpu().numpy()*255, \
                               target.detach().numpy()*255, border=0))
        if niter >= check_freq and niter % check_freq == 0:
            lr_g, lr_d = TGAN.optimizer_G.param_groups[0]['lr'], TGAN.optimizer_D.param_groups[0]['lr']
            loss_G, loss_D = sum(current_loss_G)/check_freq, sum(current_loss_D)/check_freq,
            psnr = sum(current_psnr)/check_freq
            message = 'step: {:d} lr_g: {:.9f} lr_d: {:.9f} loss_G: {:7f} loss_D: {:7f} psnr: {:4f}'.format(
                niter, lr_g, lr_d, loss_G, loss_D, psnr)
            print(message)
            with open(log_path,'a') as log_file:
                log_file.write('{}\n'.format(message))

            loss_D_mean.append(loss_D)
            loss_G_mean.append(loss_G)
            psnr_mean.append(psnr)
            current_loss_G = []
            current_loss_D = []
            current_psnr = []
            
            if psnr > best_psnr:
                torch.save({'step': niter, 
                            'best_psnr': psnr,
                            'state_dict': TGAN.model_G.state_dict(),
                            'optimizer' : TGAN.optimizer_G.state_dict()
                            }, f"{TGAN.ckpt_path}/best_G.pth")
                torch.save({'step': niter, 
                        'best_psnr': psnr,
                        'state_dict': TGAN.model_D.state_dict(),
                        'optimizer' : TGAN.optimizer_D.state_dict()
                        }, f"{TGAN.ckpt_path}/best_D.pth") 
                best_psnr = psnr

            TGAN.update_learning_rate(loss_G)
            TGAN.save_results(niter)
        if niter % 1000 == 0:
            TGAN.save_networks(niter)
        niter += 1

step: 360500 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.249516 loss_D: -0.271137 psnr: 25.942674
step: 361000 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.251849 loss_D: -0.280724 psnr: 25.925190
step: 361500 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.238906 loss_D: -0.288604 psnr: 25.934529
step: 362000 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.232285 loss_D: -0.275303 psnr: 25.935929
step: 362500 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.239640 loss_D: -0.274953 psnr: 25.927555
step: 363000 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.238940 loss_D: -0.259674 psnr: 25.921523
step: 363500 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.260273 loss_D: -0.259619 psnr: 25.924279
step: 364000 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.244067 loss_D: -0.271567 psnr: 25.938342
step: 364500 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.225094 loss_D: -0.271821 psnr: 25.929645
step: 365000 lr_g: 0.000000100 lr_d: 0.000000100 loss_G: 0.220689 loss_D: -0.264335 psnr: 25.932609


KeyboardInterrupt: 