In [2]:
import sys
sys.path.append('../')

import os
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import autograd

import torch
import torch.nn as nn


from utils.tools import get_config, default_loader, is_image_file, normalize
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# personal library
from networks import autoencoder, simulator, discriminator
from dataloader import MVTecDataset

In [3]:
# HYPER parameters
device_id = 0
num_epochs = 50000
batch_size = 32
val_batch_size = 4
ae_lr = 1e-4
weight_decay = 1e-5
expName = 'AE-capsule-z-2x2-exp2'
writer = SummaryWriter('checkpoints/'+expName)

In [4]:
trainDatset = MVTecDataset.MVTecDataset(TYPE='capsule', isTrain='train')
testDatset = MVTecDataset.MVTecDataset(TYPE='capsule', isTrain='test')
valDataset = MVTecDataset.MVTecDataset(TYPE='capsule', isTrain='val')

val_loader = DataLoader(
    dataset=valDataset,
    batch_size=val_batch_size, 
    shuffle=True,
    num_workers=4
)
train_loader = DataLoader(
    dataset=trainDatset,
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4
)
test_loader = DataLoader(
    dataset=testDatset,
    batch_size=1,
    shuffle=True,
    num_workers=4
)

In [5]:
AE = autoencoder.Autoencoder().cuda(device_id)
L1_loss = torch.nn.L1Loss()

optimizer_AE = torch.optim.Adam(
    AE.parameters(), 
    lr=ae_lr,
    weight_decay=weight_decay
)

Tensor = torch.cuda.FloatTensor

In [6]:
# To Solve: RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
torch.backends.cudnn.enabled = False 

In [None]:
for epoch in range(num_epochs):
    ## ===== AE ======= 
    for img in train_loader:
        AE.train()

        img = Variable(img).cuda(device_id)    
        # ======AE======
        blur_image = AE(img)
        ae_loss = L1_loss(blur_image, img)

        optimizer_AE.zero_grad()
        ae_loss.backward()
        optimizer_AE.step()
    
    # Validation set
    for val_img in val_loader:
        AE.eval()

        val_img = Variable(val_img).cuda(device_id)    
        # ======AE======
        val_blur_image = AE(val_img)
        val_ae_loss = L1_loss(val_blur_image, val_img)
    
    for (test_img, mask) in test_loader:
        AE.eval()

        test_img = Variable(test_img).cuda(device_id)    
        # ======AE======
        test_blur_image = AE(test_img)
        test_ae_loss = L1_loss(test_blur_image, test_img)

    # =================== AE log========================
    print('epoch [{}/{}] ae_loss:{:.4f} val_ae_loss:{:.4f} test_ae_loss:{:.4f}'.format(epoch+1, num_epochs, ae_loss.item(), val_ae_loss.item(), test_ae_loss.item()))
    writer.add_scalars('ae_loss', {
        'train': ae_loss.item(),
        'test': test_ae_loss.item(),
        'val': val_ae_loss.item()
    }, epoch)
    
    writer.add_images('Blur', blur_image, epoch)
    writer.add_images('Val Blur', val_blur_image, epoch)
    writer.add_images('Test Blur', test_blur_image, epoch)
    writer.add_images('Origin', img, epoch)
    writer.add_images('Val Origin', val_img, epoch)
    writer.add_images('Test Origin', test_img, epoch)
    
    if epoch % 10 == 0:
        if not os.path.exists('./save_weight/{}'.format(expName)):
            os.makedirs('./save_weight/{}'.format(expName))
        torch.save(AE.state_dict(), './save_weight/{}/AE_{}.npy'.format(expName, epoch))

epoch [1/50000] ae_loss:0.2814 val_ae_loss:0.2768 test_ae_loss:0.2913
epoch [2/50000] ae_loss:0.2732 val_ae_loss:0.2806 test_ae_loss:0.2944
epoch [3/50000] ae_loss:0.2721 val_ae_loss:0.2840 test_ae_loss:0.2899
epoch [4/50000] ae_loss:0.2623 val_ae_loss:0.2761 test_ae_loss:0.2913
epoch [5/50000] ae_loss:0.2631 val_ae_loss:0.2715 test_ae_loss:0.2850
epoch [6/50000] ae_loss:0.2820 val_ae_loss:0.2714 test_ae_loss:0.2752
epoch [7/50000] ae_loss:0.2466 val_ae_loss:0.2443 test_ae_loss:0.2653
epoch [8/50000] ae_loss:0.2471 val_ae_loss:0.2574 test_ae_loss:0.2778
epoch [9/50000] ae_loss:0.2389 val_ae_loss:0.2458 test_ae_loss:0.2442
epoch [10/50000] ae_loss:0.2442 val_ae_loss:0.2382 test_ae_loss:0.2505
epoch [11/50000] ae_loss:0.2380 val_ae_loss:0.2653 test_ae_loss:0.2464
epoch [12/50000] ae_loss:0.2444 val_ae_loss:0.2068 test_ae_loss:0.2487
epoch [13/50000] ae_loss:0.2273 val_ae_loss:0.2296 test_ae_loss:0.2385
epoch [14/50000] ae_loss:0.2269 val_ae_loss:0.2294 test_ae_loss:0.2325
epoch [15/50000

In [17]:
AE.eval()
S.eval()
for index, img in enumerate(test_loader):
    test_img = Variable(img[0]).cuda(device_id)

    # ======AE======
    blur_image = AE(test_img)
    
    noise = torch.zeros(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3] )
    noise = noise + (0.01**0.5)*torch.randn(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3])
    noise = noise.cuda(device_id)
    blur_image_with_noise = torch.cat([blur_image, noise], 1)
    fake_image = S(blur_image_with_noise)
    
    
    vutils.save_image(fake_image[0], './test_result/{}_simulated.png'.format(index))
    vutils.save_image(blur_image[0], './test_result/{}_blur.png'.format(index))
    vutils.save_image(test_img, './test_result/{}_origin.png'.format(index))