#### Kolmogorov Data Evaluation

In [1]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/Colab Notebooks/Colab Notebooks/autoencoders

%pwd

Mounted at /content/gdrive
/content/gdrive/.shortcut-targets-by-id/1UMow24kXYpDLYgShcir7-CB3ZYQsgEih/Colab Notebooks/autoencoders


'/content/gdrive/.shortcut-targets-by-id/1UMow24kXYpDLYgShcir7-CB3ZYQsgEih/Colab Notebooks/autoencoders'

In [3]:
# Encoded test pred.
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# PATHS
kolm_root        = Path("./datasets/kolmogorov_samples")
kolm_out_root    = Path("./best_models/output_kolmogorov")

# input arrays
X_TEST_PATH = kolm_root / "X_test.npy"

# output latents
enc_save_root = kolm_root / "encoded_test_latent"
enc_save_root.mkdir(parents=True, exist_ok=True)

kolm_model_cfgs = {
    "AuE": {
        "dir":  kolm_out_root / "output_kolm_ae",
        "ckpt": "best_overall_AuE.pt",
    },
    "VAE": {
        "dir":  kolm_out_root / "output_kolm_vae",
        "ckpt": "best_overall_VAE.pt",
    },
    "VQVAE": {
        "dir":  kolm_out_root / "output_kolm_vqvae",
        "ckpt": "best_overall_VQVAE.pt",
    },
    "VQVA2": {
        "dir":  kolm_out_root / "output_kolm_vqvae2",
        "ckpt": "best_overall_VQVA2.pt",
    },
}

BOTTLENECK_CH = 56
CODEBOOK_SIZE = 512
COMMIT_BETA   = 0.25
TOP_CH        = 56
DATA_SCALE    = 8.0

class KolmogorovTestDataset(Dataset):
    def __init__(self, x_test_path: Path):
        self.X = np.load(x_test_path, mmap_mode="r")

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.array(x, copy=True)
        if x.ndim == 2:
            x = x[None, :, :]       # (1,H,W)
        x = torch.from_numpy(x).float() / DATA_SCALE
        return x

assert X_TEST_PATH.exists(), f"Missing: {X_TEST_PATH}"
test_dataset = KolmogorovTestDataset(X_TEST_PATH)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

print("X_test samples:", len(test_dataset))
print("Example shape:", tuple(test_dataset[0].shape), "dtype:", test_dataset[0].dtype)

class Snake(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha))

    def forward(self, x):
        a = self.alpha.abs() + 1e-6
        return x + (1.0 / a) * torch.sin(a * x).pow(2)

class EncoderBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(cin, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
            nn.Conv2d(cout, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
        )
    def forward(self, x):
        return self.net(x)

class DecoderBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(cin, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
            nn.Conv2d(cout, cout, 3, padding=1),
            nn.GroupNorm(num_groups=min(8, cout), num_channels=cout), Snake(),
        )
    def forward(self, x):
        return self.net(x)

# MODELS
class AuE(nn.Module):
    def __init__(self, ch=BOTTLENECK_CH):
        super().__init__()
        self.enc = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 128 -> 64
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, ch, 4, stride=2, padding=1),  # 64 -> 32
            Snake(),
        )
        self.dec = nn.Sequential(
            DecoderBlock(ch, 128),
            nn.ConvTranspose2d(128, 64, 2, stride=2),   # 32 -> 64
            Snake(),
            DecoderBlock(64, 64),
            nn.ConvTranspose2d(64, 32, 2, stride=2),    # 64 -> 128
            Snake(),
            nn.Conv2d(32, 1, 1),
            nn.Tanh(),
        )
    def forward(self, x):
        z = self.enc(x)
        xhat = self.dec(z)
        return xhat, {}

class VAE(nn.Module):
    def __init__(self, ch=BOTTLENECK_CH):
        super().__init__()
        self.enc = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            Snake(),
        )
        self.mu     = nn.Conv2d(128, ch, 1)
        self.logvar = nn.Conv2d(128, ch, 1)

        self.dec = nn.Sequential(
            DecoderBlock(ch, 128),
            nn.ConvTranspose2d(128, 64, 2, stride=2),
            Snake(),
            DecoderBlock(64, 64),
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            Snake(),
            nn.Conv2d(32, 1, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        h = self.enc(x)
        mu, logvar = self.mu(h), self.logvar(h)
        z = mu
        xhat = self.dec(z)
        return xhat, {"mu": mu, "logvar": logvar}

class VectorQuantizer(nn.Module):
    def __init__(self, K, D, beta_commit=COMMIT_BETA):
        super().__init__()
        self.K = K
        self.D = D
        self.beta = beta_commit
        self.codebook = nn.Embedding(K, D)
        nn.init.uniform_(self.codebook.weight, -1.0 / D, 1.0 / D)

    def forward(self, z_e):
        B, D, H, W = z_e.shape
        z = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)
        e = self.codebook.weight
        dist = (z.pow(2).sum(1, keepdim=True) + e.pow(2).sum(1) - 2 * z @ e.t())
        idx = torch.argmin(dist, dim=1)
        z_q = self.codebook(idx).view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
        z_q_st = z_e + (z_q - z_e).detach()
        return z_q_st

class VQVAE(nn.Module):
    def __init__(self, K=CODEBOOK_SIZE, D=BOTTLENECK_CH, beta_commit=COMMIT_BETA):
        super().__init__()
        self.encoder = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, D, 4, stride=2, padding=1),
            Snake(),
        )
        self.quant = VectorQuantizer(K, D, beta_commit=beta_commit)

    def forward(self, x):
        z_e = self.encoder(x)
        z_q = self.quant(z_e)
        return z_q, {}

class VQVA2(nn.Module):
    def __init__(self, K=CODEBOOK_SIZE, D=BOTTLENECK_CH, beta_commit=COMMIT_BETA, top_ch=TOP_CH):
        super().__init__()
        self.enc_bottom = nn.Sequential(
            EncoderBlock(1, 32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            Snake(),
            EncoderBlock(64, 64),
            nn.Conv2d(64, D, 4, stride=2, padding=1),
            Snake(),
        )
        self.enc_top = nn.Sequential(
            nn.Conv2d(D, 128, 3, padding=1),
            Snake(),
            nn.Conv2d(128, top_ch, 4, stride=2, padding=1),  # 32 -> 16
            Snake(),
        )

    def forward(self, x):
        zb = self.enc_bottom(x)          # (D,32,32)
        zt = self.enc_top(zb)            # (top_ch,16,16)
        zt_up = F.interpolate(zt, size=zb.shape[-2:], mode="nearest")  # (top_ch,32,32)
        z_cat = torch.cat([zb, zt_up], dim=1)  # (D+top_ch,32,32)
        return z_cat, {}

# LOAD BEST MODEL + ENCODE BATCH
def build_model_by_name(name: str) -> nn.Module:
    if name == "AuE":   return AuE()
    if name == "VAE":   return VAE()
    if name == "VQVAE": return VQVAE()
    if name == "VQVA2": return VQVA2()
    raise ValueError(f"Unknown model: {name}")

def load_best_kolm_model_for_encoding(model_name: str) -> nn.Module:
    info = kolm_model_cfgs[model_name]
    ckpt_path = info["dir"] / info["ckpt"]
    assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}"

    state = torch.load(ckpt_path, map_location="cpu")

    if isinstance(state, dict) and "model" in state:
        sd = state["model"]
    elif isinstance(state, dict) and "model_state" in state:
        sd = state["model_state"]
    else:
        sd = state

    model = build_model_by_name(model_name)
    model.load_state_dict(sd, strict=False)
    model.to(device).eval()
    return model

@torch.no_grad()
def encode_batch(model_name: str, model: nn.Module, xb: torch.Tensor) -> torch.Tensor:
    xb = xb.to(device)
    if model_name == "AuE":
        return model.enc(xb).detach().cpu()

    if model_name == "VAE":
        h = model.enc(xb)
        mu = model.mu(h)
        return mu.detach().cpu()

    if model_name == "VQVAE":
        z_e = model.encoder(xb)
        return z_e.detach().cpu()

    if model_name == "VQVA2":
        zb = model.enc_bottom(xb)
        zt = model.enc_top(zb)
        zt_up = F.interpolate(zt, size=zb.shape[-2:], mode="nearest")
        return torch.cat([zb, zt_up], dim=1).detach().cpu()

    raise ValueError(f"Unknown model: {model_name}")

# ENCODE FULL TEST SET + SAVE
all_latent_paths = {}

for name in ["AuE", "VAE", "VQVAE", "VQVA2"]:
    if name not in kolm_model_cfgs:
        continue

    print(f"\n=== Encoding Kolmogorov test set – {name} ===")
    model = load_best_kolm_model_for_encoding(name)

    latents = []
    for xb in test_loader:
        z = encode_batch(name, model, xb)  # (B, C, 32, 32) or concat for VQVA2
        latents.append(z)

    Z = torch.cat(latents, dim=0).numpy()
    out_path = enc_save_root / f"kolm_{name}_test_latent.npy"
    np.save(out_path, Z)
    all_latent_paths[name] = str(out_path)

    print(f"  Saved {name} latents: {Z.shape} -> {out_path}")

print("\nAll encoded test latents saved:")
for k, v in all_latent_paths.items():
    print(f"  {k}: {v}")

Device: cuda
X_test samples: 2000
Example shape: (1, 128, 128) dtype: torch.float32

=== Encoding Kolmogorov test set – AuE ===
  Saved AuE latents: (2000, 56, 32, 32) -> datasets/kolmogorov_samples/encoded_test_latent/kolm_AuE_test_latent.npy

=== Encoding Kolmogorov test set – VAE ===
  Saved VAE latents: (2000, 56, 32, 32) -> datasets/kolmogorov_samples/encoded_test_latent/kolm_VAE_test_latent.npy

=== Encoding Kolmogorov test set – VQVAE ===
  Saved VQVAE latents: (2000, 56, 32, 32) -> datasets/kolmogorov_samples/encoded_test_latent/kolm_VQVAE_test_latent.npy

=== Encoding Kolmogorov test set – VQVA2 ===
  Saved VQVA2 latents: (2000, 112, 32, 32) -> datasets/kolmogorov_samples/encoded_test_latent/kolm_VQVA2_test_latent.npy

All encoded test latents saved:
  AuE: datasets/kolmogorov_samples/encoded_test_latent/kolm_AuE_test_latent.npy
  VAE: datasets/kolmogorov_samples/encoded_test_latent/kolm_VAE_test_latent.npy
  VQVAE: datasets/kolmogorov_samples/encoded_test_latent/kolm_VQVAE_te