In [1]:
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as tranforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt


# VAE and Beta-VAE

In [2]:
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.get_device_name(0)

In [4]:
device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
# device="cpu"
device

device(type='cuda')

In [5]:
tranform=tranforms.Compose([tranforms.Resize((64,64)),tranforms.ToTensor()])
dataset=torchvision.datasets.ImageFolder(root="./img_align_celeba",transform=tranform)
train_data=DataLoader(dataset=dataset,batch_size=64,shuffle=True)


In [8]:
dataset_two=np.load('./dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
dataset_two=torch.tensor(dataset_two['imgs']).unsqueeze(dim=1)
dataset_two=dataset_two.float()
train_data_two=DataLoader(dataset=dataset_two,batch_size=256,shuffle=True)


In [7]:
# train_data=train_data
# in_channels=3                   # Chossing CELEB 
train_data=train_data_two
in_channels=1         # Chosing Dsprits

In [11]:


class VAE(nn.Module):


    def __init__(self,in_channels,latent_dim):
      
        super(VAE, self).__init__()

        self.latent_dim = latent_dim
        self.in_channels=in_channels

        # Build Encoder
        units = []

        units.append(nn.Sequential(
                            nn.Conv2d(self.in_channels, out_channels=32,kernel_size= 3, stride= 2, padding  = 1),
                            nn.BatchNorm2d(32),
                            nn.LeakyReLU()
                            ))
        units.append(nn.Sequential(
                            nn.Conv2d(32, out_channels=64,kernel_size= 3, stride= 2, padding  = 1),
                            nn.BatchNorm2d(64),
                            nn.LeakyReLU()
                            ))
        units.append(nn.Sequential(
                            nn.Conv2d(64, out_channels=128,kernel_size= 3, stride= 2, padding  = 1),
                            nn.BatchNorm2d(128),
                            nn.LeakyReLU()
                            ))
        units.append(nn.Sequential(
                            nn.Conv2d(128, out_channels=256,kernel_size= 3, stride= 2, padding  = 1),
                            nn.BatchNorm2d(256),
                            nn.LeakyReLU()
                            ))
        units.append(nn.Sequential(
                            nn.Conv2d(256, out_channels=512,kernel_size= 3, stride= 2, padding  = 1),
                            nn.BatchNorm2d(512),
                            nn.LeakyReLU()
                            ))

        units.append(nn.Sequential(
                nn.Conv2d(512, out_channels=1024,kernel_size= 3, stride= 2, padding  = 1),
                nn.LeakyReLU())
                )
        units.append(
            nn.Flatten(start_dim=1)
            )
        units.append(
            nn.Linear(1024, latent_dim)
            )
        self.encoder=nn.Sequential(*units)
        self.mean_module = nn.Linear(self.latent_dim, self.latent_dim)
        self.variance_module = nn.Linear(self.latent_dim, self.latent_dim)
        

        
        
        # Build Decoder
        

        self.decoder_input = nn.Linear(latent_dim, 1024)

        units = []

        units.append(
                nn.Sequential(
                    nn.ConvTranspose2d(1024,
                                       512,
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(512),
                    nn.LeakyReLU())
                    )
        units.append(
                nn.Sequential(
                    nn.ConvTranspose2d(512,
                                       256,
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(256),
                    nn.LeakyReLU())
                    )
        units.append(
                nn.Sequential(
                    nn.ConvTranspose2d(256,
                                       128,
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(128),
                    nn.LeakyReLU())
                    )

        units.append(
                nn.Sequential(
                    nn.ConvTranspose2d(128,
                                       64,
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU())
                    )
        units.append(
                nn.Sequential(
                    nn.ConvTranspose2d(64,
                                       32,
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(32),
                    nn.LeakyReLU())
                    )
        
        units.append(nn.Sequential(
                            nn.ConvTranspose2d(32,
                                               32,
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(32),
                            nn.LeakyReLU())
                        )

        units.append(nn.Sequential(
                        nn.Conv2d(32, out_channels= self.in_channels,
                                      kernel_size= 3, padding= 1),
                        nn.Sigmoid())
                        )

        self.decoder = nn.Sequential(*units)


    def encode(self, input):
        mu = self.mean_module(result)
        log_var = self.variance_module(result)


        result = self.encoder(input)
        mu = result
        log_var = result

        return [mu, log_var]

    def decode(self, z):

        
        result = self.decoder_input(z)
        result = result.view(-1, 1024, 1, 1)
        # result= result.to("cuda")
        result = self.decoder(result)
        return result

    def Reparm(self, mean,log_var):

        m = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(m)
        return epsilon * m + mean

    def forward(self, input):

        mu, log_var = self.encode(input)
        x_hat=self.decode(self.Reparm(mu, log_var))
        return  [x_hat,mu, log_var]

    def total_loss(self,
                      reconstruction_error,
                      mu,log_var,beta=1):

        temp=torch.sum(1 + log_var - pow(mu,2) - log_var.exp(), dim = 1)
        kld_loss = torch.mean(-0.5 * temp, dim = 0)

        loss = reconstruction_error + beta*kld_loss
        return loss

    def sample_n(self,n):
        z = torch.randn(n,
                        self.latent_dim)

        samples = self.decode(z)
        return samples


In [9]:
epochs = 5 # epoch- means how much times the the "entire" dataset is used to memorise.

batch_size = 256 # Since we plan to use Stochastic gradient descent 

lr = 0.001 # Step size

In [None]:
model=VAE(in_channels=in_channels,latent_dim=5)
model=model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss(reduction='sum')

In [37]:
next(model.parameters()).is_cuda

True

In [38]:
beta=4

In [39]:
num_steps=int(len(dataset)/batch_size)

In [12]:



def fit_celeb(model, dataloader,beta):
    for i in range(epochs):
        print("\n\n\nEpoch ",i)
        running_loss = 0.0
        for i, (data,_) in tqdm(enumerate(dataloader),total=num_steps):
        
            data = data.to(device)
            
            optimizer.zero_grad()
            reconstruction, mu, logvar= model(data)
            reconstruction= criterion(reconstruction, data)
            loss = model.total_loss(reconstruction, mu, logvar,beta=1)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

        train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [13]:
def fit_dpsrite(model, dataloader,beta):
    for i in range(epochs):
        print("\n\n\nEpoch ",i)
        running_loss = 0.0
        for i, data in tqdm(enumerate(dataloader),total=num_steps):
        
            data = data.to(device)
            
            optimizer.zero_grad()

            
            reconstruction, mu, logvar = model(data)
            reconstruction= criterion(reconstruction, data)
            loss = model.total_loss(reconstruction, mu, logvar, beta=beta)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

        train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [None]:
# fit_dpsrite(model,train_data,beta=beta)
fit_celeb(model,train_data,beta=beta)


In [27]:
dataset_two[0]=torch.unsqueeze(dataset_two[0],dim=0)
dataset_two[0].size()

torch.Size([1, 64, 64])

In [42]:
torch.save(model,"/home/liyana/disk/liyana/VAE/saved_models/model3_beta4_d1.pth")

In [None]:
model = torch.load("/home/liyana/disk/liyana/VAE/saved_models/model1_beta1_d2.pth")
model.eval()

In [64]:
z_test = torch.randn(1,model.latent_dim)

In [65]:
m=model.decode(z_test.to(device)).squeeze()

In [66]:
m=m.detach().cpu()

In [13]:
m.shape

torch.Size([3, 64, 64])

In [None]:

if (num_channels == 3):
    plt.imshow(m.permute(1, 2, 0))
else:
    plt.imshow(m, cmap='gray')
plt.show()

In [19]:
sample = torch.randn(100,model.latent_dim)

In [67]:
m=model.decode(sample.to(device)).squeeze().detach().cpu()

In [None]:
fig=plt.figure(1,figsize=(40,40))
for idx in range(100):
    # print(images.min())
    ax=fig.add_subplot(10,10,idx+1,xticks=[],yticks=[])
    #plt.imshow(m[idx].permute(1,2,0))
    plt.imshow(m[idx], cmap="gray")
    

In [81]:
#### DiSENTAGLEMENT PLOT ###############
# Works for both ##
model = torch.load("/home/liyana/disk/liyana/VAE/saved_models/model3_beta4_d2.pth")
model.eval()

VanillaVAE_truth_celeb(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )

In [82]:
z=torch.randn(1,model.latent_dim) # dim(z)=1*latent_dim
step=0.5
N=10
replacement_vec=torch.arange(start=-2.5,end=2.5,step=0.5)
# m=model.decode(torch.randn(5,model.latent_dim)).detach().squeeze(dim=0)

In [83]:
Base_matrix=z.repeat(repeats=(N,1))
index=0
Base_matrix[:, index] = replacement_vec
container=Base_matrix.unsqueeze(dim=0)
# container.size()

In [84]:
for index in range(1,model.latent_dim):
    Base_matrix[:, index] = replacement_vec
    container=torch.cat([container,Base_matrix.unsqueeze(dim=0)],dim=0)
# container.size()

In [85]:
m=model.decode(container.to(device)).detach().cpu()
if model.in_channels==1:
    m=m.squeeze(dim=1)

In [None]:
fig=plt.figure(1,figsize=(40,40))
for idx in range(model.latent_dim*N):
    ax=fig.add_subplot(10,N,idx+1,xticks=[],yticks=[])
    if model.in_channels==1:
        plt.imshow(m[idx], cmap="gray")
    else:
        plt.imshow(m[idx].permute(1,2,0), cmap="gray")