In [1]:
import sys
sys.path.append('../')

import os
import time
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import autograd
from torch.nn import functional as F

import torch
import torch.nn as nn


from utils.tools import get_config, default_loader, is_image_file, normalize
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import torchvision.utils as vutils

from sklearn.metrics import roc_auc_score
sys.path.append('../PerceptualSimilarity')
import models as PerceptualSimilarity

# personal library
from networks import autoencoder, simulator, discriminator
from dataloader import MVTecDataset

In [2]:
# 限制可以使用的 GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

In [3]:
# HYPER parameters
num_epochs = 50000
batch_size = 32
val_batch_size = 4
ae_lr = 1e-4
s_lr = 5e-4
d_lr = 5e-4
weight_decay = 1e-5
UPSET=True
expName = 'AEGAN-exp(wood + L1 + gradient)'
writer = SummaryWriter('checkpoints/'+expName)
TYPE='wood'

In [4]:
trainDatset = MVTecDataset.MVTecDataset(TYPE=TYPE, isTrain='train')
testDatset = MVTecDataset.MVTecDataset(TYPE=TYPE, isTrain='test')
valDataset = MVTecDataset.MVTecDataset(TYPE=TYPE, isTrain='val')

val_loader = DataLoader(
    dataset=valDataset,
    batch_size=val_batch_size, 
    shuffle=True,
    num_workers=4
)
train_loader = DataLoader(
    dataset=trainDatset,
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4
)
test_loader = DataLoader(
    dataset=testDatset,
    batch_size=1,
    shuffle=True,
    num_workers=4
)

In [5]:
# Model
AE = autoencoder.Autoencoder().cuda()
S = nn.DataParallel(simulator.Simulator(3, 8)).cuda()
D = nn.DataParallel(discriminator.Discriminator(6, 16)).cuda()

# Loss
L1_loss = nn.L1Loss()
L2_loss = nn.MSELoss(reduction='none')
perceptual_loss = PerceptualSimilarity.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0])

# Optimizer
optimizer_AE = torch.optim.Adam(
    AE.parameters(), 
    lr=ae_lr,
    weight_decay=weight_decay
)
optimizer_S = torch.optim.Adam(
    S.parameters(), 
    lr=s_lr,
    weight_decay=weight_decay
)
optimizer_D = torch.optim.Adam(
    D.parameters(), 
    lr=d_lr,
    weight_decay=weight_decay
)


Tensor = torch.cuda.FloatTensor

Setting up Perceptual loss...
Loading model from: /root/AFS/Corn/AEGAN/PerceptualSimilarity/models/weights/v0.1/alex.pth
...[net-lin [alex]] initialized
...Done


In [6]:
AE.load_state_dict(torch.load('./save_weight/AE-wood-z-2x2-exp2/AE_2500.npy', map_location="cuda:0"), False)

AE = nn.DataParallel(AE)

In [7]:
# To Solve: RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
torch.backends.cudnn.enabled = False 

# 拿掉煩人的 warning
import warnings
warnings.filterwarnings("ignore")

In [8]:
LAMBDA = 10

def calc_gradient_penalty(netD, real_data, fake_data):
    # print "real_data: ", real_data.size(), fake_data.size()
    BATCH_SIZE = real_data.size(0)
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(BATCH_SIZE, real_data.nelement()//BATCH_SIZE).contiguous().view(BATCH_SIZE, 6, 256, 256)
    alpha = alpha.cuda()

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = interpolates.cuda()
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(
        outputs=disc_interpolates, 
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
        create_graph=True, 
        retain_graph=True, 
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

def gradient_loss(gen_frames, gt_frames, alpha=1):
    def gradient(x):
        h_x = x.size()[-2]
        w_x = x.size()[-1]
        # gradient step=1
        left = x
        right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        # dx, dy = torch.abs(right - left), torch.abs(bottom - top)
        dx, dy = right - left, bottom - top 
        # dx will always have zeros in the last column, right-left
        # dy will always have zeros in the last row,    bottom-top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return dx, dy

    # gradient
    gen_dx, gen_dy = gradient(gen_frames)
    gt_dx, gt_dy = gradient(gt_frames)
    #
    grad_diff_x = torch.abs(gt_dx - gen_dx)
    grad_diff_y = torch.abs(gt_dy - gen_dy)

    # condense into one tensor and avg
    return torch.mean(grad_diff_x ** alpha + grad_diff_y ** alpha)

def difNormalize(input_matrix, threshold=None):
    _min = torch.min(input_matrix)
    _max = torch.max(input_matrix)
    
    input_matrix = (input_matrix - _min) / (_max - _min)
    
    if threshold != None:
        input_matrix[input_matrix < threshold] = 0
        input_matrix[input_matrix >= threshold] = 1
        
    return input_matrix

In [None]:
for epoch in range(num_epochs): 
    start = time.time()
    ######## GAN ################
    one = torch.FloatTensor([1])
    mone = one * -1
    
    one = one.cuda()
    mone = mone.cuda()
    
    one = one.mean()
    mone = mone.mean()
    ## ==== GAN --> D =====
    for i in range(3):
        for index, img in enumerate(train_loader):
            AE.eval(), S.train(), D.train()

            img = Variable(img).cuda()

            # ====== AE ======
            blur_image = AE(img)

            _bs, _c, _w, _h = blur_image.shape
            noise = torch.zeros(_bs, 1, _w, _h )
            noise = noise + (0.01**0.5)*torch.randn(_bs, 1, _w, _h)
            noise = noise.cuda()

            blur_image_with_noise = torch.cat([blur_image, noise], 1)
            fake_image = S(blur_image_with_noise) # 當成是 residual
            
            fake_image = fake_image + blur_image # blur image + residual
            
            fake_pair = torch.cat([img, fake_image], 1)
            real_pair = torch.cat([img, img[torch.randperm(img.size(0)), :, :, :]], 1) if UPSET else torch.cat([img, img], 1)
            # ====== Train D ======
            for p in D.parameters():
                p.requires_grad = True

            optimizer_AE.zero_grad()
            optimizer_S.zero_grad()
            optimizer_D.zero_grad()


            real_D = D(real_pair)
            real_D = real_D.mean()
            real_D.backward(mone)


            fake_D = D(fake_pair)
            fake_D = fake_D.mean()
            fake_D.backward(one)

            gradient_penalty = calc_gradient_penalty(D, real_pair, fake_pair)
            gradient_penalty.backward()

            cost_D = fake_D - real_D + gradient_penalty
            Wasserstein_D = real_D - fake_D
            optimizer_D.step()
    
    ## ==== GAN --> G =====
    for index, img in enumerate(train_loader):
        AE.eval(), S.train(), D.train()

        img = Variable(img).cuda()
        # ======AE======
        blur_image = AE(img)

        _bs, _c, _w, _h = blur_image.shape
        noise = torch.zeros(_bs, 1, _w, _h )
        noise = noise + (0.01**0.5)*torch.randn(_bs, 1, _w, _h)
        noise = noise.cuda()

        blur_image_with_noise = torch.cat([blur_image, noise], 1)
        fake_image = S(blur_image_with_noise)
        
        fake_image = fake_image + blur_image
        
        fake_pair = torch.cat([img, fake_image], 1)
        # ====== Train G ======
        for p in D.parameters():
            p.requires_grad = False
        
        G_L1 = L1_loss(img, fake_image)
        grad_loss = gradient_loss(img, fake_image)
        G = D(fake_pair)
        G = G.mean()
        G = G + G_L1 * ( G / G_L1) + grad_loss
        
        optimizer_AE.zero_grad()
        optimizer_S.zero_grad()
        optimizer_D.zero_grad()
        
        G.backward(mone)
        
        cost_G = -G
        optimizer_S.step()
        
    
    # validation set
    for index, val_img in enumerate(val_loader):
        AE.eval(), S.eval(), D.eval()

        val_img = Variable(val_img).cuda()
        # ======AE======
        val_blur_image = AE(val_img)

        _bs, _c, _w, _h = val_blur_image.shape
        noise = torch.zeros(_bs, 1, _w, _h )
        noise = noise + (0.01**0.5)*torch.randn(_bs, 1, _w, _h)
        noise = noise.cuda()

        val_blur_image_with_noise = torch.cat([val_blur_image, noise], 1)

        val_fake_image = S(val_blur_image_with_noise)       
        val_fake_image = val_fake_image + val_blur_image
        
        
        val_fake_pair = torch.cat([val_img, val_fake_image], 1)
        val_real_pair = torch.cat([val_img, val_img[torch.randperm(val_img.size(0)), :, :, :]], 1) if UPSET else torch.cat([val_img, val_img], 1)
        
        val_real_D = D(val_real_pair)
        val_real_D = val_real_D.mean()
        
        val_fake_D = D(val_fake_pair)
        val_fake_D = val_fake_D.mean()
        
        val_gradient_penalty = calc_gradient_penalty(D, val_real_pair, val_fake_pair)
        
        val_G_L1 = L1_loss(val_img, val_fake_image)
        val_grad_loss = gradient_loss(val_img, val_fake_image)
        # =========== Losses =========
        val_Wasserstein_D = val_real_D - val_fake_D
        
        val_cost_G = -val_fake_D
        val_cost_D = val_fake_D - val_real_D + val_gradient_penalty
    
    # evaluate
    test_total_AUC = 0
    test_total_AUC2 = 0
    test_total_image = 0

    for index, (test_img, mask) in enumerate(test_loader):
        AE.eval(), S.eval(), D.eval()
        test_img = Variable(test_img).cuda()
        test_blur_image = AE(test_img)

        _bs, _c, _w, _h = test_blur_image.shape
        noise = torch.zeros(_bs, 1, _w, _h )
        noise = noise + (0.01**0.5)*torch.randn(_bs, 1, _w, _h)
        noise = noise.cuda()

        test_blur_image_with_noise = torch.cat([test_blur_image, noise], 1)

        test_fake_image = S(test_blur_image_with_noise)       
        test_fake_image = test_fake_image + test_blur_image

        # 計算 dif (相似度以及 L2)
        dif, _ = perceptual_loss.forward(test_fake_image, test_img)
        l2Dif = L2_loss(test_fake_image, test_img)
        l2Dif = torch.mean(l2Dif, 1, True)
        
        pred_mask2 = difNormalize(dif)
        pred_mask2 = torch.flatten(pred_mask2[0])
        
        pred_mask = difNormalize(dif[0] * l2Dif[0])
        pred_mask = torch.flatten(pred_mask)
        
        mask = torch.mean(mask, 1, True)
        true_mask = mask[0].cpu().detach().numpy().flatten()
        true_mask = true_mask.astype(int)

        AUC = roc_auc_score(true_mask, pred_mask.cpu().detach().numpy())
        AUC2 = roc_auc_score(true_mask, pred_mask2.cpu().detach().numpy())

        test_total_AUC += AUC
        test_total_AUC2 += AUC2
        test_total_image += 1
    
    # =================== GAN log========================
    end = time.time()
    print('epoch [{}/{}] s_loss:{:.4f} d_loss:{:.4f} val_s_loss:{:.4f} val_d_loss:{:.4f} cost:{:.2f}'.format(epoch+1, num_epochs, cost_G.item(), cost_D.item(), val_cost_G.item(), val_cost_D.item(), end-start ))
    writer.add_scalars('eval', {
        "auc_roc_score": test_total_AUC / test_total_image,
        "auc_roc_score(w/o L2)": test_total_AUC2 / test_total_image,
    }, epoch)
    
    writer.add_scalars('loss', {
        "Wasserstein Distance": Wasserstein_D.item(),
        "Val Wasserstein Distance": val_Wasserstein_D.item(),
        "gradient penalty": gradient_penalty,
        "val gradient penalty": val_gradient_penalty
    }, epoch)
    
    writer.add_scalars('gan loss', {
        "l1_loss": G_L1.item(),
        "g_loss": cost_G.item(),
        "d_loss": cost_D.item(),
        "gradient_loss": grad_loss.item(),
        "val_l1_loss": val_G_L1.item(),
        "val_g_loss": val_cost_G.item(),
        "val_d_loss": val_cost_D.item(),
        "val_gradient_loss": val_grad_loss.item(),
    }, epoch)

    writer.add_images('Blur', blur_image, epoch)
    writer.add_images('Reconstruct', fake_image, epoch)
    writer.add_images('Origin', img, epoch)

    writer.add_images('Val Blur', val_blur_image, epoch)
    writer.add_images('Val Reconstruct', val_fake_image, epoch)
    writer.add_images('Val Origin', val_img, epoch)


        
    if epoch % 10 == 0:
        if not os.path.exists('./save_weight/{}'.format(expName)):
            os.makedirs('./save_weight/{}'.format(expName))
        torch.save(S.state_dict(), './save_weight/{}/S_{}.npy'.format(expName, epoch))
        torch.save(D.state_dict(), './save_weight/{}/D_{}.npy'.format(expName, epoch))

epoch [1/50000] s_loss:1912.1523 d_loss:-278.1466 val_s_loss:1213.3286 val_d_loss:-225.7886 cost:26.74
epoch [2/50000] s_loss:1849.9199 d_loss:-181.2305 val_s_loss:795.5139 val_d_loss:59.6611 cost:22.13
epoch [3/50000] s_loss:1016.2395 d_loss:-187.1540 val_s_loss:545.4919 val_d_loss:-85.8004 cost:21.34
epoch [4/50000] s_loss:1561.7660 d_loss:-222.8959 val_s_loss:690.3129 val_d_loss:-88.0677 cost:21.55
epoch [5/50000] s_loss:1429.7407 d_loss:-333.0352 val_s_loss:735.8066 val_d_loss:-208.0760 cost:22.57
epoch [6/50000] s_loss:1331.9973 d_loss:-287.2206 val_s_loss:1054.0688 val_d_loss:-572.6839 cost:25.44
epoch [7/50000] s_loss:814.6937 d_loss:-209.0937 val_s_loss:1078.2573 val_d_loss:-625.2091 cost:23.04
epoch [8/50000] s_loss:859.4924 d_loss:-127.0710 val_s_loss:437.9287 val_d_loss:-43.8189 cost:24.72
epoch [9/50000] s_loss:742.3251 d_loss:-111.7458 val_s_loss:341.1018 val_d_loss:-36.7675 cost:23.39
epoch [10/50000] s_loss:593.0656 d_loss:-60.3526 val_s_loss:236.5579 val_d_loss:-31.5240

In [None]:
AE.eval()
S.eval()
for index, img in enumerate(test_loader):
    test_img = Variable(img[0]).cuda()

    # ======AE======
    blur_image = AE(test_img)
    
    noise = torch.zeros(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3] )
    noise = noise + (0.01**0.5)*torch.randn(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3])
    noise = noise.cuda()
    blur_image_with_noise = torch.cat([blur_image, noise], 1)
    fake_image = S(blur_image_with_noise)
    
    
    vutils.save_image(fake_image[0], './test_result/{}_simulated.png'.format(index))
    vutils.save_image(blur_image[0], './test_result/{}_blur.png'.format(index))
    vutils.save_image(test_img, './test_result/{}_origin.png'.format(index))