In [None]:
import torch
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from torchvision.transforms.functional import crop
import torchvision.models as models
import os
import cv2
import random
import scipy
import skimage
import numpy as np
import matplotlib.pyplot as plt
import glob
import time
from PIL import Image
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torch.autograd import Variable
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
import torch.optim as optim
from Model import *
from Prepare import preprocessing
from Generate import generate

In [None]:
from Config import OutpaintingConfig
config = OutpaintingConfig()
config.display()

In [None]:
cropped_size = config.CROPPED_SIZE
output_size = config.OUTPUT_SIZE
expand_size = (output_size - cropped_size) // 2

### Preprocessing

In [None]:
preprocessing(cropped_size, output_size, expand_size)
preprocessing(cropped_size, output_size, expand_size, target_dir='val')

### Declare

In [None]:
gen = Generator()
dis = Discriminator()

gpu_device = torch.device('cuda:0')
cpu_device = torch.device("cpu")
gen.to(gpu_device)
dis.to(gpu_device)

In [None]:
SSIM_MODULE = SSIM(data_range=1, size_average=True, channel=3)
Loss_L1 = nn.L1Loss()
optimizer_G = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=config.ADAM_BETAS)

Loss_BCE = nn.BCELoss()
optimizer_D = optim.Adam(dis.parameters(), lr=config.LEARNING_RATE, betas=config.ADAM_BETAS)

VGG16_MODULE = models.vgg16(pretrained=True).to(gpu_device)
Loss_MSE = nn.MSELoss()

trainset = OutpaintingDataset('./dataset/train')
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=4)

### Load Model

In [None]:
checkpoint_G = torch.load('generator_final.tar')
gen.load_state_dict(checkpoint_G['model_state_dict'])
optimizer_G.load_state_dict(checkpoint_G['optimizer_state_dict'])
epoch = checkpoint_G['epoch']
loss_pxl = checkpoint_G['loss_pxl']
loss_adv = checkpoint_G['loss_adv']
loss_pxl_array = checkpoint_G['loss_pxl_array']
loss_adv_array = checkpoint_G['loss_adv_array']

In [None]:
checkpoint_D = torch.load('discriminator_final.tar')
dis.load_state_dict(checkpoint_D['model_state_dict'])
optimizer_D.load_state_dict(checkpoint_D['optimizer_state_dict'])
epoch = checkpoint_D['epoch']
loss_D = checkpoint_D['loss']
loss_D_array = checkpoint_D['loss_D_array']

### Train

In [None]:
try: epoch
except: epoch = 1

try: loss_pxl_array
except: loss_pxl_array = []

try: loss_per_array
except: loss_per_array = []

try: loss_adv_array
except: loss_adv_array = []

try: loss_D_array
except: loss_D_array = []

In [None]:
Tensor = torch.cuda.FloatTensor

n = 1000
alpha_pxl = config.LOSS_WEIGHTS['PIXEL']
alpha_per = config.LOSS_WEIGHTS['PER']
alpha_adv = config.LOSS_WEIGHTS['ADV']

while epoch <= 125:
    
    running_loss_pxl = 0.0
    running_loss_per = 0.0
    running_loss_adv = 0.0
    running_loss_D = 0.0
    
    for i, data in enumerate(trainloader, 0):
        # dataset input
        inputs, gt = data
        inputs = inputs.to(gpu_device)
        gt = gt.to(gpu_device)
        gt_cr = crop(gt, 0, expand_size, output_size, cropped_size)
        # grad to zero
        optimizer_G.zero_grad()
        ###-----------###
        ### GENERATOR ###
        ###-----------###
        outputs = gen(inputs)
        
        valid = torch.ones(outputs.shape[0],1).to(gpu_device)
        fake = torch.zeros(outputs.shape[0],1).to(gpu_device)
        
        
        loss_vgg = Loss_MSE(VGG16_MODULE(crop(outputs, 0, expand_size, output_size, cropped_size)), VGG16_MODULE(gt_cr))
        
        # loss_ssim = 1 - SSIM_MODULE(crop(outputs, 0, expand_size, output_size, cropped_size), gt_cr) # лосс структурный
        if config.PIXEL_LOSS == 'L1':
            loss_pxl = Loss_L1(crop(outputs, 0, expand_size, output_size, cropped_size), gt_cr) # лосс пиксельный
        if config.PIXEL_LOSS == 'MSE':
            loss_pxl = Loss_MSE(crop(outputs, 0, expand_size, output_size, cropped_size), gt_cr) # лосс пиксельный 
        
        if config.PER_LOSS == 'SSIM':
            loss_per = 1 - SSIM_MODULE(crop(outputs, 0, expand_size, output_size, cropped_size), gt_cr) # лосс структурный
        if config.PER_LOSS == 'VGG':
            loss_per = Loss_MSE(VGG16_MODULE(crop(outputs, 0, expand_size, output_size, cropped_size)), VGG16_MODULE(gt_cr))
        
        loss_adv = Loss_BCE(dis(outputs), valid) # лосс от дискриминатора
        loss_G = alpha_pxl*loss_pxl + alpha_per*loss_per + alpha_adv*loss_adv
        
        # шажок генератора
        loss_G.backward()
        optimizer_G.step()
        
        ###---------------###
        ### ДИСКРИМИНАТОР ###
        ###---------------###
        
        #
        optimizer_D.zero_grad()
        
        loss_valid = Loss_BCE(dis(gt), valid)
        loss_fake = Loss_BCE(dis(outputs.detach()), fake)
        loss_D = loss_valid + loss_fake
        
        loss_D.backward()
        optimizer_D.step()
        
        ###-------------###
        ### ЛОГИРОВАНИЕ ###
        ###-------------###
        
        running_loss_pxl += loss_pxl.item()
        running_loss_per += loss_per.item()
        running_loss_adv += loss_adv.item()
        running_loss_D += loss_D.item()
        
        if i % n == n-1:
            print(f'[{epoch}, {i + 1:5d}] loss_pxl: {running_loss_pxl / n:.4f} loss_per: {running_loss_per / n:.3f} loss_adv: {running_loss_adv / n:.3f} loss_D: {running_loss_D / n:.3f}')
            loss_pxl_array.append(running_loss_pxl / n)
            loss_per_array.append(running_loss_vgg / n)
            loss_adv_array.append(running_loss_adv / n)
            loss_D_array.append(running_loss_D / n)
            
            running_loss_pxl = 0.0
            running_loss_per = 0.0
            running_loss_adv = 0.0
            running_loss_D = 0.0
            
            
            
            
            
            
    in_img = transforms.ToPILImage()(torch.squeeze(inputs[0], 0).to(cpu_device))
    gen_img = transforms.ToPILImage()(torch.squeeze(gen(inputs)[0], 0).to(cpu_device))
    gt_img = transforms.ToPILImage()(torch.squeeze(gt[0], 0).to(cpu_device))
            
    fig, axarr = plt.subplots(1,3,figsize=(18,6))
    axarr[0].imshow(in_img)
    axarr[1].imshow(gen_img)
    axarr[2].imshow(gt_img)
            
    fig.savefig('epoch ' + str(epoch).zfill(3) + '.jpg', dpi=50)
    
    
    epoch += 1

print('Done')

### Save

In [None]:
torch.save({
            'epoch': epoch,
            'model_state_dict': gen.state_dict(),
            'optimizer_state_dict': optimizer_G.state_dict(),
            'loss_pxl': loss_pxl,
            'loss_per': loss_per,
            'loss_adv': loss_adv,
            'loss_pxl_array': loss_pxl_array,
            'loss_adv_array': loss_adv_array,
            }, 'gen_model-vgg-'+str(epoch).zfill(3)+'.tar')
torch.save({
            'epoch': epoch,
            'model_state_dict': dis.state_dict(),
            'optimizer_state_dict': optimizer_D.state_dict(),
            'loss': loss_D,
            'loss_D_array': loss_D_array,
            }, 'dis_model-'+str(epoch).zfill(3)+'.tar')

### Show

In [None]:
img_name =  random.choice(os.listdir('dataset/val/cropped')) # 'badlands00000002.jpg'
cropped_img_path = 'dataset/val/cropped/'+img_name
gt_img_path = 'dataset/val/gt/'+img_name

in_img = Image.open(cropped_img_path)
gen_img = generate(cropped_img_path)
gt_img = Image.open(gt_img_path)
     
fig, axarr = plt.subplots(1,3,figsize=(18,6))
axarr[0].imshow(in_img)
axarr[1].imshow(gen_img)
axarr[2].imshow(gt_img)

# fig.savefig('test.jpg', dpi=50)

In [None]:
plt.scatter(range(len(loss_pxl_array)), loss_pxl_array)
plt.show()