In [5]:
import os
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import save_image
import torchsummary

if not os.path.exists('./VAE_img'):
    os.mkdir('./VAE_img')

In [6]:
def normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor

def value_round(tensor):
    return torch.round(tensor)

def to_img(x):
    x = x.view(x.size(0),1,28,28)
    return x

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor:normalization(tensor,0,1)),
    transforms.Lambda(lambda tensor:value_round(tensor))
])
batch_size=128
dataset = FashionMNIST('./FashionMNIST_DATASET',transform=img_transform,download=True)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:07<00:00, 3581354.36it/s]


Extracting ./FashionMNIST_DATASET\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 112588.40it/s]


Extracting ./FashionMNIST_DATASET\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:02<00:00, 1697391.08it/s]


Extracting ./FashionMNIST_DATASET\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5256153.11it/s]

Extracting ./FashionMNIST_DATASET\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./FashionMNIST_DATASET\FashionMNIST\raw






In [7]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self):
        super(VariationalAutoEncoder,self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28,400),
            nn.ReLU(True),
            nn.Linear(400,40)
        )
        self.decoder = nn.Sequential(
            nn.Linear(20,400), # 반은 MU, 반은 Var
            nn.ReLU(True),
            nn.Linear(400,28*28),
            nn.Sigmoid()
        )
    
    def reparameterization(self,mean,logvar):
        var = logvar.exp() # e^{logvar}
        std = var.sqrt() # std = sqrt(e^{logvar})
        eps = Variable(torch.cuda.FloatTensor(std.size()).normal_()) # N(0,1)
        return eps.mul(std).add(mean)
    
    def forward(self,x):
        h = self.encoder(x)
        mean = h[:,:20]
        logvar = h[:,20:]
        z = self.reparameterization(mean,logvar)
        x_gen = self.decoder(z)
        
        return x_gen, mean, logvar
    
    def interpolation(self,x_1,x_2,alpha):
        traverse_1 = self.encoder(x_1)
        traverse_2 = self.encoder(x_2)
        mean_1,mean_2 = traverse_1[:,:20],traverse_2[:,:20]
        
        logvar_1, logvar_2 = traverse_1[:,20:],traverse_2[:,20:]
        traverse_m = (1-alpha) * mean_1 + alpha * mean_2
        traverse_logvar = (1-alpha) * logvar_1 + alpha * logvar_2
        z = self.reparameterization(traverse_m,traverse_logvar)
        generated_image = self.decoder(z)
        
        return generated_image

In [8]:
import pytorch_model_summary
model = VariationalAutoEncoder().cuda()
print(pytorch_model_summary.summary(model,torch.zeros(1,784).cuda(),show_input=True))

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
BCE = nn.BCELoss()
num_epochs, learning_rate = 50,1e-3
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

for epoch in range(num_epochs):
    for data in dataloader:
        img,_ = data
        img = img.view(img.size(0),-1)
        img = Variable(img).cuda()
        x_gen,mu,logvar = model(img)
        NKLD = mu.pow(2).add(logvar.exp()).mul(-1).add(logvar.add(1))
        KLD = torch.sum(NKLD).mul(-0.5)
        
        KLD /= batch_size * 784
        loss = BCE(x_gen,img) + KLD
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if epoch % 10 == 0 or (epoch+1) == num_epochs:
        print(f'epoch: [{epoch+1} / {num_epochs}] loss:{loss.item():.4f}')
        x_gt = to_img(img.cpu().data)
        x_gen = to_img(x_gen.cpu().data)
        save_image(x_gt,'./VAE_img/ground_truth_{}.png'.format(epoch))
        save_image(x_gen,'./VAE_img/generated_{}.png'.format(epoch))
        batch = next(iter(dataloader))
        batch = batch[0].clone().detach()
        batch = batch.view(batch.size(0),-1)
        batch = Variable(batch).cuda()
        x_1 = batch[0:1]
        x_2 = batch[1:2]
        generated_images = []
        for alpha in torch.arange(0.0,1.0,0.1):
            generated_images.append(model.interpolation(x_1,x_2,alpha))
        generated_images = torch.cat(generated_images,0).cpu().data
        save_image(generated_images.view(-1,1,28,28),'./VAE_img/interpolation_{}.png'.format(epoch))
        
torch.save(model.state_dict(),'./Variational_AutoEncoder.pth')