In [3]:
import sys
sys.path.append('../')

import os
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import autograd

import torch
import torch.nn as nn


from utils.tools import get_config, default_loader, is_image_file, normalize
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# personal library
from networks import autoencoder, simulator, discriminator
from dataloader import MVTecDataset

In [4]:
# HYPER parameters
device_id = 3
num_epochs = 50000
batch_size = 32
val_batch_size = 4
ae_lr = 1e-4
weight_decay = 1e-5
expName = 'AE-cable-z-2x2-exp1'
writer = SummaryWriter('checkpoints/'+expName)

In [5]:
trainDatset = MVTecDataset.MVTecDataset(TYPE='cable', isTrain='train')
testDatset = MVTecDataset.MVTecDataset(TYPE='cable', isTrain='test')
valDataset = MVTecDataset.MVTecDataset(TYPE='cable', isTrain='val')

val_loader = DataLoader(
    dataset=valDataset,
    batch_size=val_batch_size, 
    shuffle=True,
    num_workers=4
)
train_loader = DataLoader(
    dataset=trainDatset,
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4
)
test_loader = DataLoader(
    dataset=testDatset,
    batch_size=1,
    shuffle=True,
    num_workers=4
)

In [6]:
AE = autoencoder.Autoencoder().cuda(device_id)
L1_loss = torch.nn.L1Loss()

optimizer_AE = torch.optim.Adam(
    AE.parameters(), 
    lr=ae_lr,
    weight_decay=weight_decay
)

Tensor = torch.cuda.FloatTensor

In [7]:
# To Solve: RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
torch.backends.cudnn.enabled = False 

In [None]:
for epoch in range(num_epochs):
    ## ===== AE ======= 
    for img in train_loader:
        AE.train()

        img = Variable(img).cuda(device_id)    
        # ======AE======
        blur_image = AE(img)
        ae_loss = L1_loss(blur_image, img)

        optimizer_AE.zero_grad()
        ae_loss.backward()
        optimizer_AE.step()
    
    # Validation set
    for val_img in val_loader:
        AE.eval()

        val_img = Variable(val_img).cuda(device_id)    
        # ======AE======
        val_blur_image = AE(val_img)
        val_ae_loss = L1_loss(val_blur_image, val_img)
    
    for (test_img, mask) in test_loader:
        AE.eval()

        test_img = Variable(test_img).cuda(device_id)    
        # ======AE======
        test_blur_image = AE(test_img)
        test_ae_loss = L1_loss(test_blur_image, test_img)

    # =================== AE log========================
    print('epoch [{}/{}] ae_loss:{:.4f} val_ae_loss:{:.4f} test_ae_loss:{:.4f}'.format(epoch+1, num_epochs, ae_loss.item(), val_ae_loss.item(), test_ae_loss.item()))
    writer.add_scalars('ae_loss', {
        'train': ae_loss.item(),
        'test': test_ae_loss.item(),
        'val': val_ae_loss.item()
    }, epoch)
    
    writer.add_images('Blur', blur_image, epoch)
    writer.add_images('Val Blur', val_blur_image, epoch)
    writer.add_images('Test Blur', test_blur_image, epoch)
    writer.add_images('Origin', img, epoch)
    writer.add_images('Val Origin', val_img, epoch)
    writer.add_images('Test Origin', test_img, epoch)
    
    if epoch % 10 == 0:
        if not os.path.exists('./save_weight/{}'.format(expName)):
            os.makedirs('./save_weight/{}'.format(expName))
        torch.save(AE.state_dict(), './save_weight/{}/AE_{}.npy'.format(expName, epoch))

epoch [1/50000] ae_loss:0.2332 val_ae_loss:0.1996 test_ae_loss:0.2190
epoch [2/50000] ae_loss:0.2297 val_ae_loss:0.2018 test_ae_loss:0.1969
epoch [3/50000] ae_loss:0.2138 val_ae_loss:0.2001 test_ae_loss:0.1952
epoch [4/50000] ae_loss:0.2113 val_ae_loss:0.2146 test_ae_loss:0.1927
epoch [5/50000] ae_loss:0.2197 val_ae_loss:0.2079 test_ae_loss:0.2054
epoch [6/50000] ae_loss:0.2160 val_ae_loss:0.1929 test_ae_loss:0.1982
epoch [7/50000] ae_loss:0.2151 val_ae_loss:0.2139 test_ae_loss:0.1967
epoch [8/50000] ae_loss:0.2047 val_ae_loss:0.2058 test_ae_loss:0.1994
epoch [9/50000] ae_loss:0.2044 val_ae_loss:0.2073 test_ae_loss:0.1998
epoch [10/50000] ae_loss:0.2036 val_ae_loss:0.2090 test_ae_loss:0.1994
epoch [11/50000] ae_loss:0.2028 val_ae_loss:0.2050 test_ae_loss:0.1861
epoch [12/50000] ae_loss:0.2071 val_ae_loss:0.1889 test_ae_loss:0.1928
epoch [13/50000] ae_loss:0.2063 val_ae_loss:0.2011 test_ae_loss:0.1979
epoch [14/50000] ae_loss:0.2145 val_ae_loss:0.1938 test_ae_loss:0.2082
epoch [15/50000

In [17]:
AE.eval()
S.eval()
for index, img in enumerate(test_loader):
    test_img = Variable(img[0]).cuda(device_id)

    # ======AE======
    blur_image = AE(test_img)
    
    noise = torch.zeros(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3] )
    noise = noise + (0.01**0.5)*torch.randn(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3])
    noise = noise.cuda(device_id)
    blur_image_with_noise = torch.cat([blur_image, noise], 1)
    fake_image = S(blur_image_with_noise)
    
    
    vutils.save_image(fake_image[0], './test_result/{}_simulated.png'.format(index))
    vutils.save_image(blur_image[0], './test_result/{}_blur.png'.format(index))
    vutils.save_image(test_img, './test_result/{}_origin.png'.format(index))