In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import models
import torch.nn.functional as F
from torch.nn import init
import torchvision as tv
from torchsummary import summary
from torch.autograd import Variable
from pytorch_msssim import SSIM
from torch.optim import lr_scheduler


import datetime
import time
from tqdm import tqdm 

import numpy as np
from PIL import Image
import os
import datetime
import time
import pandas as pd


import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt 
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
class Config(object):
    def __init__(self):
        self.name = 'vae_kl_100'
        self.dataset_name = 'mimic'
        self.dataroot ='../DATASET/images'
        
        self.save_path = './checkpoint/' + self.name
        self.model_path = self.save_path + '/models'
        self.decode_path = self.save_path + '/decoded_results'
        self.val_path = self.save_path + '/val_results'
        self.test_path = self.save_path + '/test_results'
        
        self.num_threads = 8
        self.shuffle_dataset=True
        self.random_seed=24


        self.lr = 0.00003     
        
        self.serial_batches = False
        self.phase='train'
        
        self.train_batch_size = 6
        self.val_batch_size = 6
        self.test_batch_size = 1
        self.max_epochs = 500
        self.save_every = 1     #epoch
        self.plot_every = 1     # epoch to save decoded images

        os.makedirs(self.save_path, exist_ok=True)
        os.makedirs(self.model_path, exist_ok=True)
        os.makedirs(self.decode_path, exist_ok=True)
        os.makedirs(self.val_path, exist_ok=True)
        os.makedirs(self.test_path, exist_ok=True)
opt = Config()

In [None]:
normalize = transforms.Normalize([0.546, 0.265, 0.406],

transform={'train':transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.Grayscale(num_output_channels=1),
                               transforms.RandomHorizontalFlip(),
                               transforms.Lambda
                                (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                               transforms.Lambda
                                (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                             ]),
            'val':transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.Grayscale(num_output_channels=1),
                               transforms.Lambda
                                 (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                               transforms.Lambda
                                 (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                               transforms.ToTensor(),
                             ])
          }
                                      
class Data(data.Dataset):
    def __init__(self,transform=None,phase=None):
        self.transform = transform
        self.phase = phase
        
        train_file   = "../DATASET/full_chex/train.csv"
        val_file   = "../DATASET/full_chex/val.csv"
        test_file   = "../DATASET/full_chex/test.csv"
        train_df = pd.read_csv(train_file)
        val_df   = pd.read_csv(val_file)
        test_df  = pd.read_csv(test_file)
        
        self.train_labels=train_df.set_index('id')['Labels'].to_dict()
        self.val_labels=val_df.set_index('id')['Labels'].to_dict()
        self.test_labels=test_df.set_index('id')['Labels'].to_dict()
        
        self.train_images= list(self.train_labels.keys()) 
        self.val_images= list(self.val_labels.keys())
        self.test_images= list(self.test_labels.keys())
        
                
    def __len__(self):
        if self.phase == 'train':
            return len(self.train_images)
        if self.phase == 'val':
            return len(self.val_images)
        if self.phase == 'test':
            return len(self.test_images)
    
    def __getitem__(self, index):
        if self.phase == 'train':
            file = self.train_images[index]
            path = os.path.join(opt.dataroot,file)
            image = Image.open(path)

            if self.transform is not None:
                image = self.transform['train'](image)
                
            for key in self.train_labels.keys():
                if file in key:
                    label= self.train_labels[key]
                    break
        
            label = torch.LongTensor([label])
    #         label = torch.FloatTensor([label])
            return {'img': image, 'label':label }
    
        if self.phase == 'val':
            file = self.val_images[index]
            path = os.path.join(opt.dataroot,file)
            image = Image.open(path)

            if self.transform is not None:
                image = self.transform['val'](image)
                
            for key in self.val_labels.keys():
                if file in key:
                    label= self.val_labels[key]
                    break
        
            label = torch.LongTensor([label])
    #         label = torch.FloatTensor([label])
            return {'img': image, 'label':label }
    
        if self.phase == 'test':
            file = self.test_images[index]
            path = os.path.join(opt.dataroot,file)
            image = Image.open(path)

            if self.transform is not None:
                image = self.transform['val'](image)
                
            for key in self.test_labels.keys():
                if file in key:
                    label= self.test_labels[key]
                    break
        
            label = torch.LongTensor([label])
    #         label = torch.FloatTensor([label])
            return {'img': image, 'label':label }


    
opt.phase='train'
image_dataset=Data(transform,opt.phase)
    
train_loader = data.DataLoader(image_dataset,batch_size=opt.train_batch_size,num_workers=opt.num_threads,shuffle=True,pin_memory= False)
print(f'{len(image_dataset)} images loaded under train')

opt.phase='val'
val_dataset=Data(transform,opt.phase)                        
val_loader = data.DataLoader(val_dataset,batch_size=opt.val_batch_size,num_workers=opt.num_threads,shuffle=True,pin_memory= False)
print(f'{len(val_dataset)} images loaded under Validation')

opt.phase='test'
test_dataset=Data(transform,opt.phase)                        
test_loader = data.DataLoader(test_dataset,batch_size=opt.test_batch_size,num_workers=opt.num_threads,shuffle=False,pin_memory= False)
print(f'{len(test_dataset)} images loaded under test')

In [None]:
# mean = 0.
# std = 0.
# nb_samples = 0.
# for data_ in tqdm(train_loader):
#     images = data_['img']
#     batch_samples = images.size(0)
#     images = images.view(batch_samples, images.size(1), -1)
    
#     mean += images.mean(2).sum(0)
#     std += images.std(2).sum(0)
#     nb_samples += batch_samples

# mean /= nb_samples
# std /= nb_samples
# print(mean,std)

In [None]:
# mean_val = 0.
# std_val = 0.
# nb_samples = 0.
# for data_ in tqdm(val_loader):
#     images = data_['img']
#     batch_samples = images.size(0)
#     images = images.view(batch_samples, images.size(1), -1)
    
#     mean_val += images.mean(2).sum(0)
#     std_val += images.std(2).sum(0)
#     nb_samples += batch_samples

# mean_val /= nb_samples
# std_val /= nb_samples
# print(mean_val,std_val)

In [None]:
# mean_test = 0.
# std_test = 0.
# nb_samples = 0.
# for data_ in tqdm(test_loader):
#     images = data_['img']
#     batch_samples = images.size(0)
#     images = images.view(batch_samples, images.size(1), -1)
    
#     mean_test += images.mean(2).sum(0)
#     std_test += images.std(2).sum(0)
#     nb_samples += batch_samples

# mean_test /= nb_samples
# std_test /= nb_samples
# print(mean_test,std_test)

In [None]:
def down_pooling():
    return nn.MaxPool2d(2)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.n_z=512      # number of dimensions in latent space.
        
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(inplace=True)
                                )
        self.conv11 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(inplace=True)
                                )
        
        self.conv2=nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(128),
                                 nn.ReLU(inplace=True)
                                )
        self.conv22=nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(128),
                                 nn.ReLU(inplace=True)
                                )
        self.conv3=nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(256),
                                 nn.LeakyReLU(inplace=True)
                                )
        self.conv33=nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(256),
                                 nn.ReLU(inplace=True)
                                )
        self.conv4=nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(inplace=True)
                                )
        self.conv44=nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(512),
                                 nn.ReLU(inplace=True)
                                )        
        self.conv5=nn.Sequential(nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(1024),
                                 nn.ReLU(inplace=True)
                                )
        self.conv55=nn.Sequential(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3,stride=1,padding=1),
                                 nn.BatchNorm2d(1024),
                                 nn.ReLU(inplace=True)
                                )
        self.down_pooling = nn.MaxPool2d(2)
        
        self.mu = nn.Linear(1024*14*14, 256)
        self.logvar = nn.Linear(1024*14*14, 256)

        
    
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample

        
    def forward(self,x):
        
        x=self.conv1(x)
        x=self.conv11(x)
        x = self.down_pooling(x)
        
        x=self.conv2(x)
        x=self.conv22(x)
        x = self.down_pooling(x)
        
        
        x=self.conv3(x)
        x=self.conv33(x)
        x = self.down_pooling(x)

        
        x=self.conv4(x)
        x=self.conv44(x)
        x = self.down_pooling(x)
        
        x=self.conv5(x)
        x=self.conv55(x)
        
        x = x.view(x.size(0),-1)
        mu, logvar = self.mu(x), self.logvar(x)
        x = self.reparameterize(mu, logvar)
        
        return x,mu,logvar




In [None]:
def up_pooling(in_channels, out_channels, kernel_size=2, stride=2):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )
def conv_bn_leru(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
    )

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.fc = nn.Sequential(nn.Linear(256,1024*14*14),
                                nn.ReLU()
                                )
        self.up_pool6 = up_pooling(1024, 1024)
        self.conv6 = conv_bn_leru(1024, 512)
        self.up_pool7 = up_pooling(512, 512)
        self.conv7 = conv_bn_leru(512, 256)
        self.up_pool8 = up_pooling(256, 256)
        self.conv8 = conv_bn_leru(256, 128)
        self.up_pool9 = up_pooling(128, 128)
        self.conv9 = conv_bn_leru(128, 64)

        self.conv10 = nn.Conv2d(64, 1,1)
        self.relu = nn.ReLU()
    
        
    def forward(self,x):
        x  = self.fc(x)
        x5 = x.view(x.size(0),1024,14,14)
        
        p6 = self.up_pool6(x5)
        x6 = self.conv6(p6)

        p7 = self.up_pool7(x6)
        x7 = self.conv7(p7)

        p8 = self.up_pool8(x7)
        x8 = self.conv8(p8)

        p9 = self.up_pool9(x8)
        x9 = self.conv9(p9)
        
        output = self.conv10(x9)
        output = self.relu(output)
        return output


In [None]:
class Sobel(nn.Module):
    def __init__(self):
        super(Sobel, self).__init__()
        self.x_filter = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
        self.y_filter = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
        self.convx = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.convy = nn.Conv2d(1, 1, kernel_size=3 , stride=1, padding=1, bias=False)
        
        self.weights_x = torch.from_numpy(self.x_filter).float().unsqueeze(0).unsqueeze(0)
        self.weights_y = torch.from_numpy(self.y_filter).float().unsqueeze(0).unsqueeze(0)
        
        self.convx.weight = nn.Parameter(self.weights_x)
        self.convy.weight = nn.Parameter(self.weights_y)
        
    def forward(self,x,target):
        g1_x = self.convx(x)
        g2_x = self.convx(target)
        g1_y = self.convy(x)
        g2_y = self.convy(target)
        
        g_1 = torch.pow(g1_x, 2) + torch.pow(g1_y, 2)
        g_2 = torch.pow(g2_x, 2) + torch.pow(g2_y, 2)
        
        loss=torch.mean((g_1 - g_2).pow(2))
        
        return loss   

In [None]:
# def weight_init(m):
#     if isinstance(m, nn.Conv2d):
#         init.xavier_normal_(m.weight)
#         init.constant_(m.bias, 0)
#     if isinstance(m, nn.Linear):
#         init.xavier_normal_(m.weight)
#         init.constant_(m.bias, 0)

In [None]:
def weight_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Conv2d):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
    if isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


In [None]:
def free_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def frozen_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [None]:
encoder=Encoder()
encoder.apply(weight_init)
decoder=Decoder()
decoder.apply(weight_init)
encoder, decoder= encoder.to(device), decoder.to(device)

sobel=Sobel()
sobel=sobel.to(device)
frozen_params(sobel)

In [None]:
class SSIM_Loss(SSIM):
    def forward(self, img1, img2):
        return ( 1 - super(SSIM_Loss, self).forward(img1, img2)) 

In [None]:
def train(**kwargs):
    torch.cuda.empty_cache()
    opt = Config()
    print('loading the model...')
    sobel.eval()
    enc_optim = torch.optim.Adam(encoder.parameters(), lr = opt.lr,amsgrad=True)
    dec_optim = torch.optim.Adam(decoder.parameters(), lr = opt.lr,amsgrad=True)
#     enc_scheduler = torch.optim.lr_scheduler.StepLR(enc_optim, step_size=50, gamma=0.1)
#     dec_scheduler = torch.optim.lr_scheduler.StepLR(dec_optim, step_size=50, gamma=0.1)
    enc_scheduler = lr_scheduler.CosineAnnealingLR(enc_optim, T_max=10, eta_min=0.000001)
    dec_scheduler = lr_scheduler.CosineAnnealingLR(enc_optim, T_max=10, eta_min=0.000001)
#     criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=1)
    criterion = nn.MSELoss()
    l1_loss = nn.L1Loss()
    try:
        state = torch.load(os.path.join(opt.model_path, 'model.pth'))
        encoder.load_state_dict(state['enc_state_dict'])
        decoder.load_state_dict(state['dec_state_dict'])
        print("Loaded pre-trained models with success.")
        e_counter=state['epoch']
        best_valid_loss = state['valid_loss_min']
#         best_valid_loss = float('inf')
        prev_loss=float('inf')
        print('Previously Trained for {} epoches'.format(e_counter))
        e_counter+=1
    except FileNotFoundError:
        print("Pre-trained weights not found. Training from scratch.")
        e_counter=0
        best_valid_loss = float('inf')
        prev_loss=float('inf')
    elrs = []
    dlrs = []
    t_loss = []
    v_loss=[]
    epoches=[]
    for epoch in range(e_counter,opt.max_epochs):
        encoder.train()
        decoder.train()
        epoch_start_time = time.time()
        print()
        print('==================================================================')
        print('-------------Epoch: {}/{}------------'.format(epoch,opt.max_epochs))
        kl_annealtime = 100
        p= 100
        l= 10
        s= 10
        e_loss=e_vloss=0.0
        e_rl=e_kl=e_sl=0.0
        e_vrl=e_vkl=e_vsl=0.0
        for idx,batch in enumerate(train_loader,1):
            image=batch['img']
            image=image.to(device)
            encoder.zero_grad()
            decoder.zero_grad()
            
            bs, n_crops, c, h, w = images.size()
            inputs = images.view(-1, c, h, w)
            inputs = torch.autograd.Variable(inputs.view(-1, c, h, w))    
            enc_out,mu,logvar = encoder(inputs).view(bs, n_crops, -1).mean(dim=1)
            
            dec_out = decoder(enc_out)    
            rl  = p*criterion(dec_out, image)+ l*l1_loss (image,dec_out)
            KLD = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
            kl_weight = (1./kl_annealtime)
            kl  = kl_weight*KLD
            sl  = s*sobel(dec_out,image)
            loss=rl+kl+sl
            loss.backward() 
            enc_optim.step()
            dec_optim.step()
            e_loss+=loss.item()
            e_rl +=rl.item()
            e_kl +=kl.item()
            e_sl +=sl.item()
            if idx%1000 == 0:
                print(idx,'/',len(train_loader),'|Loss: %.3f | RL: %.3f | KL: %.3f | sobel: %.3f ' %(e_loss/idx,e_rl/idx,e_kl/idx,e_sl/idx))
        mean_loss = e_loss/len(train_loader)
        print('Train Loss: %.3f'%(mean_loss))
        t_loss.append(mean_loss)
        with open(f'{opt.save_path}/train_logs.txt', 'a') as file:
            file.write('epoch: ' + str(epoch) + ',loss: '+ str(mean_loss) + ',rl: ' + str(e_rl) + ',kl: ' + str(e_kl) + ',sl: ' + str(e_sl) +'\n')  
        state = {
                'epoch': epoch,
                'loss_min': mean_loss,
                'enc_state_dict': encoder.state_dict(),
                'dec_state_dict': decoder.state_dict(),
                }
        if epoch % opt.plot_every == 0:
            if mean_loss < prev_loss:
                torch.save(state, os.path.join(opt.model_path, 'train_model.pth'))
                filename = 'decoded_%04d.png' % (epoch)
                decoded_path = os.path.join(opt.decode_path, filename)
                tv.utils.save_image(dec_out.cpu().data, decoded_path)
                prev_loss = mean_loss
        print()    
        print('...........................validation....................................') 
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            for idx, batch in enumerate(val_loader,1):
                image,label=batch['img'],batch['label']
                image, label = image.to(device),label.to(device)
                
                enc_out,mu,logvar = encoder(image)
                dec_out = decoder(enc_out)

                rl  = p*criterion(dec_out, image)+ l*l1_loss (image,dec_out)
                KLD = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
                kl_weight = (1./kl_annealtime)
                kl  = kl_weight*KLD
                sl  = s*sobel(dec_out,image)

                eval_loss = rl+kl+sl
                e_vloss += eval_loss.item()
                e_vrl +=rl
                e_vkl +=kl
                e_vsl +=sl
               
                
                if idx%400 == 0:
#                     print('VAL_Batch ',i,'/',len(val_loader), '|| Loss: %.5f'%(eval_loss/(i+1)))
                    print(idx,'/',len(val_loader),'|Loss: %.3f | RL: %.3f | KL: %.3f | sobel: %.3f ' %(e_vloss/idx,e_vrl/idx,e_vkl/idx,e_vsl/idx))


        
        valid_loss = e_vloss/len(val_loader)
        print('Valid Loss: %.3f'%(valid_loss))
        v_loss.append(valid_loss)
        epoches.append(epoch)
        
        state = {
                'epoch': epoch,
                'valid_loss_min': valid_loss,
                'enc_state_dict': encoder.state_dict(),
                'dec_state_dict': decoder.state_dict(),
                'enc_optimizer': enc_optim.state_dict(),
                'dec_optimizer': dec_optim.state_dict(),
                }
        
        if epoch % opt.save_every == 0 or epoch == opt.max_epochs - 1:
            if valid_loss < best_valid_loss:
                print('Validation loss decreased ({:.5f} --> {:.5f}). Saving model ...'.format(best_valid_loss,valid_loss))
                torch.save(state, os.path.join(opt.model_path, 'model.pth'))
                filename = 'fake_%04d.png' %(epoch)
                val_path = os.path.join(opt.val_path, filename)
                tv.utils.save_image(dec_out.cpu().data, val_path)
                with open(f'{opt.save_path}/val_logs.txt', 'a') as file:
                    file.write('epoch: ' + str(epoch) + ',loss: '+ str(valid_loss) + ',rl: ' + str(e_vrl) + ',kl: ' + str(e_vkl) + ',sl: ' + str(e_vsl) +'\n')

                best_valid_loss = valid_loss
                
        epoch_time = int(time.time() - epoch_start_time)
        print(f'-----------------Epoch cost time {epoch_time}s--------------------')
        elrs.append(enc_optim.param_groups[0]["lr"])
        dlrs.append(dec_optim.param_groups[0]["lr"])
        print('Enc_learning_rate :',enc_scheduler.get_lr())
        print('Dec_learning_rate :',dec_scheduler.get_lr())
        
        enc_scheduler.step()
        dec_scheduler.step()       
    
    plt.plot(elrs)
    filepath=os.path.join(opt.save_path, 'ELR.png')
    plt.savefig(filepath)
    plt.plot(dlrs)
    filepath=os.path.join(opt.save_path, 'DLR.png')
    plt.savefig(filepath)
    
    filepath=os.path.join(opt.save_path, 'losses.png')
    plt.title("Training Curve")
    plt.plot(epoches, t_loss, label="Train")
    plt.plot(epoches, v_loss, label="Validation")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.savefig(filepath)
    
    

In [None]:
train()

In [None]:
print('loading the model...')
encoder=Encoder()
decoder=Decoder()
encoder, decoder= encoder.to(device), decoder.to(device)

sobel=Sobel().to(device)
state = torch.load(os.path.join(opt.model_path, 'train_model.pth'))
encoder.load_state_dict(state['enc_state_dict'])
decoder.load_state_dict(state['dec_state_dict'])

print("Loaded pre-trained regressor with success.")
e_counter=state['epoch']
print('Previously Trained for {} epoches'.format(e_counter))

criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=1)
l1_loss = nn.L1Loss()
kl_annealtime = 200
p_s = 100000
val_epoch_loss=0.0
i=0
with torch.no_grad():
    for batch in tqdm(test_loader):
        image,label=batch['img'],batch['label']
        image, label = image.to(device),label.to(device)

        enc_out,mu,logvar = encoder(image)
            
        dec_out = decoder(enc_out)

        
        rl  = p_s*criterion(dec_out, image)
        l1  = 1000*l1_loss (image,dec_out) 
        KLD = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
        kl_weight = (1./kl_annealtime)
        kl = kl_weight*KLD
        sobel_loss = sobel(dec_out,image)
        sobel_loss = 1000*sobel_loss
        
        eval_loss=rl+l1+kl+sobel_loss
        val_epoch_loss += eval_loss.item()
        
        
        filename1 = 'fake_%02d.png' %(i)
        filename2 = 'orig_%02d.png' %(0)
        val_path1 = os.path.join(opt.test_path, filename1)
        val_path2 = os.path.join(opt.test_path, filename2)
        tv.utils.save_image(dec_out.cpu().data, val_path1)
        tv.utils.save_image(image.cpu().data, val_path2)
        i+=1
        if i ==11:
            break
valid_loss = val_epoch_loss/len(test_loader)
print('test Loss: %.5f'%(valid_loss))