In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [82]:
batch_size = 64
hidden = 256
beta = 20
M_N = 1

In [89]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import SubsetRandomSampler
import random
from random import choices
random.seed(52)
data_transform = transforms.Compose([transforms.ToTensor(),])

your_name_dataset = ImageFolder(root="../png/142p_data",
                                transform=data_transform)
N = len(your_name_dataset)
sample_N = 6666
sample_indices = choices(list(range(1, N)), k=sample_N)
data_loader = torch.utils.data.DataLoader(your_name_dataset,
                                          shuffle=False,
                                          batch_size=batch_size,
                                          pin_memory=True,
                                          sampler=SubsetRandomSampler(sample_indices)
                                         )

In [87]:
input_dim = your_name_dataset[0][0].reshape(-1, 1).size(0)
print(your_name_dataset[0][0].shape)
print(input_dim)

torch.Size([3, 142, 189])
80514


In [52]:
from torch.autograd import Variable

def _cal_dim(x, k, s, p=0):
    return (x - k + p) // s + 1
    
class View(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, x):
        return x.view(self.dim)

class BetaVAE(nn.Module):
    def __init__(self, latent=10, in_channels=3):
        super().__init__()
        self.latent = latent
        self.in_channels = in_channels
        
        a, b = 2, 5
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=(8, 9), stride=2),    # (B, -1, 142, 189)     
            nn.ReLU(True),
            nn.Conv2d(32, 32, kernel_size=(6, 7), stride=2),             # (B, -1, 68, 91)     
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=(6, 7), stride=2),             # (B, -1, 32, 43)   
            nn.ReLU(True),
            nn.Conv2d(64, 256, kernel_size=(6, 7), stride=2),            # (B, -1, 14, 19) 
            nn.ReLU(True),
            nn.Conv2d(256, 512, kernel_size=(5, 7), stride=1),     # (B, -1, 5, 7)        
            nn.ReLU(True),
            View((-1, 512 * 1 * 1))          
        )
        self.fc_mu = nn.Linear(512 * 1 * 1, latent)
        self.fc_sigma = nn.Linear(512 * 1 * 1, latent)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent, 512 * 1 * 1),             
            View((-1, 512, 1, 1)),              
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=(5, 7), stride=1),  
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 64, kernel_size=(6, 7), stride=2), 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=(6, 7), stride=2), 
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, kernel_size=(6, 7), stride=2), 
            nn.ReLU(True),
            nn.ConvTranspose2d(32, in_channels, kernel_size=(8, 9), stride=2), 
        )


    def forward(self, x):
        batch_size = x.size(0)
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_sigma(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
    

    def reparameterize(self, mu, logvar):
        std = logvar.div(2).exp()
        eps = Variable(torch.randn_like(std))
        return eps * std + mu


In [53]:
# x = _cal_dim(142, 8, 2)   # 68
# x = _cal_dim(x, 6, 2)     # 32
# x = _cal_dim(x, 6, 2)     # 14
# x = _cal_dim(x, 6, 2)     #  5
# x = _cal_dim(x, 5, 1)
# print(x)

In [54]:
# x = _cal_dim(189, 9, 2)  # 91
# x = _cal_dim(x, 7, 2)     # 43
# x = _cal_dim(x, 7, 2) # 19
# x = _cal_dim(x, 7, 2) # 7
# x = _cal_dim(x, 7, 1)
# print(x)

In [55]:
def loss_function(recon_x, x, mu, logvar, beta, M_N):
    b = x.size(0)
    recon_x, x = recon_x.view(b, -1), x.view(b, -1)
    recon_x = F.sigmoid(recon_x)
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Varational Bayes. ICLR,i 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + beta * M_N * KLD

In [56]:
torch.cuda.empty_cache()
torch.cuda.memory_summary()



In [57]:
model = BetaVAE(latent=hidden)
print(model)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

BetaVAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(8, 9), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 32, kernel_size=(6, 7), stride=(2, 2))
    (3): ReLU(inplace=True)
    (4): Conv2d(32, 64, kernel_size=(6, 7), stride=(2, 2))
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 256, kernel_size=(6, 7), stride=(2, 2))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 512, kernel_size=(5, 7), stride=(1, 1))
    (9): ReLU(inplace=True)
    (10): View()
  )
  (fc_mu): Linear(in_features=512, out_features=256, bias=True)
  (fc_sigma): Linear(in_features=512, out_features=256, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): View()
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(5, 7), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): ConvTranspose2d(256, 64, kernel_size=(6, 7), stride=(2, 2))
    (6): ReLU(inplace=True)
    (7): ConvTranspose2d(64, 32, kernel_size=(6, 7)

In [93]:
epochs = 200
log_interval = 20
length = len(data_loader) // batch_size

for epoch in range(1, epochs + 1):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar, beta, M_N)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % length == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx, len(data_loader),
                100. * batch_idx / len(data_loader),
                loss.item()))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss * batch_size / len(data_loader)))



KeyboardInterrupt: 