In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

The encoder and decoder are mirrored networks consisting of two layers. In the encoder the we take the input data to a hidden dimension through a linear layer and then we pass the hidden state to two different linear layers outputting the mean and standard deviation of the latent distribution respectively.

We then sample from the latent distribution and input it to the decoder that in turn outputs a vector of the same shape as the input.

### Creating the latent space using an encoder

In [None]:
# creating helper print module
class PrintSize(nn.Module):
  def __init__(self):
    super(PrintSize, self).__init__()

  def forward(self, x):
    print(x.shape)
    return x

In [None]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 224
Z_DIM = 20
H_DIM = 20
NUM_EPOCHS = 10
BATCH_SIZE = 132
LR_RATE = 1e-3

In [None]:
218/4

In [None]:
1458/H_DIM

In [None]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, z_dim=Z_DIM, h_dim=H_DIM, kernel_size=4, hidden_channels=16):
        super().__init__()

        # encoder

        # one for mu and one for stds, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # the pixels are conditionally independent

        self.conv1 = nn.Conv2d(1, hidden_channels, kernel_size)
        self.conv2 = nn.Conv2d(hidden_channels, 1, kernel_size)
        self.vector_length = 218
        self.pool1 = nn.AdaptiveAvgPool2d((int(self.vector_length/2), int(self.vector_length/2)))
        self.pool2 = nn.AdaptiveAvgPool2d((int(self.vector_length/4), int(self.vector_length/4)))
        self.img_2hid = nn.Linear(int(self.vector_length/4), h_dim)
        # self.fl1 = nn.Flatten()
        # self.hid_2hid = nn.Linear(int(self.vector_length/4)*h_dim, h_dim*8)
        # self.hid_2hid2 = nn.Linear(h_dim*8, h_dim*4)
        self.hid_2z = nn.Linear(h_dim, z_dim)

        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, int(self.vector_length/4))
        self.convt1 = nn.ConvTranspose2d(1, hidden_channels, 2)
        self.convt2 = nn.ConvTranspose2d(hidden_channels, 4, 3, stride=2)

        self.convt3 = nn.ConvTranspose2d(4, 1, 4, stride=2)
    
    def encode(self, x):
        z = F.relu(self.conv1(x))
        z = F.relu(self.conv2(z))
        z = self.pool1(z)
        #print(f'after pool: {z.shape}')
        z = self.pool2(z)
        #print('after pool 2', z.shape)
        z = self.img_2hid(z)
        #print('downsampled image to hid:', z.shape)
        z = self.hid_2z(z)
        #print('hid_to_z', z.shape)
        #print(f'full layer 1: {z.shape}')
        return z
    
    def decode(self, z):
        #x = z.reshape(z.shape[0], 1, self.vector_length, self.vector_length)
        x = self.z_2hid(z)
        x = self.hid_2img(x)
        #print('hid_to_img', x.shape)
        x = F.relu(self.convt1(x))
        #print(f'conv transpose 1: {x.shape}')
        x = F.relu(self.convt2(x))
        #print(f'conv transpose 2: {x.shape}')
        x = F.relu(self.convt3(x))
        #print(f'conv transpose 3: {x.shape}')
        #print(f'full layer 5: {x.shape}')
        return x

    def forward(self, x):
        z_reparametrized = self.encode(x)

        x_reconst = self.decode(z_reparametrized)
        #print('RECONST', x_reconst.shape)
        return x_reconst

In [None]:
# Dataset loading
# dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
# train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
from galaxy_datasets import gz2  # or gz_hubble, gz_candels, ...

catalog, label_cols = gz2(root='/Users/padmavenkatraman/Documents/SSI/SSI_Projects/gz2/',train=True,download=True)

In [None]:
catalog

In [None]:
catalog['summary_val']=catalog['summary'].map({'smooth_round':1,'smooth_inbetween':2,'smooth_cigar':3,\
                                             'featured_without_bar_or_spiral':4,'edge_on_disk':5,'barred_spiral':6,\
                                             'unbarred_spiral':7,None:8})

In [None]:
catalog

In [None]:
# from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

# datamodule = GalaxyDataModule(
#     catalog=catalog.sample(20000),
#     label_cols =[],
#     # optional args to specify augmentations
# )

# datamodule.prepare_data()
# datamodule.setup()
# '''
# for images, labels in datamodule.train_dataloader():
#     print(images.shape, labels.shape)
#     break'''


from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
datamodule = GalaxyDataModule(
    catalog = catalog[catalog['label']==1].sample(20000,replace=False),
    label_cols = [])
datamodule.prepare_data()
datamodule.setup()
train_data = datamodule.train_dataloader()
val_data = datamodule.val_dataloader()

In [None]:
def show_image(img_arr):
  org_img = img_arr.cpu().detach().numpy()
  org_img_shape = org_img.shape
  print(f'image input dimensions = {org_img_shape}')
  print(f'number of channels = {org_img_shape[0]}')
  print(f'height = {org_img_shape[1]}')
  print(f'width = {org_img_shape[2]}')

  img = np.transpose(org_img, (1, 2, 0))
  print(f'image plotting dimensions = {img.shape}')
  plt.imshow(img)

In [None]:
def plot_real_reconst(old, reconst):
    fig, ax = plt.subplots(1, 2)
    org_img_old = old.cpu().detach().numpy()
    img_old = np.transpose(org_img_old, (1, 2, 0))
    

    org_img_new = reconst.cpu().detach().numpy()
    img_new = np.transpose(org_img_new, (1, 2, 0))

    ax[0].imshow(img_old)
    ax[1].imshow(img_new)
    ax[0].axis('off')
    ax[1].axis('off')
    plt.show()
    

In [None]:
def early_stopping(train_loss, validation_loss, min_delta_frac):
    val_train_diff = (validation_loss - train_loss)/train_loss
    if  val_train_diff > min_delta_frac:
          print('Fractional Difference',val_train_diff)
          return True

# Define train function
def train(num_epochs, model, optimizer, loss_fn, training_dataloader,val_dataloader):
    # Start training
    all_train_loss, all_val_loss = [], []
    train_loss_list, val_loss_list = [], []
    for epoch in range(num_epochs):
        train_loss_per_epoch, val_loss_per_epoch = [], []
        loop = enumerate(training_dataloader)
        print(f'epoch number = {epoch}')
        for i, (x, _) in loop:
            # Forward pass

            x_reconst = model(x.to(device))
            # print(x_reconst.shape, x.shape)
            reconst_loss = loss_fn(x_reconst, x.to(device))
            # kl_div = - torch.mean(1 + torch.log(sigma.pow(2)) - torch.linalg.vector_norm(mu) - sigma.pow(2))

            loss = reconst_loss# + kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # loop.set_postfix(loss=loss.item())
            loss_numpy = loss.cpu().detach().numpy()

            train_loss_per_epoch.append(loss_numpy)
            all_train_loss.append(loss_numpy)
            

            if i%10==0:
              print('JUST TRAINED ANOTHER 10......')
              plot_real_reconst(x[0], x_reconst[0])
              with torch.no_grad():
                for j, (x_val,_) in enumerate(val_dataloader):
                  # validation
                  x_reconst_val = model(x_val.to(device))
                  reconst_loss_val = loss_fn(x_reconst_val, x_val.to(device))
                  # kl_div_val = - torch.mean(1 + torch.log(sigma_val.pow(2)) - torch.linalg.vector_norm(mu_val) - sigma_val.pow(2))
  #                reconst_loss_val = torch.sum((x_reconst - x_val).pow(2))
  #                kl_div_val = -torch.mean(1 + torch.log(torch.linalg.vector_norm(sigma_val)) - torch.linalg.vector_norm(mu_val) - torch.linalg.vector_norm(sigma_val))
                  val_loss = reconst_loss_val# + kl_div_val
                  val_loss_numpy = val_loss.cpu().detach().numpy()
                  val_loss_per_epoch.append(val_loss_numpy)
                  all_val_loss.append(val_loss_numpy)
                  
            if len(train_loss_per_epoch)%20==0:
                  print(f'Current Loss {train_loss_per_epoch[-1]}')
        train_loss_list.append(np.mean(train_loss_per_epoch))
        val_loss_list.append(np.mean(val_loss_per_epoch))
        torch.save(
            {'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': np.mean(train_loss_per_epoch),
            'test_loss':np.mean(val_loss_per_epoch)
        }, f'/Users/padmavenkatraman/Documents/SSI/SSI_Projects/model_save_files/{epoch}_6.pth')
    return all_train_loss, all_val_loss


In [None]:
model = VariationalAutoEncoder(INPUT_DIM, Z_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)


In [None]:
# Initialize model, optimizer, loss
loss_fn = nn.MSELoss(reduction="mean")
# Run training
tr_loss, val_loss = train(NUM_EPOCHS, model, optimizer, loss_fn, train_data,val_data)

In [None]:
trained = torch.load('/Users/padmavenkatraman/Documents/SSI/SSI_Projects/model_save_files/9_5_2.pth')
model.load_state_dict(trained['model_state_dict'])
optimizer.load_state_dict(trained['optimizer_state_dict'])
epoch = trained['epoch']
loss = trained['train_loss']

In [None]:
test_data = datamodule.test_dataloader()


In [None]:
datamodule2 = GalaxyDataModule(
    catalog = catalog[catalog['summary_val']==3].sample(10000,replace=False),
    label_cols = [])
datamodule2.prepare_data()
datamodule2.setup()
test_data2 = datamodule2.test_dataloader()



In [None]:
datamodule3 = GalaxyDataModule(
    catalog = catalog[catalog['label']==3].sample(1000,replace=False),
    label_cols = [])
datamodule3.prepare_data()
datamodule3.setup()
test_data3 = datamodule3.test_dataloader()

In [None]:
def test_loop(testdata, num_loops):
    test_loss_all = []
    with torch.no_grad():
        while len(test_loss_all) < num_loops:
            for j, (xt,_) in enumerate(testdata):
                
                # validation
                xrt = model(xt.to(device))
                reconst_loss_val = loss_fn(xrt, xt.to(device))
                # kl_div_val = - torch.mean(1 + torch.log(sigma_val.pow(2)) - torch.linalg.vector_norm(mu_val) - sigma_val.pow(2))
        #                reconst_loss_val = torch.sum((x_reconst - x_val).pow(2))
        #                kl_div_val = -torch.mean(1 + torch.log(torch.linalg.vector_norm(sigma_val)) - torch.linalg.vector_norm(mu_val) - torch.linalg.vector_norm(sigma_val))
                test_loss = reconst_loss_val# + kl_div_val
                test_loss_numpy = test_loss.cpu().detach().numpy()
                test_loss_all.append(test_loss_numpy)
                if len(test_loss_all)%10 == 0:
                    plot_real_reconst(xt[0], xrt[0])
    return test_loss_all

In [None]:
test_edge = test_loop(test_data3, 256)


In [None]:
test_smooth = test_loop(test_data, 256)
test_cigar = test_loop(test_data2, 256)


In [None]:
len(test_smooth), len(test_cigar), len(test_edge)

In [None]:
plt.hist(test_smooth[:88], bins=np.linspace(20,35, 50), density=True, edgecolor='white', alpha = 0.5, color='r', label='smooth rounded', );
plt.hist(test_cigar, bins=np.linspace(20,35, 50), density=True, edgecolor='white', alpha = 0.5, color='b', label='smooth cigar');

plt.hist(test_edge, bins=np.linspace(20,35, 50),density=True, edgecolor='white', alpha = 0.5, color='k', label='edge on');
plt.xlabel("MSE Loss")
plt.title('reconstruction loss')
plt.legend();

In [None]:
def read_in_epoch_loss(losstype):
    loss_list = []
    for i in np.arange(10):
        l =torch.load(f'/Users/padmavenkatraman/Documents/SSI/SSI_Projects/model_save_files/{i}_5.pth')[losstype]
        loss_list.append(l)
    for i in np.arange(10):
        l =torch.load(f'/Users/padmavenkatraman/Documents/SSI/SSI_Projects/model_save_files/{i}_5_2.pth')[losstype]
        loss_list.append(l)
    return loss_list

In [None]:
train_loss_list = read_in_epoch_loss('train_loss')
val_loss_list = read_in_epoch_loss('test_loss')
plt.plot(np.arange(len(train_loss_list)), train_loss_list, label='train loss', color='k')
plt.plot(np.arange(len(val_loss_list)), val_loss_list, label='validation loss', color='r', alpha = 0.7)
plt.yticks(np.linspace(min(val_loss_list), max(train_loss_list), 20))
min_tr_loss = np.round(min(train_loss_list), 2)
min_val_loss = np.round(min(val_loss_list), 2)
plt.title('Training vs Validation loss for a convolutional autoencoder')
plt.legend()
plt.ylabel('MSE Loss')
plt.xlabel('Number of Epochs')
plt.savefig('loss.png', dpi = 200)

#plt.title(f'minimum train loss = {min_tr_loss}; minimum val loss = {min_val_loss}');
