In [1]:
import sys
sys.path.append('../')

import os
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import autograd

import torch
import torch.nn as nn


from utils.tools import get_config, default_loader, is_image_file, normalize
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# personal library
from networks import autoencoder, simulator, discriminator
from dataloader import MVTecDataset

In [2]:
# HYPER parameters
device_id = 1
num_epochs = 50000
batch_size = 32
val_batch_size = 4
ae_lr = 1e-4
weight_decay = 1e-5
expName = 'AE-wood-z-2x2-exp2'
writer = SummaryWriter('checkpoints/'+expName)

In [3]:
trainDatset = MVTecDataset.MVTecDataset(TYPE='wood', isTrain='train')
testDatset = MVTecDataset.MVTecDataset(TYPE='wood', isTrain='test')
valDataset = MVTecDataset.MVTecDataset(TYPE='wood', isTrain='val')

val_loader = DataLoader(
    dataset=valDataset,
    batch_size=val_batch_size, 
    shuffle=True,
    num_workers=4
)
train_loader = DataLoader(
    dataset=trainDatset,
    batch_size=batch_size, 
    shuffle=True,
    num_workers=4
)
test_loader = DataLoader(
    dataset=testDatset,
    batch_size=1,
    shuffle=True,
    num_workers=4
)

In [4]:
AE = autoencoder.Autoencoder().cuda(device_id)
L1_loss = torch.nn.L1Loss()

optimizer_AE = torch.optim.Adam(
    AE.parameters(), 
    lr=ae_lr,
    weight_decay=weight_decay
)

Tensor = torch.cuda.FloatTensor

In [5]:
# To Solve: RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
torch.backends.cudnn.enabled = False 

In [6]:
for epoch in range(num_epochs):
    ## ===== AE ======= 
    for img in train_loader:
        AE.train()

        img = Variable(img).cuda(device_id)    
        # ======AE======
        blur_image = AE(img)
        ae_loss = L1_loss(blur_image, img)

        optimizer_AE.zero_grad()
        ae_loss.backward()
        optimizer_AE.step()
    
    # Validation set
    for val_img in val_loader:
        AE.eval()

        val_img = Variable(val_img).cuda(device_id)    
        # ======AE======
        val_blur_image = AE(val_img)
        val_ae_loss = L1_loss(val_blur_image, val_img)
    
    for (test_img, mask) in test_loader:
        AE.eval()

        test_img = Variable(test_img).cuda(device_id)    
        # ======AE======
        test_blur_image = AE(test_img)
        test_ae_loss = L1_loss(test_blur_image, test_img)

    # =================== AE log========================
    print('epoch [{}/{}] ae_loss:{:.4f} val_ae_loss:{:.4f} test_ae_loss:{:.4f}'.format(epoch+1, num_epochs, ae_loss.item(), val_ae_loss.item(), test_ae_loss.item()))
    writer.add_scalars('ae_loss', {
        'train': ae_loss.item(),
        'test': test_ae_loss.item(),
        'val': val_ae_loss.item()
    }, epoch)
    
    writer.add_images('Blur', blur_image, epoch)
    writer.add_images('Val Blur', val_blur_image, epoch)
    writer.add_images('Test Blur', test_blur_image, epoch)
    writer.add_images('Origin', img, epoch)
    writer.add_images('Val Origin', val_img, epoch)
    writer.add_images('Test Origin', test_img, epoch)
    
    if epoch % 10 == 0:
        if not os.path.exists('./save_weight/{}'.format(expName)):
            os.makedirs('./save_weight/{}'.format(expName))
        torch.save(AE.state_dict(), './save_weight/{}/AE_{}.npy'.format(expName, epoch))

epoch [1/50000] ae_loss:0.1578 val_ae_loss:0.1346 test_ae_loss:0.1508
epoch [2/50000] ae_loss:0.1530 val_ae_loss:0.1346 test_ae_loss:0.1268
epoch [3/50000] ae_loss:0.1459 val_ae_loss:0.1326 test_ae_loss:0.1356
epoch [4/50000] ae_loss:0.1447 val_ae_loss:0.1339 test_ae_loss:0.1365
epoch [5/50000] ae_loss:0.1422 val_ae_loss:0.1245 test_ae_loss:0.1271
epoch [6/50000] ae_loss:0.1416 val_ae_loss:0.1263 test_ae_loss:0.1280
epoch [7/50000] ae_loss:0.1390 val_ae_loss:0.1292 test_ae_loss:0.1317
epoch [8/50000] ae_loss:0.1379 val_ae_loss:0.1269 test_ae_loss:0.1408
epoch [9/50000] ae_loss:0.1370 val_ae_loss:0.1216 test_ae_loss:0.1326
epoch [10/50000] ae_loss:0.1348 val_ae_loss:0.1228 test_ae_loss:0.1219
epoch [11/50000] ae_loss:0.1341 val_ae_loss:0.1287 test_ae_loss:0.1905
epoch [12/50000] ae_loss:0.1360 val_ae_loss:0.1356 test_ae_loss:0.1219
epoch [13/50000] ae_loss:0.1324 val_ae_loss:0.1199 test_ae_loss:0.1212
epoch [14/50000] ae_loss:0.1319 val_ae_loss:0.1195 test_ae_loss:0.1441
epoch [15/50000

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-942a09966249>", line 12, in <module>
    ae_loss.backward()
  File "/opt/conda/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2044, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception 

TypeError: object of type 'NoneType' has no len()

In [17]:
AE.eval()
S.eval()
for index, img in enumerate(test_loader):
    test_img = Variable(img[0]).cuda(device_id)

    # ======AE======
    blur_image = AE(test_img)
    
    noise = torch.zeros(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3] )
    noise = noise + (0.01**0.5)*torch.randn(blur_image.shape[0], 1, blur_image.shape[2], blur_image.shape[3])
    noise = noise.cuda(device_id)
    blur_image_with_noise = torch.cat([blur_image, noise], 1)
    fake_image = S(blur_image_with_noise)
    
    
    vutils.save_image(fake_image[0], './test_result/{}_simulated.png'.format(index))
    vutils.save_image(blur_image[0], './test_result/{}_blur.png'.format(index))
    vutils.save_image(test_img, './test_result/{}_origin.png'.format(index))