# VAE with do notation

In [1]:
from tqdm import tqdm
import time
import argparse
import os
import sys
from abc import abstractmethod
from torch.nn import functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import numpy as np

In [3]:
import torch
from torch import nn
import torch.nn.init as init
from torchvision import datasets
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image

In [4]:
sys.path.insert(0,"/cmlscratch/margot98/Causal_Disentangle/3dshape_dataset")

In [5]:
from datasets import ShapeDataset,CelebaDataset
from dataset_3d_shape import sample_3dshape_dataset

In [6]:
# data_path = "/cmlscratch/margot98/Causal_Disentangle/3dshape_data"
# attr_path =  "/cmlscratch/margot98/Causal_Disentangle/3dshape_data/sample_source.csv"

data_path = "/fs/cml-datasets/CelebA/Img/img_align_celeba"
attr_path =  "/fs/cml-datasets/CelebA/Anno/list_attr_celeba.txt"

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [7]:
print(device)

cuda


In [8]:
# img = Image.open(os.path.join(data_path,'92104.jpg'))
# img

In [9]:
image_size = 64
transform = T.Compose([
    T.Resize([image_size, image_size]),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# train_dataset = ShapeDataset( data_path='/cmlscratch/margot98/Causal_Disentangle/3dshape_data', 
#                               attr_path='/cmlscratch/margot98/Causal_Disentangle/3dshape_data/sample_source.csv',
#                               attr=['floor_hue', 'shape'],
#                               transform=transform
#                              )
train_dataset = train_dataset = CelebaDataset(data_path,attr_path,
                              attr=['Eyeglasses', 'Smiling'],
                              transform=transform)
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=128,
                          shuffle=True) 

In [10]:
def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)
            
class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)

In [11]:
class doVAE(nn.Module):

    def __init__(self,z_dim, c_dim = 2, in_channels=3):
        super(doVAE, self).__init__()
        
        self.z_dim = z_dim
        self.c_dim = c_dim
        # encoder output dim: c_Num*z_dim for mu_z, c_Num*z_dim for sigma_z, c_Num for mu_pi, c_Num for sigma_mu
        self.model_encoder= nn.Sequential(
            # B,  32, 32, 32
            nn.Conv2d(in_channels, out_channels=32, kernel_size=4, stride=2, padding=1),          
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # B,  32, 16, 16
            nn.Conv2d(32, out_channels=32, kernel_size=4, stride=2, padding=1),  
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # B,  64,  8,  8
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # B,  64,  4,  4
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),          
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # B, 256,  1,  1
            nn.Conv2d(64, 256, kernel_size=4, stride=1),           
            nn.BatchNorm2d(256),
            nn.ReLU(),
            View((-1, 256*1*1)),                 # B, 256
            
            # for mu and logvar
            nn.Linear(256,  c_dim*(2*z_dim+2)),      
        )
        
            
        self.model_decoder = nn.Sequential(
            nn.Linear(z_dim, 256),               # B, 256
            View((-1, 256, 1, 1)),               # B, 256,  1,  1
            nn.ReLU(),
            nn.ConvTranspose2d(256, 64, 4),      # B,  64,  4,  4
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, 4, 2, 1), # B,  64,  8,  8
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # B,  32, 16, 16
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 4, 2, 1),  # B, 3, 64, 64
        )
        

        self.weight_init()
        
    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                kaiming_init(m)
        
                
    def encode(self, X):
        
        result = self.model_encoder(X)
        X_len = X.shape[0]
        mu = torch.zeros(X_len,self.c_dim * self.z_dim)
        logvar = torch.zeros(X_len,self.c_dim * self.z_dim)
        mu_pi = torch.zeros(X_len,self.c_dim)
        var_pi = torch.zeros(X_len,self.c_dim)
        for c in range(self.c_dim):
            mu[:,c*self.z_dim:(c+1)*self.z_dim] = result[:,c*self.z_dim : (c+1)*self.z_dim]
            logvar[:,c*self.z_dim:(c+1)*self.z_dim] = result[:,(self.c_dim+c)*self.z_dim : (self.c_dim+c+1)*self.z_dim]
            mu_pi[:,c] = result[:,2*self.c_dim*self.z_dim + c]
            var_pi[:,c] = result[:,2*self.c_dim*self.z_dim + self.c_dim + c]
                                
        
        return mu, logvar, mu_pi, var_pi
                                 
    
    def decode(self, z):
        
        return self.model_decoder(z)
    
    
    def generate(self, num_samples, current_device):                         
        z = torch.randn(num_samples, self.z_dim)
        z = z.to(current_device)
        samples = self.decode(z)
        return samples
    

    def reparameterize(self, mu, logvar,mu_pi,var_pi):
        X_len = mu.shape[0]
        z_all = torch.zeros(X_len, self.z_dim)
        Pi = torch.zeros(X_len, self.c_dim)
        for c in range(self.c_dim):
            std = torch.exp(0.5 * logvar[:,c*self.z_dim:(c+1)*self.z_dim])
            eps = torch.randn_like(std)
            curr_z = eps * std + mu[:,c*self.z_dim:(c+1)*self.z_dim]
            
            std_pi = torch.exp(0.5 * var_pi[:,c])
            eps_pi = torch.randn_like(std_pi)
            curr_pi = eps_pi * std_pi + mu_pi[:,c]
            
            z_all += curr_z * curr_pi[:,None]
            Pi[:,c] = curr_pi
        
        return z_all, Pi
    
    

    def forward(self, X):
        mu, logvar, mu_pi, var_pi = self.encode(X)
        z,Pi = self.reparameterize(mu, logvar, mu_pi, var_pi)
        reconstructed_z = self.decode(z)
        return reconstructed_z, mu, logvar, Pi
    
    def calculate_loss(self, X, reconstructed_z, mu, logvar, Pi, label, pi_weight = 1.0):
        """
        Loss term currently contains: reconstruction loss, classification loss for p(c|x)
        """
        # classification loss:
        criterion = nn.CrossEntropyLoss()
        pi_loss = criterion(Pi, label.float())
        
        # reconstruction loss:
        rec_loss = F.mse_loss(reconstructed_z, X)
        
        loss = rec_loss + pi_weight * pi_loss
        
        return loss, rec_loss, pi_loss
        
    
#     def calculate_loss(self, x1, x2):
#         """
#         Here we compute the distance d(E[Z_i|do(Z_{-i}), X], E[Z_i|X])
#         E[Z_i | X] = \sum_{i}f(x_i) /N = \sum_{c} P(C = c)E[Z_i|X, C = c]
#         E[Z_i | do(Z_{-i}), X] = \sum_{c}P(C = c)E[Z_i|Z_{-i},X, C = c]
        
#         Computation:
#         sum_c P(C = c)\sum_i d(E[Z_i|do(Z_{-i}), X, C = c], E[Z_i|X, C = c])
#         """
#         prob_c = []
#         prob_c.append(len(x1)/(len(x1)+len(x2)))
#         prob_c.append(len(x2)/(len(x1)+len(x2)))
        
#         output_list, z_sample_list, z_mean_list, z_logvar_list = self.forward(x1, x2)
        
#         Dist = 0
#         for j in range(self.c_Num):
#             mu_j = z_mean_list[j]
#             logvar = z_logvar_list[j]
#             z_sample =  z_sample_list[j]
#             dist_j = 0
#             diff = mu_j - z_sample
#             length = len(mu_j)
#             for i in range(self.z_dim):
#                 E = mu_j[:,i]
#                 logsig_i_minus_i = torch.cat((logvar[:,i,:i].T, logvar[:,i,i+1:].T)).T
#                 logsig_minus_minus = torch.cat((logvar[:,:i,:i].reshape(length,i*i), 
#                                                 logvar[:,:i,i+1:].reshape(length,i*(self.z_dim-i-1)),
#                                                 logvar[:,i+1:,:i].reshape(length,i*(self.z_dim-i-1)),
#                                                 logvar[:,i+1:,i+1:].reshape(length,(self.z_dim-i-1)*(self.z_dim-i-1))), 
#                                                 axis = 1).reshape(-1,self.z_dim-1, self.z_dim-1)
                
#                 diff_minus = torch.cat((diff[:,:i].T, diff[:,i+1:].T)).T
#                 do_E = mu_j[:,i] + torch.sum(torch.matmul(logsig_i_minus_i.reshape(-1,1,self.z_dim - 1), 
#                                                           torch.linalg.inv(logsig_minus_minus)).squeeze()*diff_minus, axis = 1)
#                 # absolute value for distance measure
#                 dist_j += torch.abs(do_E - E)
                
#             Dist += torch.mean(dist_j)*prob_c[j]
        
#         ## MSE reconstruction loss
#         reconstruction_error = F.mse_loss(output_list[0], x1) + F.mse_loss(output_list[1], x2)
            
#         loss = reconstruction_error + Dist
            
#         return loss, reconstruction_error, Dist
    
    


In [12]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [13]:
CD_model = doVAE(z_dim=4)

In [14]:
train_samples = len(train_loader) * 128
learning_rate = 0.01*float(128)/256.
optimizer = torch.optim.SGD(CD_model.parameters(), lr=1e-3, 
                            momentum=0.9, weight_decay=5e-4)

In [15]:
for x, target in train_loader:
    # randomly generate target set, delete afterward
    summ = torch.sum(target, dim = 1)
    zero_index = (summ == 0).nonzero().squeeze().detach()
    target[zero_index,:] = torch.tensor([1,0])
    two_index = (summ == 2).nonzero().squeeze().detach()
    target[two_index,:] = torch.tensor([0,1])
    break
    

In [17]:
for epoch in tqdm(range(15)):
    CD_model.train()
    i = 0
    for x, target in train_loader:
        # randomly generate target set, delete afterward
        summ = torch.sum(target, dim = 1)
        zero_index = (summ == 0).nonzero().squeeze().detach()
        target[zero_index,:] = torch.tensor([1,0])
        two_index = (summ == 2).nonzero().squeeze().detach()
        target[two_index,:] = torch.tensor([0,1])
        #####
        
        i += 1
        batch_time = AverageMeter()
        losses = AverageMeter()
        end = time.time()
        total_start = time.time()
        
        reconstructed_z, mu, logvar, Pi = CD_model(x)
        
        loss, rec_loss, pi_loss = CD_model.calculate_loss(x, reconstructed_z, mu, logvar, Pi, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i%15 == 0:
        
            reduced_loss = loss.data
            losses.update(float(reduced_loss), x.size(0))
            torch.cuda.synchronize()
            batch_time.update((time.time() - end)/15)
            end = time.time()

            print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.5f} ({loss.avg:.4f})\t'.format(
                           epoch, i, len(train_loader),
                           batch_time=batch_time,
                           loss=losses))
        

  0%|                                                    | 0/15 [00:00<?, ?it/s]

Epoch: [0][15/1583]	Time 0.107 (0.107)	Loss 1.19318 (1.1932)	
Epoch: [0][30/1583]	Time 0.025 (0.025)	Loss 1.03093 (1.0309)	
Epoch: [0][45/1583]	Time 0.025 (0.025)	Loss 0.89789 (0.8979)	
Epoch: [0][60/1583]	Time 0.027 (0.027)	Loss 1.00473 (1.0047)	
Epoch: [0][75/1583]	Time 0.025 (0.025)	Loss 0.83017 (0.8302)	
Epoch: [0][90/1583]	Time 0.025 (0.025)	Loss 0.81048 (0.8105)	
Epoch: [0][105/1583]	Time 0.027 (0.027)	Loss 0.78252 (0.7825)	
Epoch: [0][120/1583]	Time 0.024 (0.024)	Loss 0.69036 (0.6904)	
Epoch: [0][135/1583]	Time 0.024 (0.024)	Loss 0.72997 (0.7300)	
Epoch: [0][150/1583]	Time 0.026 (0.026)	Loss 0.83829 (0.8383)	
Epoch: [0][165/1583]	Time 0.024 (0.024)	Loss 0.74022 (0.7402)	
Epoch: [0][180/1583]	Time 0.024 (0.024)	Loss 0.72603 (0.7260)	
Epoch: [0][195/1583]	Time 0.027 (0.027)	Loss 0.65784 (0.6578)	
Epoch: [0][210/1583]	Time 0.029 (0.029)	Loss 0.58063 (0.5806)	
Epoch: [0][225/1583]	Time 0.030 (0.030)	Loss 0.73998 (0.7400)	
Epoch: [0][240/1583]	Time 0.027 (0.027)	Loss 0.62465 (0.6246)

  0%|                                                    | 0/15 [15:57<?, ?it/s]


KeyboardInterrupt: 