# Experiment 1: RGM Fundamentals (Visualizing Renormalization)

이 노트북은 **Spatial Renormalizing Generative Model (RGM)**의 핵심 원리인 **"Renormalization (재규격화)"** 과정을 시각적으로 증명하는 독립적인 실험입니다.

RGM의 핵심 가설은 다음과 같습니다:
> "세상은 계층적(Hierarchical)이며, 국소적인 세부 사항(Local Details)은 상위 레벨에서 추상적인 개념(Abstract Concepts)으로 통합(Renormalized)된다."

이 실험에서는 다음 세 가지를 검증합니다:
1.  **Abstraction (추상화)**: 픽셀(Pixels) $\to$ 국소 특징($z_1$) $\to$ 전역 개념($z_2$)으로 정보가 압축되는 과정.
2.  **Instantiation (구체화)**: 동일한 전역 개념($z_2$)이 다양한 국소 특징($z_1$)으로 발현되는 과정 (One Concept, Many Variations).
3.  **Locality (국소성)**: $z_1$의 변화는 이미지의 특정 영역에만 영향을 미치며, 이는 $z_1$이 공간적 위상(Spatial Topology)을 보존함을 의미함.

In [2]:
# 1. Setup and Imports
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.spatial_rgm import SpatialRGM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [3]:
# 2. Prepare Data and Model
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
dataset = MNIST(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize Model
model = SpatialRGM(latent_dim=32, num_classes=10).to(device)

# Quick Training Function (if no pretrained model exists)
def train_quick(model, loader, epochs=3):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    print("Training RGM for demonstration...")
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            # Note: We need to access internal variables, so we might need to modify forward or just use the modules directly
            # But SpatialRGM.forward returns reconstruction. We need to implement a custom loss here similar to the experiment class
            # For simplicity, let's just use the model's components or a simplified training loop
            
            # Re-implementing basic loss logic from MNISTExperiment for standalone execution
            h1 = model.enc1(x)
            z1_logits = model.z1_proj(h1)
            z1_sample = model.reparameterize_gumbel(z1_logits)
            z2_logits = model.enc2(z1_sample)
            z2_sample = model.reparameterize_gumbel(z2_logits)
            z1_prior_logits = model.dec2(z2_sample).view(-1, model.latent_dim, 7, 7)
            recon = model.dec1(z1_sample)
            
            # Losses
            recon_loss = F.mse_loss(recon, x)
            cls_loss = F.cross_entropy(z2_logits, y)
            
            # Prior matching (z1 posterior vs z1 prior)
            # Simple KL divergence approximation for categorical
            p = F.softmax(z1_logits, dim=1)
            q = F.softmax(z1_prior_logits, dim=1)
            prior_loss = torch.sum(p * (torch.log(p + 1e-8) - torch.log(q + 1e-8))) / (x.size(0) * 49)
            
            loss = recon_loss + 1.0 * cls_loss + 0.1 * prior_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(loader):.4f}")

# Train the model
train_quick(model, loader, epochs=3)

Training RGM for demonstration...
Epoch 1/3, Loss: 0.6217
Epoch 1/3, Loss: 0.6217
Epoch 2/3, Loss: 0.1968
Epoch 2/3, Loss: 0.1968
Epoch 3/3, Loss: 0.1688
Epoch 3/3, Loss: 0.1688


## Experiment 1: Hierarchical Abstraction (Bottom-Up)
이미지가 어떻게 $z_1$ (7x7 Grid)과 $z_2$ (Global Class)로 압축되는지 시각화합니다.
- **Input**: 원본 이미지 (28x28)
- **Level 1 ($z_1$)**: 7x7 그리드의 각 셀이 어떤 특징을 잡고 있는지 확인.
- **Level 2 ($z_2$)**: 최종적으로 어떤 숫자로 인식했는지 확인.

In [None]:
model.eval()
x, y = next(iter(loader))
x = x[0:1].to(device)

with torch.no_grad():
    # Level 1
    h1 = model.enc1(x)
    z1_logits = model.z1_proj(h1)
    z1_idx = z1_logits.argmax(dim=1) # (1, 7, 7) - Discrete codes
    
    # Level 2
    z1_sample = model.reparameterize_gumbel(z1_logits)
    z2_logits = model.enc2(z1_sample)
    pred_class = z2_logits.argmax(dim=1).item()

plt.figure(figsize=(10, 4))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(x[0, 0].cpu(), cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title(f"Level 1: z1 Codes (7x7)\n(Local Features)")
plt.imshow(z1_idx[0].cpu(), cmap='tab20')
plt.colorbar()
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title(f"Level 2: z2 Prediction\n(Global Concept: {pred_class})")
plt.bar(range(10), F.softmax(z2_logits, dim=1)[0].cpu().numpy())
plt.xticks(range(10))
plt.show()

## Experiment 2: Concept-Conditional Generation (Top-Down)
**"개념(Concept)은 하나지만, 표현(Instance)은 다양하다."**

상위 레벨 변수 $z_2$를 고정(예: 숫자 '3')하고, 하위 레벨 변수 $z_1$을 $P(z_1|z_2)$에서 여러 번 샘플링하여 복원합니다.
이 실험은 $z_2$가 **"무엇(What)"**을 결정하고, $z_1$이 **"어떻게(How)"**를 결정한다는 것을 증명합니다.

In [None]:
target_digit = 3
num_samples = 5

plt.figure(figsize=(15, 3))
with torch.no_grad():
    # 1. Fix z2 (One-hot vector for target_digit)
    z2_fixed = torch.zeros(num_samples, 10).to(device)
    z2_fixed[:, target_digit] = 1.0
    
    # 2. Predict z1 Prior from z2
    z1_prior_logits = model.dec2(z2_fixed).view(num_samples, 32, 7, 7)
    
    # 3. Sample z1 from this prior (Stochasticity comes from here)
    # We use Gumbel-Softmax to sample diverse z1 codes
    z1_sampled = F.gumbel_softmax(z1_prior_logits, tau=1.0, hard=True)
    
    # 4. Decode to Pixels
    generated = model.dec1(z1_sampled)

for i in range(num_samples):
    plt.subplot(1, num_samples, i+1)
    plt.title(f"Sample {i+1} (Digit {target_digit})")
    plt.imshow(generated[i, 0].cpu(), cmap='gray')
    plt.axis('off')
plt.suptitle(f"Fixed Concept z2={target_digit} -> Diverse Instances via z1 Sampling", fontsize=14)
plt.show()

## Experiment 3: Local vs Global Perturbation
**"Renormalization은 공간적 정보를 보존하면서 압축한다."**

1.  **Local Perturbation**: $z_1$ (7x7) 그리드에서 **단 하나의 셀**만 값을 변경했을 때, 이미지의 **해당 위치**만 변하는지 확인합니다.
2.  **Global Perturbation**: $z_2$ (숫자 클래스)를 변경했을 때, 이미지가 **전역적으로** 변하는지 확인합니다.

In [None]:
# Base Image
x, _ = next(iter(loader))
x = x[0:1].to(device)

with torch.no_grad():
    # Encode to get base z1
    h1 = model.enc1(x)
    z1_logits = model.z1_proj(h1)
    z1_hard = model.reparameterize_gumbel(z1_logits)
    
    # --- 1. Local Perturbation ---
    # Modify the center cell (3, 3) of z1
    z1_perturbed = z1_hard.clone()
    # Shift the categorical code by 1 (cyclic)
    current_code = z1_perturbed[0, :, 3, 3].argmax()
    new_code = (current_code + 5) % 32 # Change to a different code
    z1_perturbed[0, :, 3, 3] = 0
    z1_perturbed[0, new_code, 3, 3] = 1
    
    recon_base = model.dec1(z1_hard)
    recon_local = model.dec1(z1_perturbed)
    
    # --- 2. Global Perturbation ---
    # Change z2 class
    z2_logits = model.enc2(z1_hard)
    z2_hard = model.reparameterize_gumbel(z2_logits)
    current_class = z2_hard.argmax(dim=1).item()
    target_class = (current_class + 1) % 10
    
    z2_new = torch.zeros_like(z2_hard)
    z2_new[0, target_class] = 1.0
    
    # Predict new z1 from new z2
    z1_from_new_z2 = model.dec2(z2_new).view(1, 32, 7, 7)
    z1_from_new_z2 = F.gumbel_softmax(z1_from_new_z2, hard=True)
    recon_global = model.dec1(z1_from_new_z2)

# Visualization
plt.figure(figsize=(12, 4))

plt.subplot(1, 4, 1)
plt.title("Original")
plt.imshow(recon_base[0, 0].cpu(), cmap='gray')
plt.axis('off')

plt.subplot(1, 4, 2)
plt.title("Local Perturbation\n(Center z1 changed)")
plt.imshow(recon_local[0, 0].cpu(), cmap='gray')
plt.axis('off')

plt.subplot(1, 4, 3)
plt.title("Difference (Local)")
diff = torch.abs(recon_base - recon_local)
plt.imshow(diff[0, 0].cpu(), cmap='hot')
plt.axis('off')

plt.subplot(1, 4, 4)
plt.title(f"Global Perturbation\n(Class {current_class}->{target_class})")
plt.imshow(recon_global[0, 0].cpu(), cmap='gray')
plt.axis('off')

plt.show()