In [None]:
# Libraries
from pathlib import Path
from matplotlib import pyplot as plt
from boutdata import collect
import math 
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus
from torch.distributions import Distribution, Bernoulli
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import ToTensor
from functools import reduce
from typing import *
import matplotlib
from IPython.display import Image, display, clear_output
import numpy as np
%matplotlib nbagg
%matplotlib inline
import seaborn as sns
import pandas as pd
from collections import defaultdict
sns.set_style("whitegrid")

In [22]:
# Load and plot data

# Absolute path to either the BOUT output directory or an individual BOUT.dmp file.
DATA_LOCATION = Path(
	r"data\BOUT.dmp.0.nc"
 )


def load_density(path_hint: Path):
	"""Return the density field collected from BOUT output files."""
	path_hint = path_hint.expanduser().resolve()

	# If the user points at a single file we only need the parent directory.
	bout_dir = path_hint.parent if path_hint.is_file() else path_hint

	density = collect("pe", path=str(bout_dir))
	return density.squeeze()


def plot_timestep(density,idx):
	"""Plot the density at the final time step."""
	fig, ax = plt.subplots()
	contour = ax.contourf(density[idx, :, :].T)
	fig.colorbar(contour, ax=ax, label="n (arb. units)")
	ax.set_title("Density at final timestep")
	ax.set_xlabel("x index")
	ax.set_ylabel("z index")
	plt.show()

density_data = load_density(DATA_LOCATION)
#plot_timestep(density_data,-1)


mxsub = 8 mysub = 1 mz = 1024

nxpe = 128, nype = 1, npes = 128

Reading from 0: [0-9][0-0] -> [0-9][0-0]

Reading from 1: [2-9][0-0] -> [10-17][0-0]

Reading from 2: [2-9][0-0] -> [18-25][0-0]

Reading from 3: [2-9][0-0] -> [26-33][0-0]

Reading from 4: [2-9][0-0] -> [34-41][0-0]

Reading from 5: [2-9][0-0] -> [42-49][0-0]

Reading from 6: [2-9][0-0] -> [50-57][0-0]

Reading from 7: [2-9][0-0] -> [58-65][0-0]

Reading from 8: [2-9][0-0] -> [66-73][0-0]

Reading from 9: [2-9][0-0] -> [74-81][0-0]

Reading from 10: [2-9][0-0] -> [82-89][0-0]

Reading from 11: [2-9][0-0] -> [90-97][0-0]

Reading from 12: [2-9][0-0] -> [98-105][0-0]

Reading from 13: [2-9][0-0] -> [106-113][0-0]

Reading from 14: [2-9][0-0] -> [114-121][0-0]

Reading from 7: [2-9][0-0] -> [58-65][0-0]

Reading from 8: [2-9][0-0] -> [66-73][0-0]

Reading from 9: [2-9][0-0] -> [74-81][0-0]

Reading from 10: [2-9][0-0] -> [82-89][0-0]

Reading from 11: [2-9][0-0] -> [90-97][0-0]

Reading from 12: [2-9][0-0] -> [98-105][0-0]


## VAE

In [23]:
class ReparameterizedDiagonalGaussian(Distribution):
    """
    A distribution `N(y | mu, sigma I)` compatible with the reparameterization trick given `epsilon ~ N(0, 1)`.
    """
    def __init__(self, mu: Tensor, log_sigma:Tensor):
        assert mu.shape == log_sigma.shape, f"Tensors `mu` : {mu.shape} and ` log_sigma` : {log_sigma.shape} must be of the same shape"
        self.mu = mu
        self.sigma = log_sigma.exp()
        
    def sample_epsilon(self) -> Tensor:
        """`\eps ~ N(0, I)`"""
        return torch.empty_like(self.mu).normal_()
        
    def sample(self) -> Tensor:
        """sample `z ~ N(z | mu, sigma)` (without gradients)"""
        with torch.no_grad():
            return self.rsample()
        
    def rsample(self) -> Tensor:
        """sample `z ~ N(z | mu, sigma)` (with the reparameterization trick) """
        epsilon = self.sample_epsilon()
        return self.mu + self.sigma * epsilon
            
    def log_prob(self, z:Tensor) -> Tensor:
        """return the log probability: log `p(z)`"""
        return -0.5 * (math.log(2 * math.pi) + 2 * torch.log(self.sigma) + ((z - self.mu) ** 2) / (self.sigma ** 2))




In [24]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_hw: Tuple[int, int], latent_features: int, in_channels: int = 1) -> None:
        super().__init__()
        self.input_hw = input_hw
        self.latent_features = latent_features
        self.in_channels = in_channels
        
        # Convolutional encoder keeps the code compact while extracting spatial features.
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        
        # Define encoder output dimensions
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_hw)
            encoder_out = self.encoder(dummy)
        self.encoder_shape = encoder_out.shape[1:]
        self.encoder_features = int(np.prod(self.encoder_shape))
        
        self.fc_mu = nn.Linear(self.encoder_features, latent_features)
        self.fc_log_sigma = nn.Linear(self.encoder_features, latent_features)
        self.fc_decode = nn.Linear(latent_features, self.encoder_features)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(self.encoder_shape[0], 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )
        
        # Prior parameters stay identical to the dense version.
        self.register_buffer("prior_params", torch.zeros(1, 2 * latent_features))
        
    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        mu = self.fc_mu(h)
        log_sigma = self.fc_log_sigma(h)
        return mu, log_sigma
        
    def decode(self, z: Tensor) -> Tensor:
        h = self.fc_decode(z)
        h = h.view(z.size(0), *self.encoder_shape)
        return self.decoder(h)
        
    def posterior(self, x: Tensor) -> Distribution:
        mu, log_sigma = self.encode(x)
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
        
    def prior(self, batch_size: int = 1) -> Distribution:
        prior_params = self.prior_params.expand(batch_size, -1)
        mu, log_sigma = prior_params.chunk(2, dim=-1)
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
        
    def observation_model(self, z: Tensor) -> Distribution:
        probs = self.decode(z).clamp(1e-6, 1 - 1e-6)
        return Bernoulli(probs=probs, validate_args=False)
        
    def forward(self, x: Tensor) -> Dict[str, Any]:
        qz = self.posterior(x)
        pz = self.prior(batch_size=x.size(0))
        z = qz.rsample()
        px = self.observation_model(z)
        return {"px": px, "pz": pz, "qz": qz, "z": z}
        
    def sample_from_prior(self, batch_size: int = 16) -> Dict[str, Any]:
        pz = self.prior(batch_size=batch_size)
        z = pz.rsample()
        px = self.observation_model(z)
        return {"px": px, "pz": pz, "z": z}


latent_features = 2
# Example usage once you have tensors shaped as (batch, 1, H, W):
# vae = VariationalAutoencoder(input_hw=(64, 64), latent_features=latent_features)
# print(sum(p.numel() for p in vae.parameters())/1e6, "M parameters")

In [25]:
density_data_tesnor = torch.as_tensor(load_density(DATA_LOCATION), dtype=torch.float32)

mxsub = 8 mysub = 1 mz = 1024

nxpe = 128, nype = 1, npes = 128

Reading from 0: [0-9][0-0] -> [0-9][0-0]

Reading from 1: [2-9][0-0] -> [10-17][0-0]

Reading from 2: [2-9][0-0] -> [18-25][0-0]

Reading from 3: [2-9][0-0] -> [26-33][0-0]

Reading from 4: [2-9][0-0] -> [34-41][0-0]

Reading from 5: [2-9][0-0] -> [42-49][0-0]

Reading from 6: [2-9][0-0] -> [50-57][0-0]

Reading from 7: [2-9][0-0] -> [58-65][0-0]

Reading from 8: [2-9][0-0] -> [66-73][0-0]

Reading from 9: [2-9][0-0] -> [74-81][0-0]

Reading from 10: [2-9][0-0] -> [82-89][0-0]

Reading from 3: [2-9][0-0] -> [26-33][0-0]

Reading from 4: [2-9][0-0] -> [34-41][0-0]

Reading from 5: [2-9][0-0] -> [42-49][0-0]

Reading from 6: [2-9][0-0] -> [50-57][0-0]

Reading from 7: [2-9][0-0] -> [58-65][0-0]

Reading from 8: [2-9][0-0] -> [66-73][0-0]

Reading from 9: [2-9][0-0] -> [74-81][0-0]

Reading from 10: [2-9][0-0] -> [82-89][0-0]

Reading from 11: [2-9][0-0] -> [90-97][0-0]

Reading from 12: [2-9][0-0] -> [98-105][0-0]

Reading 

In [None]:
latent_features = 2
VAE = VariationalAutoencoder(density_data_tesnor[0].shape,latent_features)
print(VAE)

VariationalAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
  )
  (fc_mu): Linear(in_features=4210688, out_features=2, bias=True)
  (fc_log_sigma): Linear(in_features=4210688, out_features=2, bias=True)
  (fc_decode): Linear(in_features=2, out_features=4210688, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): Sigmoid()
  )
)
