In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import trange

import sys
sys.path.insert(0, '../src')

from megs.model.mPCA import mPCA
from megs.data import image, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pickle

# Load mm object from the file
filename = "morphmodel.pkl"
with open(filename, "rb") as file:
    mm = pickle.load(file)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
datamatrix = mm.datamatrix.reshape(mm.datamatrix.shape[0], 3, 64, 64)


data = mm.datamatrix.reshape(12484, 3, 64, 64)
data = data[:, 0, :]  # Use only the first map to learn the model


targets = mm.scores
# targets = datamatrix




# Split the data into train and test sets
train_data, test_data, train_targets, test_targets = train_test_split(data, targets, test_size=0.2, random_state=42)

# Assuming you have your training data and targets as tensors
train_data = torch.Tensor(train_data).to(device)
train_targets = torch.Tensor(train_targets).to(device)

test_data = torch.Tensor(test_data).to(device)
test_targets = torch.Tensor(test_targets).to(device)

# Convert test data and targets into a TensorDataset

# Convert training data and targets into a TensorDataset
train_dataset = TensorDataset(train_data, train_targets)
test_dataset = TensorDataset(test_data, test_targets)


# Define the batch size for training
batch_size = 128



# Create the train_loader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions
from torch.distributions import kl_divergence
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

class PlanarFlow(nn.Module):
    def __init__(self, dim):
        super(PlanarFlow, self).__init__()
        self.u = nn.Parameter(torch.randn(1, dim).normal_(0, 0.01))
        self.w = nn.Parameter(torch.randn(1, dim).normal_(0, 0.01))
        self.b = nn.Parameter(torch.randn(1).zero_())

    def forward(self, z):
        u_hat = self.get_u_hat()
        psi = self.get_psi(z)
        f_z = z + (u_hat * torch.tanh(z @ self.w.t() + self.b))

        log_det = torch.log(torch.abs(1 + psi @ u_hat.t()))

        return f_z, log_det

    def get_psi(self, z):
        return (1 - torch.tanh(z @ self.w.t() + self.b) ** 2) @ self.w

    def get_u_hat(self):
        uw = torch.dot(self.u[0], self.w[0])
        m_uw = -1 + torch.log1p(uw.exp())
        return self.u + (m_uw - uw) * (self.w / torch.norm(self.w))

class PCALayer(nn.Module):
    def __init__(self, eigengalaxies):
        super(PCALayer, self).__init__()
        self.eigengalaxies = nn.Parameter(eigengalaxies, requires_grad=False)

    def forward(self, x):
        return torch.matmul(x.unsqueeze(2), self.eigengalaxies)

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, eigengalaxies, flow_steps):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.eigengalaxies = nn.Parameter(eigengalaxies, requires_grad=False)
        self.flows = nn.ModuleList([PlanarFlow(latent_dim) for _ in range(flow_steps)])

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_var = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            PCALayer(self.eigengalaxies),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*64*3, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, log_var = self.fc_mu(h), self.fc_var(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z0 = self.reparameterize(mu, log_var)
        zK, list_ladj = z0, []

        for flow in self.flows:
            zK, ladj = flow(zK)
            list_ladj.append(ladj)
        return self.decode(zK), mu, log_var, zK, list_ladj

def loss_function(recon_x, x, mu, log_var, z_0, z_k, list_ladj):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    ladj = torch.sum(torch.cat(list_ladj))

    return BCE + KLD + ladj

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

eigendat = mm.get_eigengalaxies()

eigengalaxies = torch.from_numpy(eigendat).float().to(device)
vae = VAE(64*64*3, 60, eigengalaxies, 10).to(device)

optimizer = optim.Adam(vae.parameters())

def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, log_var, z_0, list_ladj = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var, z_0, z_k, list_ladj)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

def main():
    for epoch in range(1, 101):
        train(epoch)

if __name__ == "__main__":
    main()
