In [1]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import numpy as np
import random
import matplotlib.pyplot as plt

In [2]:
bs = 50
# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
# test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 176227259.31it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 28896873.53it/s]


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 52179654.42it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7076719.45it/s]


Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



In [3]:
# class VAE(nn.Module):
#     def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
#         super(VAE, self).__init__()

#         # encoder part
#         self.fc1 = nn.Linear(x_dim, h_dim1)
#         self.fc2 = nn.Linear(h_dim1, h_dim2)
#         self.fc31 = nn.Linear(h_dim2, z_dim)
#         self.fc32 = nn.Linear(h_dim2, z_dim)
#         # decoder part
#         self.fc4 = nn.Linear(z_dim, h_dim2)
#         self.fc5 = nn.Linear(h_dim2, h_dim1)
#         self.fc6 = nn.Linear(h_dim1, x_dim)

#     def encoder(self, x):
#         h = F.relu(self.fc1(x))
#         h = F.relu(self.fc2(h))
#         return self.fc31(h), self.fc32(h) # mu, log_var

#     def sampling(self, mu, log_var):
#         std = torch.exp(0.5*log_var)
#         eps = torch.randn_like(std)
#         return eps.mul(std).add_(mu) # return z sample

#     def decoder(self, z):
#         h = F.relu(self.fc4(z))
#         h = F.relu(self.fc5(h))
#         return F.sigmoid(self.fc6(h))

#     def forward(self, x):
#         mu, log_var = self.encoder(x.reshape(x.shape[0], 729))
#         z = self.sampling(mu, log_var)
#         return self.decoder(z), mu, log_var

# # build model
# vae = VAE(x_dim=729, h_dim1= 512, h_dim2=256, z_dim=2)
# if torch.cuda.is_available():
#     vae.cuda()
import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()

        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var

    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample

    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h))

    def forward(self, x):
        mu, log_var = self.encoder(x.reshape(x.shape[0], 729))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var
# build model
vae = VAE(x_dim=729, h_dim1= 512, h_dim2=256, z_dim=2)
torch.cuda.empty_cache()
if torch.cuda.is_available():
    vae.cuda()
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# vae.to(device)

In [4]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 729), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [5]:
def make_full_img_arr(img_ls, M, N):
    # e.g., img_ls = [[[1, 2], [5, 6]], [[3, 4], [7, 8]], [[9, 10], [13, 14]], [[11, 12], [15, 16]]]
    seg_nrow = img_ls[0].shape[0]
    seg_ncol = img_ls[0].shape[1]

    tmp_arr = []
    for i in range(0, len(img_ls), N):
        col_arr = img_ls[i]
        for j in range(1, M):
            col_arr = np.concatenate((col_arr, img_ls[i+j]), axis=1)
        tmp_arr.append(col_arr)

    tmp_arr = np.array(tmp_arr)
    tmp_arr = tmp_arr.reshape(M*seg_nrow, N*seg_ncol)
    return tmp_arr


def mirror_arr(np_arr):
    return np.flip(np_arr, axis=1)


def flip_arr(np_arr):
    return np.flip(np_arr, axis=0)


def rotate_90_arr(np_arr, k):
    return np.rot90(np_arr, k=k)


def make_segments_tensor(data_point, M, N, is_augment=True):
    nrow = data_point.shape[0]
    ncol = data_point.shape[1]
    segments_ls = []
    for i in range(M):
        for j in range(N):
            segment = np.array(data_point[
                i*(nrow//M):(i+1)*(nrow//M),
                j*(ncol//N):(j+1)*(ncol//N)
                ])
            if is_augment:
                u = random.uniform(0, 1)
                if u < 0.5:
                    segment = mirror_arr(segment)

                u = random.uniform(0, 1)
                if u < 0.5:
                    segment = flip_arr(segment)

                u = random.uniform(0, 1)
                if u < 0.5:
                    segment = rotate_90_arr(segment, k=1)

            segments_ls.append(segment)

    segments_tensor = torch.Tensor(np.array(segments_ls))
    return segments_tensor


def augment_image(data_point, M, N):
    segments_tensor = make_segments_tensor(data_point, M, N)

    # Specify the dimension along which you want to shuffle (0 for rows, 1 for columns)
    dim_to_shuffle = 0

    # Generate a random permutation of indices for the specified dimension
    permuted_indices = torch.randperm(segments_tensor.size(dim_to_shuffle))

    # Use the permutation to shuffle the tensor along the specified dimension
    shuffled_tensor = segments_tensor.index_select(dim_to_shuffle, permuted_indices)

    augmented_data_point = torch.tensor(make_full_img_arr(shuffled_tensor, M, N)).unsqueeze(dim=0).unsqueeze(dim=0)
    return augmented_data_point


def append_augmented_data(batch_data, M, N, rep=1):
    for i in range(len(batch_data)):
        data_point = batch_data[i][0]
        for j in range(rep):
          augmented_data_point = augment_image(data_point, M, N).numpy()
          # batch_data = torch.cat((batch_data, augmented_data_point), dim=0)
          batch_data = np.concatenate((batch_data, augmented_data_point), axis=0)
    batch_data = torch.tensor(batch_data)
    # return batch_data[len(batch_data):]
    return batch_data


def imshow(img):
    # img = img / 0.5 + 0.5     # unnormalize
    img = img * 0.5 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [6]:
M = 3
N = 3
rep = 2
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        # fit image size to M, N sliceable data
        while data.shape[2] % M != 0:
            data = data[:, :, :-1, :]
        while data.shape[3] % N != 0:
            data = data[:, :, :, :-1]
        ground_truth = torch.cat((data, data), axis=0)
        data_cpu = data.cpu()
        data = append_augmented_data(data_cpu.detach().numpy(), M, N).cuda()
        # for i in range(rep):
        #     data = torch.cat((data, append_augmented_data(data_cpu.detach().numpy(), M, N).cuda()), dim=0).cuda()

        # shuffle data and ground_truth
        dim_to_shuffle = 0
        permuted_indices = torch.randperm(ground_truth.size(dim_to_shuffle)).cuda()

        data = data.index_select(dim_to_shuffle, permuted_indices)
        ground_truth = ground_truth.index_select(dim_to_shuffle, permuted_indices)

        optimizer.zero_grad()

        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, ground_truth, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [7]:
for epoch in range(1, 51):
    train(epoch)
    # test()

====> Epoch: 1 Average loss: -84117.4600
====> Epoch: 2 Average loss: -92951.0141
====> Epoch: 3 Average loss: -90820.2012
====> Epoch: 4 Average loss: -88691.1913
====> Epoch: 5 Average loss: -85888.7197
====> Epoch: 6 Average loss: -82903.3486
====> Epoch: 7 Average loss: -79762.1633
====> Epoch: 8 Average loss: -76756.3891
====> Epoch: 9 Average loss: -74793.6841
====> Epoch: 10 Average loss: -73974.2881
====> Epoch: 11 Average loss: -73835.7921
====> Epoch: 12 Average loss: -74258.8270
====> Epoch: 13 Average loss: -73521.5387
====> Epoch: 14 Average loss: -73499.0750
====> Epoch: 15 Average loss: -73993.0521
====> Epoch: 16 Average loss: -75123.7252
====> Epoch: 17 Average loss: -76137.5873
====> Epoch: 18 Average loss: -77101.2247
====> Epoch: 19 Average loss: -77825.4004
====> Epoch: 20 Average loss: -79260.7754
====> Epoch: 21 Average loss: -79647.7736
====> Epoch: 22 Average loss: -79613.1292
====> Epoch: 23 Average loss: -80456.7098
====> Epoch: 24 Average loss: -80836.9153
=

In [None]:
from google.colab import drive

# Mount your Google Drive
drive.mount('/content/gdrive')

# Specify the Google Drive folder path where you want to save the file
drive_folder_path = '/content/gdrive/My Drive/Colab Notebooks/'

# Save the model or tensor to your Google Drive folder
torch.save(vae.state_dict(), drive_folder_path + "vae" + '_2.pth')