# Analysis of Learned AKOrN Model Dynamics

This notebook analyzes the learned parameters and dynamics of an AKOrN model trained on CIFAR-10 classification.
We will examine:

1. **Omega (Ω)**: The learned natural frequencies/rotational matrices
2. **J**: The learned connectivity/coupling matrices 
3. **c**: The external input/bias terms from layer outputs
4. **Dynamics**: Simulate the learned Kuramoto-like dynamics

Model details:
- Architecture: 3-layer AKOrN with channels [128, 256, 512]
- Oscillator dimension: n=2 (complex oscillators)
- Time steps: T=4 per layer
- Best accuracy: 49.33% on CIFAR-10

In [None]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import json
from pathlib import Path
import einops
from einops import rearrange
from sklearn.decomposition import PCA
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Add source directory to path
#sys.path.append('/source')
from models.classification.knet import AKOrN
from data.augs import augmentation_strong

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Learned Model and Configuration

In [None]:
# Load the best model checkpoint
checkpoint_path = "results/20250701_561545.opbs/akorn_cifar10_final.pth"
config_path = "results/20250701_561545.opbs/parameters.json"

# Load configuration
with open(config_path, 'r') as f:
    config = json.load(f)

print("Model Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'epoch' in checkpoint_path:
    print(f"\nLoaded checkpoint from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.4f}")
elif 'final' in checkpoint_path:
    print(f"\nLoaded final checkpoint with accuracy {checkpoint['final_accuracy']:.2f}%")

# Create model with same configuration
model = AKOrN(
    n=config['n'],
    ch=config['ch'], 
    out_classes=config['num_classes'],
    L=config['L'],
    T=config['T'],
    J=config['J'],
    ksizes=config['ksizes'],
    ro_ksize=config['ro_ksize'],
    ro_N=config['ro_N'],
    norm=config['norm'],
    c_norm=config['c_norm'],
    gamma=config['gamma'],
    use_omega=config['use_omega'],
    init_omg=config['init_omg'],
    global_omg=config['global_omg'],
    learn_omg=config['learn_omg'],
    ensemble=config['ensemble']
).to(device)

# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nModel loaded successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 2. Analysis of Learned Omega Parameters

Omega represents the natural frequencies/rotational dynamics of the oscillators.

In [None]:
def extract_omega_parameters(model):
    """Extract omega parameters from all layers"""
    omega_params = []
    
    for layer_idx in range(len(model.layers)):
        layer = model.layers[layer_idx]
        if hasattr(layer[2], 'omg') and hasattr(layer[2].omg, 'omg_param'):
            omega_param = layer[2].omg.omg_param.detach().cpu().numpy()
            omega_params.append(omega_param)
            print(f"Layer {layer_idx}: omega shape = {omega_param.shape}")
            print(f"  Omega values: {omega_param}")
            print(f"  Omega magnitude: {np.linalg.norm(omega_param):.4f}")
    
    return omega_params

omega_params = extract_omega_parameters(model)

# Visualize omega parameters
fig, axes = plt.subplots(1, len(omega_params), figsize=(15, 4))
if len(omega_params) == 1:
    axes = [axes]

for i, omega in enumerate(omega_params):
    ax = axes[i]
    
    # For 2D oscillators, omega_param has shape [2] representing frequency
    ax.bar(['Real', 'Imag'], omega)
    ax.set_title(f'Layer {i} Omega Parameters')
    ax.set_ylabel('Value')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Analyze omega evolution across layers
omega_magnitudes = [np.linalg.norm(omega) for omega in omega_params]
print(f"\nOmega magnitude progression: {omega_magnitudes}")

In [None]:
layer = model.layers[0]
print(layer[2])

In [None]:
omega_param = layer[2].omg.omg_param#.detach().cpu().numpy()
print(omega_param)

## 3. Analysis of Learned Connectivity Matrices (J)

The connectivity matrices determine how oscillators couple with each other.

In [None]:
def extract_connectivity_weights(model):
    """Extract connectivity weight matrices from all layers"""
    connectivity_weights = []
    
    for layer_idx in range(len(model.layers)):
        layer = model.layers[layer_idx]
        if hasattr(layer[2], 'connectivity'):
            weight = layer[2].connectivity.weight.detach().cpu().numpy()
            bias = layer[2].connectivity.bias.detach().cpu().numpy() if layer[2].connectivity.bias is not None else None
            
            connectivity_weights.append({
                'weight': weight,
                'bias': bias,
                'shape': weight.shape
            })
            
            print(f"Layer {layer_idx}: Connectivity weight shape = {weight.shape}")
            print(f"  Weight statistics: mean={weight.mean():.4f}, std={weight.std():.4f}")
            print(f"  Weight range: [{weight.min():.4f}, {weight.max():.4f}]")
            if bias is not None:
                print(f"  Bias statistics: mean={bias.mean():.4f}, std={bias.std():.4f}")
    
    return connectivity_weights

connectivity_weights = extract_connectivity_weights(model)


## $J_{ij}$'s as 2x2 matrices
Here we are to have a close look at the connectivity matrices $J$.

Take the first layer for an example. Here, the whole weight tensor looks like $(n_{\text{output ch}}, n_{\text{input ch}}, H, W)$. For instance, it would be of the shape $(128,128,9,9)$, where the firt corrdinate is the ouptout oscillator that is influenced, the second the input oscillator that gives influence to the first, and the latter two determine the location in the 9x9 kernel. Considering the dimension of the oscillators is two, to see the connectivity $J_{ij}$, or the influence from $j$ to $i$, we are to check the index $[2i:2i+1, 2j:2j+1]$.

We, therefore, have 64x64x9x9 = 3.3e5 connectivity matrices per layer. This is far too many, so we must see summarized statistics of the connectivity matrices. To do so, we first examine
1. Strength of the connectivity
1. 

In [None]:
# Let J = connectivity_weights[0]['weight']. 
# Can you draw a 9x9 subplot, where ij'th subplot is a 64x64 matrix whose kl'th element is the Frobenius norm of J[2k:2k+2, 2l:2l+2, i, j]? 
# Use the same color limit for all subplot. Add titles and suptitle if necessary.
J = connectivity_weights[0]['weight']
fig, axes = plt.subplots(9, 9, figsize=(18, 18))
C_out, C_in, H, W = J.shape
matrices = np.zeros((9, 9, C_out // 2, C_in // 2))

# Compute all matrices and find global vmin/vmax
for i in range(9):
    for j in range(9):
        for k in range(C_out // 2):
            for l in range(C_in // 2):
                block = J[2*k:2*k+2, 2*l:2*l+2, i, j]
                matrices[i, j, k, l] = np.linalg.norm(block, ord='fro')
vmin = matrices.min()
vmax = matrices.max()

for i in range(9):
    for j in range(9):
        ax = axes[i, j]
        im = ax.imshow(matrices[i, j], cmap='viridis', vmin=vmin, vmax=vmax)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f'({i},{j})', fontsize=8)
plt.suptitle('Frobenius Norms of $J_{ij}$ Blocks (Each subplot: 64x64)', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.97])
cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.01, pad=0.01)
plt.show()

In [None]:
config['L']

In [None]:
# Compute and visualize summary statistics (mean, std, min, max) for each 64x64 matrix in the 9x9 grid

for l in range(config['L']):

    J = connectivity_weights[l]['weight']
    C_out, C_in, H, W = J.shape
    matrix_shape = (C_out // 2, C_in // 2)

    # Precompute statistics for each (i, j) kernel position
    means = np.zeros((H, W))
    stds = np.zeros((H, W))
    mins = np.zeros((H, W))
    maxs = np.zeros((H, W))

    for i in range(H):
        for j in range(W):
            mat = np.zeros(matrix_shape)
            for k in range(matrix_shape[0]):
                for l in range(matrix_shape[1]):
                    block = J[2*k:2*k+2, 2*l:2*l+2, i, j]
                    mat[k, l] = np.linalg.norm(block, ord='fro')
            means[i, j] = mat.mean()
            stds[i, j] = mat.std()
            mins[i, j] = mat.min()
            maxs[i, j] = mat.max()

    # Plot all statistics in a 2x2 grid of 9x9 subplots
    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    stat_titles = ['Mean', 'Std', 'Min', 'Max']
    stat_arrays = [means, stds, mins, maxs]
    cmaps = ['viridis', 'magma', 'Blues', 'Reds']

    for idx, (ax, stat, title, cmap) in enumerate(zip(axes.flat, stat_arrays, stat_titles, cmaps)):
        im = ax.imshow(stat, cmap=cmap)
        ax.set_title(f'{title} of {matrix_shape[0]}x{matrix_shape[1]} Block ({H}x{W} kernels)', fontsize=14)
        ax.set_xlabel('Kernel W')
        ax.set_ylabel('Kernel H')
        ax.grid(False)
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    plt.suptitle(f'Summary Statistics for Frobenius norms of {l+1} Oscillators ({H}x{W} Kernel Grid)', fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

In [None]:
# # This time, can you draw a 64x64 subplot, where kl'th subplot is a 9x9 matrix whose ij'th element is the Frobenius norm of J[2k:2k+2, 2l:2l+2, i, j]? 
# # Use the same color limit for all subplot. Add titles and suptitle if necessary.
# # Assume J = connectivity_weights[0]['weight'] (shape: [128, 128, 9, 9])
# J = connectivity_weights[0]['weight']
# C_out, C_in, H, W = J.shape
# n = 2  # oscillator dimension

# # Number of oscillators per channel
# num_out = C_out // n
# num_in = C_in // n

# # Compute the 64x64 grid of 9x9 matrices (each entry is a 9x9 Frobenius norm map)
# frobenius_maps = np.zeros((num_out, num_in, H, W))
# for k in range(num_out):
#     for l in range(num_in):
#         for i in range(H):
#             for j in range(W):
#                 block = J[n*k:n*k+n, n*l:n*l+n, i, j]
#                 frobenius_maps[k, l, i, j] = np.linalg.norm(block, ord='fro')

# # Find global vmin/vmax for consistent color scale
# vmin = frobenius_maps.min()
# vmax = frobenius_maps.max()

# fig, axes = plt.subplots(num_out, num_in, figsize=(num_in, num_out), dpi=120)
# if num_out == 1 and num_in == 1:
#     axes = np.array([[axes]])
# elif num_out == 1 or num_in == 1:
#     axes = axes.reshape(num_out, num_in)

# for k in range(num_out):
#     for l in range(num_in):
#         ax = axes[k, l]
#         im = ax.imshow(frobenius_maps[k, l], cmap='viridis', vmin=vmin, vmax=vmax)
#         ax.set_xticks([])
#         ax.set_yticks([])
#         if k == 0:
#             ax.set_title(f'in {l}', fontsize=6)
#         if l == 0:
#             ax.set_ylabel(f'out {k}', fontsize=6)

# plt.suptitle('Each subplot: 9x9 Frobenius norm map for J[2k:2k+2, 2l:2l+2, :, :]', fontsize=12)
# plt.tight_layout(rect=[0, 0, 1, 0.97])
# fig.colorbar(im, ax=axes.ravel().tolist(), orientation='vertical', fraction=0.01, pad=0.01)
# plt.show()

In [None]:
# Compute and visualize summary statistics (mean, std, min, max) for each 9x9 matrix in the 64x64 grid

In [None]:
# Compute and visualize summary statistics (mean, std, min, max) for each 9x9 matrix in the 64x64 grid
for lay in range(config['L']):

    J = connectivity_weights[lay]['weight']
    C_out, C_in, H, W = J.shape
    n = 2  # oscillator dimension (already defined above)
    num_out = C_out // n
    num_in = C_in // n
    matrix_shape = (num_out, num_in)
    # frobenius_maps: shape [num_out, num_in, H, W] (from cell 14)
    # Each [k, l] is a 9x9 matrix

    mean_grid = np.zeros((num_out, num_in))
    std_grid = np.zeros((num_out, num_in))
    min_grid = np.zeros((num_out, num_in))
    max_grid = np.zeros((num_out, num_in))

    # Compute the 64x64 grid of 9x9 matrices (each entry is a 9x9 Frobenius norm map)
    frobenius_maps = np.zeros((num_out, num_in, H, W))
    for k in range(num_out):
        for l in range(num_in):
            for i in range(H):
                for j in range(W):
                    block = J[n*k:n*k+n, n*l:n*l+n, i, j]
                    frobenius_maps[k, l, i, j] = np.linalg.norm(block, ord='fro')


    for k in range(num_out):
        for l in range(num_in):
            mat = frobenius_maps[k, l]
            mean_grid[k, l] = mat.mean()
            std_grid[k, l] = mat.std()
            min_grid[k, l] = mat.min()
            max_grid[k, l] = mat.max()

    stat_grids = [mean_grid, std_grid, min_grid, max_grid]
    stat_titles = ['Mean', 'Std', 'Min', 'Max']
    cmaps = ['viridis', 'magma', 'Blues', 'Reds']

    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    for idx, (ax, grid, title, cmap) in enumerate(zip(axes.flat, stat_grids, stat_titles, cmaps)):
        im = ax.imshow(grid, cmap=cmap)
        ax.set_title(f'{title} of {H}x{W} Frobenius Norms ({matrix_shape[0]}x{matrix_shape[1]} grid)')
        ax.set_xlabel('Input Oscillator')
        ax.set_ylabel('Output Oscillator')
        ax.grid(False)
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    plt.suptitle(f'Summary Statistics for Each {H}x{W} Matrix in {matrix_shape[0]}x{matrix_shape[1]} Grid', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

## How many clusters of J_ij?

In [None]:
# Let J = connectivity_weights[0]['weight'].
# J = connectivity_weights[0]['weight']という変数を[128, 128, 9, 9]型torch.tensorとする。
# 各i, j, k, lに対してJ[2k:2k+2, 2l:2l+2, i, j] を上でいった2x2の行列とする。
# 実装例で示してくれたみたいに4次元行列にflattenしたときのユークリッド距離でクラスターを見たいんだけど、今言った状況で実装してみてくれない？

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# ------------------------------------------------------------
# 0. 前提: J は [128,128,9,9] の torch.Tensor
# ------------------------------------------------------------
J = connectivity_weights[0]['weight']          # dtype=float32/64 など想定
assert J.shape == (128, 128, 9, 9)

# ------------------------------------------------------------
# 1. 2×2 ブロックを全部取り出す
#    reshape → transpose で (k,l,i,j,2,2) = (64,64,9,9,2,2)
# ------------------------------------------------------------
blocks = (
    J.reshape(64, 2, 64, 2, 9, 9)      # (64,2,64,2,9,9)
     .transpose(0, 2, 4, 5, 1, 3)      # (64,64,9,9,2,2)
     .reshape(-1, 2, 2)                # (331_776, 2, 2)
)

# ------------------------------------------------------------
# 2. vec(A) = (a11,a12,a21,a22) に flatten して NumPy へ
# ------------------------------------------------------------
X = blocks.reshape(blocks.shape[0], -1)    # (331_776, 4)

# ------------------------------------------------------------
# 3. z-スコア標準化（列ごとに平均0, 分散1）
# ------------------------------------------------------------
#X = StandardScaler().fit_transform(X)

# ------------------------------------------------------------
# 4. k を 2‥10 で総当たり → シルエット最大を採択
# ------------------------------------------------------------
from tqdm import tqdm

best_k, best_score = None, -1
for k in tqdm(range(2, 11)):
    km = KMeans(n_clusters=k, n_init='auto', random_state=0).fit(X)
    score = silhouette_score(X, km.labels_, sample_size=10_000)
    if score > best_score:
        best_k, best_score, best_model = k, score, km
print(f"採択: k={best_k}, silhouette={best_score:.3f}")

labels  = best_model.labels_                      # (331_776,)
centers = best_model.cluster_centers_.reshape(best_k, 2, 2)

# ------------------------------------------------------------
# 5. ざっと結果を眺める
# ------------------------------------------------------------
for k in range(best_k):
    cnt = (labels == k).sum()
    print(f"Cluster {k}: {cnt:6d} mats  |  center =\n{centers[k]}")


In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
pca = PCA(n_components=3, random_state=0)
X_pca = pca.fit_transform(X)

plt.figure(figsize=(6,5))
plt.scatter(X_pca[:,0], X_pca[:,1], s=5, c=labels, cmap='tab10')
plt.title('PCA (2-D) of 2×2 blocks')
plt.xlabel('PC1'); plt.ylabel('PC2')
plt.tight_layout(); plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D  # 明示的に import（必須）
from mpl_toolkits.mplot3d import proj3d  # 投影処理用（通常不要）

fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111, projection='3d')

# 3D 座標 (N, 3)、色ラベル (N,)
ax.scatter(X_pca[:,0], X_pca[:,1], X_pca[:,2], c=labels, s=5, cmap='tab10')

ax.set_title("3D Scatter")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
plt.tight_layout()
plt.show()

In [None]:
N          = X.shape[0]
rng        = np.random.default_rng(0)
idx_vis    = rng.choice(N, size=20_000, replace=False)
X_vis      = X[idx_vis]
labels_vis = labels[idx_vis]

tsne = TSNE(
    n_components=2,
    perplexity=30,
    init='pca',
    learning_rate='auto',
    random_state=0,
)
X_tsne = tsne.fit_transform(X_vis)

plt.figure(figsize=(6,5))
plt.scatter(X_tsne[:,0], X_tsne[:,1], s=5, c=labels_vis, cmap='tab10')
plt.title('t-SNE of 2×2 blocks')
plt.xlabel('dim-1'); plt.ylabel('dim-2')
plt.tight_layout(); plt.show()

In [None]:
import umap

# ------------------------------------------------------------
# 1. 描画用にランダム・サブサンプル（推奨 1–2 万点）
# ------------------------------------------------------------
N          = X.shape[0]
rng        = np.random.default_rng(0)
idx_vis    = rng.choice(N, size=20_000, replace=False)
X_vis      = X[idx_vis]
labels_vis = labels[idx_vis]

# ------------------------------------------------------------
# 2. UMAP (2 次元) で埋め込み
#    ・n_neighbors: 近傍のスケール (5–50 でチューニング)
#    ・min_dist   : 点をどれだけ詰めるか (0–0.5)
# ------------------------------------------------------------
mapper = umap.UMAP(
    n_components=2,
    n_neighbors=30,
    min_dist=0.1,
    metric="euclidean",
    random_state=0,
)
X_umap = mapper.fit_transform(X_vis)   # (20_000, 2)

# ------------------------------------------------------------
# 3. 可視化
# ------------------------------------------------------------
plt.figure(figsize=(6,5))
plt.scatter(
    X_umap[:, 0], X_umap[:, 1],
    s=5, c=labels_vis, cmap="tab10", alpha=0.8
)
plt.title("UMAP of 2×2 blocks")
plt.xlabel("UMAP-1"); plt.ylabel("UMAP-2")
plt.tight_layout()
plt.show()

In [None]:
k = 2
km = KMeans(n_clusters=k, n_init='auto', random_state=0).fit(X)

In [None]:
# # Visualize connectivity weight distributions
# fig, axes = plt.subplots(2, len(connectivity_weights), figsize=(5*len(connectivity_weights), 8))
# if len(connectivity_weights) == 1:
#     axes = axes.reshape(-1, 1)

# for i, conn in enumerate(connectivity_weights):
#     weight = conn['weight']
#     bias = conn['bias']
    
#     # Weight distribution
#     axes[0, i].hist(weight.flatten(), bins=50, alpha=0.7, density=True)
#     axes[0, i].set_title(f'Layer {i} Weight Distribution')
#     axes[0, i].set_xlabel('Weight Value')
#     axes[0, i].set_ylabel('Density')
#     axes[0, i].grid(True, alpha=0.3)
    
#     # Weight magnitude heatmap (first few filters)
#     # Show average over spatial dimensions for first 16 filters
#     if len(weight.shape) == 4:  # Conv weight [out_ch, in_ch, h, w]
#         weight_viz = np.mean(np.abs(weight[:16, :16]), axis=(2, 3))  # Average over spatial dims
#         im = axes[1, i].imshow(weight_viz, cmap='viridis', aspect='auto')
#         axes[1, i].set_title(f'Layer {i} Weight Magnitude (16x16 filters)')
#         axes[1, i].set_xlabel('Input Channel')
#         axes[1, i].set_ylabel('Output Channel')
#         axes[1, i].grid(False)
#         plt.colorbar(im, ax=axes[1, i])

# plt.tight_layout()
# plt.show()

# # Analyze kernel patterns
# print("\nAnalyzing learned kernel patterns:")
# for i, conn in enumerate(connectivity_weights):
#     weight = conn['weight']
#     if len(weight.shape) == 4:  # Conv kernels
#         kernel_size = weight.shape[2]
#         print(f"\nLayer {i} (kernel size {kernel_size}x{kernel_size}):")
        
#         # Compute average kernel
#         avg_kernel = np.mean(weight, axis=(0, 1))  # Average over input/output channels
#         print(f"  Average kernel center value: {avg_kernel[kernel_size//2, kernel_size//2]:.4f}")
#         print(f"  Average kernel edge/center ratio: {np.mean(avg_kernel[0, :]) / avg_kernel[kernel_size//2, kernel_size//2]:.4f}")

### Visualize Individual Kernels

In [None]:
# def visualize_conv_kernels(connectivity_weights, layer_idx=0, num_kernels=16):
#     """Visualize individual convolutional kernels"""
#     if layer_idx >= len(connectivity_weights):
#         print(f"Layer {layer_idx} not found")
#         return
    
#     weight = connectivity_weights[layer_idx]['weight']  # [out_ch, in_ch, h, w]
#     out_ch, in_ch, h, w = weight.shape
    
#     # Select kernels to visualize
#     num_kernels = min(num_kernels, out_ch)
#     kernel_indices = np.linspace(0, out_ch-1, num_kernels, dtype=int)
    
#     fig, axes = plt.subplots(4, 4, figsize=(12, 12))
#     axes = axes.ravel()
    
#     for i, kernel_idx in enumerate(kernel_indices):
#         if i >= 16:
#             break
            
#         # Average over input channels for visualization
#         kernel = np.mean(weight[kernel_idx], axis=0)
        
#         im = axes[i].imshow(kernel, cmap='RdBu_r', vmin=-np.abs(kernel).max(), vmax=np.abs(kernel).max())
#         axes[i].set_title(f'Filter {kernel_idx}')
#         axes[i].axis('off')
#         plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
    
#     # Hide unused subplots
#     for i in range(len(kernel_indices), 16):
#         axes[i].axis('off')
    
#     plt.suptitle(f'Layer {layer_idx} Convolutional Kernels ({h}x{w})', fontsize=16)
#     plt.tight_layout()
#     plt.show()

# # Visualize kernels for each layer
# for layer_idx in range(len(connectivity_weights)):
#     visualize_conv_kernels(connectivity_weights, layer_idx, num_kernels=16)

## 4. Simulate Dynamics with Learned Parameters

Now we'll simulate the actual Kuramoto dynamics using learned parameters on real CIFAR-10 images.

In [None]:
# Load CIFAR-10 for testing dynamics
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# CIFAR-10 class names
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Get a sample image
idx_img = 17
sample_image, sample_label = test_dataset[idx_img]
sample_image = sample_image.unsqueeze(0).to(device)
# sample_image, sample_label = next(iter(test_loader))
# sample_image = sample_image.to(device)

print(f"Sample image shape: {sample_image.shape}")
print(f"Sample label: {sample_label} ({classes[sample_label]})")
#print(f"Sample label: {sample_label.item()} ({classes[sample_label.item()]})")

# Visualize the sample
img_np = sample_image[0].cpu().permute(1, 2, 0).numpy()
plt.figure(figsize=(6, 6))
plt.imshow(img_np)
plt.title(f'Sample CIFAR-10 Image: {classes[sample_label]}')
#plt.title(f'Sample CIFAR-10 Image: {classes[sample_label.item()]}')
plt.axis('off')
plt.show()


In [None]:
def simulate_layer_dynamics(model, input_tensor, x, layer_idx, return_trajectory=True):
    """Simulate dynamics for a specific layer and return trajectory"""
    model.eval()
    
    with torch.no_grad():
        # Use the model's feature extraction method to get to the target layer
        c = model.conv0(model.rgb_normalize(input_tensor))
        # x = torch.randn_like(c)
        if x is None:
            x = torch.randn_like(c)
        elif x.shape != c.shape:
            raise ValueError(f"Input tensor shape {x.shape} does not match expected shape {c.shape}")
        
        # Forward through layers up to target layer
        for i in range(layer_idx):
            layer = model.layers[i]
            transition_layer, _, k_layer, readout_layer, _ = layer
            
            # Apply transition
            x, c = transition_layer[0](x), transition_layer[1](c)
            
            # Apply KLayer (get final state only for intermediate layers)
            T_val = model.T[i] if hasattr(model.T, '__getitem__') else model.T
            layer_xs, layer_es = k_layer(x, c, T_val, model.gamma)
            x = layer_xs[-1]  # Take final state
            
            # Apply readout
            c = readout_layer(x)
        
        # Now simulate dynamics for the target layer
        target_layer = model.layers[layer_idx]
        transition_layer, _, k_layer, readout_layer, _ = target_layer
        
        # Apply transition for target layer
        x, c = transition_layer[0](x), transition_layer[1](c)
        layer_input = x.clone()  # Save input to KLayer
        
        # Apply KLayer with full trajectory
        T_val = model.T[layer_idx] if hasattr(model.T, '__getitem__') else model.T
        xs, es = k_layer(x, c, T_val, model.gamma)
        
        if return_trajectory:
            return xs[-1], xs, es, layer_input  # final_x, trajectory, energies, input
        else:
            return xs[-1], None, None, layer_input

# Simulate dynamics for each layer
layer_results = {}

for layer_idx in range(config['L']):
    print(f"\nSimulating Layer {layer_idx} dynamics...")
    
    x = torch.randn_like(sample_image)  # Random initial state
    final_x, xs, es, layer_input = simulate_layer_dynamics(model, sample_image, None, layer_idx, return_trajectory=True)
    
    layer_results[layer_idx] = {
        'input': layer_input.cpu().numpy(),
        'final_output': final_x.cpu().numpy(),
        'trajectory': [x.cpu().numpy() for x in xs] if xs else None,
        'energies': [e.cpu().numpy() for e in es] if es else None
    }
    
    if xs is not None:
        print(f"  Trajectory length: {len(xs)} time steps")
        print(f"  Input shape: {layer_input.shape}")
        print(f"  Output shape: {final_x.shape}")
    
    if es is not None:
        energies_np = [e.cpu().numpy() for e in es]
        print(f"  Energy range: [{min([e.min() for e in energies_np]):.4f}, {max([e.max() for e in energies_np]):.4f}]")

print("\nDynamics simulation completed!")

## See what's happening closely in the first layer (layer 0)

In [None]:
model.eval()

layer_idx = 0  # Change this to analyze different layers


with torch.no_grad():
    # Use the model's feature extraction method to get to the target layer
    c_raw = model.conv0(model.rgb_normalize(sample_image))
    x_raw = torch.randn_like(c_raw)


    layer_0 = model.layers[layer_idx]
    transition_layer_0, _, k_layer_0, readout_layer_0, _ = layer_0

    # Apply transition to the raw input
    x, c = transition_layer_0[0](x_raw), transition_layer_0[1](c_raw)

    # plot how x_raw, c_raw, x, and c look like
    fig, axes = plt.subplots(2, 2, figsize=(4, 4))
    axes = axes.ravel()
    axes[0].imshow(x_raw[0,0,:,:].cpu().numpy())
    axes[0].set_title('Raw Input x_raw')
    axes[0].axis('off')
    axes[1].imshow(c_raw[0,0,:,:].cpu().numpy())
    axes[1].set_title('Raw Input c_raw')
    axes[1].axis('off')
    axes[2].imshow(x[0,0,:,:].cpu().numpy())
    axes[2].set_title('Transformed Input x')
    axes[2].axis('off')
    axes[3].imshow(c[0,0,:,:].cpu().numpy())
    axes[3].set_title('Transformed c')
    axes[3].axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
x_raw[0,0,:,:] == x[0,0,:,:]  # Check if the raw input and transformed input are the same

In [None]:
T_val = model.T[layer_idx] if hasattr(model.T, '__getitem__') else model.T

# set a counterfactual T to see how the dynamics proceeds
T_long = 20
layer_xs, layer_es = k_layer_0(x, c, T_long, model.gamma)
# x = layer_xs[-1]  # Take final state

# layer_xs: list of tensors, each tensor (1, C, H, W) is the state at a time step


In [None]:
layer_xs[0].shape

In [None]:
layer_xs[0][0,0,:,:]**2 + layer_xs[0][0,1,:,:]**2


In [None]:

layer_thetas = []
for i in range(len(layer_xs)):
    n_osc = layer_xs[i].shape[1]
    layer_theta = torch.atan2(layer_xs[i][0,range(1,n_osc,2),:,:], layer_xs[i][0,range(0,n_osc,2),:,:]).unsqueeze(0)
    #layer_theta = [torch.atan2(layer_xs[i][0,2*k+1,:,:], layer_xs[i][0,2*k,:,:]) for k in range(layer_xs[i].shape[1] // 2)]
    layer_thetas.append(layer_theta)

In [None]:
osc = 0
theta_series = [lt[0, osc].detach().cpu().numpy() if hasattr(lt, 'cpu') else lt[0, osc] for lt in layer_thetas]
theta_series = np.stack(theta_series, axis=0)  # [T, H, W]
T = theta_series.shape[0]

fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(theta_series[0], cmap='twilight', vmin=-np.pi, vmax=np.pi)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
ax.set_title(f"Oscillator Phase Evolution (osc={osc})")
ax.axis('off')

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation

# ------------------------------------------------------------------
# 1. Collect the phase snapshots for every oscillator only once
# ------------------------------------------------------------------
n_osc = layer_thetas[0].shape[1]          # 64 in your example
theta_series_all = []                     # list length = n_osc

for osc in range(n_osc):
    theta_ts = [lt[0, osc].detach().cpu().numpy()
                if hasattr(lt, 'cpu') else lt[0, osc]
                for lt in layer_thetas]   # list of [H,W] frames
    theta_series_all.append(np.stack(theta_ts, axis=0))  # shape [T,H,W]

T = theta_series_all[0].shape[0]          # number of animation frames

# ------------------------------------------------------------------
# 2. Build the 8×8 panel once, store the AxesImage handles
# ------------------------------------------------------------------
fig, axes = plt.subplots(8, 8, figsize=(15, 15), constrained_layout=True)
ims = []                                  # AxesImage objects, one per osc

for osc, ax in enumerate(axes.flat):
    im = ax.imshow(theta_series_all[osc][0],
                   cmap='twilight', vmin=-np.pi, vmax=np.pi)
    ims.append(im)
    #ax.set_title(f"No. {osc}", fontsize=8)
    ax.axis('off')

# ------------------------------------------------------------------
# 3. Single shared colour-bar
# ------------------------------------------------------------------
# cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])          # [left, bottom, width, height]
# fig.colorbar(plt.cm.ScalarMappable(cmap='twilight',
#                                    norm=plt.Normalize(vmin=-np.pi, vmax=np.pi)),
#              cax=cbar_ax, label=r'phase $\theta$')

# ------------------------------------------------------------------
# 4. Animation function that updates every panel
# ------------------------------------------------------------------
def animate(t):
    for im, series in zip(ims, theta_series_all):
        im.set_array(series[t])
    return ims                           # return the artists you updated

ani = animation.FuncAnimation(fig, animate, frames=T,
                              interval=200, blit=True)     # blit True is fine here
# display in notebook (Jupyter/IPython)
# from IPython.display import HTML
# HTML(ani.to_jshtml())


In [None]:
HTML(ani.to_jshtml())

In [None]:
# Can you summarize what I have done until here from "See what's happening..." and make a function that returns the 8x8 animation for arbitrary input x?
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.animation as animation

def create_8x8_animation(model, input_img, x=None, layer_idx=0, n_steps=20, fps=5):
    """
    Create an 8x8 animation showing the phase evolution of oscillators for arbitrary input x.

    Args:
        model: The AKOrN model.
        input_tensor: The input tensor (e.g., an image).
        x: Optional initial state tensor. If None, a random tensor is used.
        layer_idx: The layer index to analyze.
        T_long: Number of time steps for dynamics simulation.
        fps: Frames per second for the animation.

    Returns:
        ani: The animation object.
    """
    model.eval()
    with torch.no_grad():
        # Preprocess input and initialize x
        c = model.conv0(model.rgb_normalize(input_img))
        if x is None:
            x = torch.randn_like(c)
        
        # Extract the target layer
        layer = model.layers[layer_idx]
        transition_layer, _, k_layer, _, _ = layer

        # Apply transition
        x, c = transition_layer[0](x), transition_layer[1](c)

        # Simulate dynamics
        layer_xs, _ = k_layer(x, c, n_steps, model.gamma)

        # Extract phases
        layer_thetas = []
        for t in range(len(layer_xs)):
            n_osc = layer_xs[t].shape[1]
            theta = torch.atan2(layer_xs[t][0, range(1, n_osc, 2), :, :],
                                layer_xs[t][0, range(0, n_osc, 2), :, :]).unsqueeze(0)
            layer_thetas.append(theta)

        # Prepare data for animation
        n_osc = layer_thetas[0].shape[1]
        theta_series_all = []
        for osc in range(n_osc):
            theta_ts = [lt[0, osc].detach().cpu().numpy() for lt in layer_thetas]
            theta_series_all.append(np.stack(theta_ts, axis=0))
        T = theta_series_all[0].shape[0]

        # Create 8x8 grid animation
        fig, axes = plt.subplots(8, 8, figsize=(15, 15), constrained_layout=True)
        ims = []
        for osc, ax in enumerate(axes.flat):
            im = ax.imshow(theta_series_all[osc][0], cmap='twilight', vmin=-np.pi, vmax=np.pi)
            ims.append(im)
            ax.axis('off')

        def animate(t):
            for im, series in zip(ims, theta_series_all):
                im.set_array(series[t])
            return ims

        ani = animation.FuncAnimation(fig, animate, frames=T, interval=1000 // fps, blit=True)
        return ani

In [None]:
# Get a sample image
idx_img_x = 23
x_image, _ = test_dataset[idx_img_x]
x_image = x_image.unsqueeze(0).to(device)

print(f"A different image that serve as x: {x_image.shape}")
# print(f"Sample label: {sample_label} ({classes[sample_label]})")
#print(f"Sample label: {sample_label.item()} ({classes[sample_label.item()]})")

# Visualize the sample
img_np_x = x_image[0].cpu().permute(1, 2, 0).numpy()
plt.figure(figsize=(6, 6))
plt.imshow(img_np_x)
plt.title(f'A different CIFAR-10 Image')
#plt.title(f'Sample CIFAR-10 Image: {classes[sample_label.item()]}')
plt.axis('off')
plt.show()


In [None]:
layer_idx = 0
layer_0 = model.layers[layer_idx]
transition_layer_0, _, k_layer_0, readout_layer_0, _ = layer_0

# Apply transition to the raw input
x_sample = model.conv0(model.rgb_normalize(x_image))

In [None]:
ani = create_8x8_animation(model, sample_image, x_sample, layer_idx=0, n_steps=20, fps=5)
HTML(ani.to_jshtml())

In [None]:

# Save as GIF
ani.save('oscillator_phase_8x8_evolution.gif', writer='pillow', fps=5)
plt.close(fig)
print("Animation saved as oscillator_phase_evolution.gif")

## Well, some s

In [None]:
import matplotlib.animation as animation

# Prepare the data: layer_thetas is a list of [channel][H,W] arrays for each time step
# We want to animate the time series of layer_thetas[t][0, :, :] for t in range(len(layer_thetas))
theta_series = [lt[0].cpu().numpy() if hasattr(lt[0], 'cpu') else lt[0] for lt in layer_thetas]  # shape: [T, 32, 32]
theta_series = np.stack(theta_series, axis=0)  # [T, 32, 32]
T = theta_series.shape[0]

fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(theta_series[0], cmap='hsv', vmin=-np.pi, vmax=np.pi)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
ax.set_title("Oscillator Phase Evolution (channel 0)")
ax.axis('off')

def animate(t):
    im.set_array(theta_series[t])
    ax.set_title(f"Oscillator Phase Evolution (t={t})")
    return [im]

In [None]:

# Forward through layers up to target layer
for i in range(layer_idx):
    layer = model.layers[i]
    transition_layer, _, k_layer, readout_layer, _ = layer
    
    # Apply transition
    x, c = transition_layer[0](x), transition_layer[1](c)
    
    # Apply KLayer (get final state only for intermediate layers)
    T_val = model.T[i] if hasattr(model.T, '__getitem__') else model.T
    layer_xs, layer_es = k_layer(x, c, T_val, model.gamma)
    x = layer_xs[-1]  # Take final state
    
    # Apply readout
    c = readout_layer(x)

# Now simulate dynamics for the target layer
target_layer = model.layers[layer_idx]
transition_layer, _, k_layer, readout_layer, _ = target_layer

# Apply transition for target layer
x, c = transition_layer[0](x), transition_layer[1](c)
layer_input = x.clone()  # Save input to KLayer

# Apply KLayer with full trajectory
T_val = model.T[layer_idx] if hasattr(model.T, '__getitem__') else model.T
xs, es = k_layer(x, c, T_val, model.gamma)

if return_trajectory:
    return xs[-1], xs, es, layer_input  # final_x, trajectory, energies, input
else:
    return xs[-1], None, None, layer_input


## 5. Energy Evolution Analysis

In [None]:
def plot_energy_evolution(layer_results):
    """Plot energy evolution for all layers"""
    fig, axes = plt.subplots(1, len(layer_results), figsize=(5*len(layer_results), 4))
    if len(layer_results) == 1:
        axes = [axes]
    
    for layer_idx, results in layer_results.items():
        if results['energies'] is not None:
            energies = results['energies']
            
            # Plot mean energy over batch
            mean_energies = [e.mean() for e in energies]
            std_energies = [e.std() for e in energies]
            
            time_steps = range(len(mean_energies))
            
            axes[layer_idx].plot(time_steps, mean_energies, 'b-', linewidth=2, label='Mean Energy')
            axes[layer_idx].fill_between(time_steps, 
                                       np.array(mean_energies) - np.array(std_energies),
                                       np.array(mean_energies) + np.array(std_energies),
                                       alpha=0.3, color='blue')
            
            axes[layer_idx].set_title(f'Layer {layer_idx} Energy Evolution')
            axes[layer_idx].set_xlabel('Time Step')
            axes[layer_idx].set_ylabel('Energy')
            axes[layer_idx].grid(True, alpha=0.3)
            axes[layer_idx].legend()
    
    plt.tight_layout()
    plt.show()

plot_energy_evolution(layer_results)

# Print energy statistics
print("Energy Evolution Statistics:")
for layer_idx, results in layer_results.items():
    if results['energies'] is not None:
        energies = results['energies']
        initial_energy = energies[0].mean()
        final_energy = energies[-1].mean()
        energy_change = final_energy - initial_energy
        
        print(f"\nLayer {layer_idx}:")
        print(f"  Initial energy: {initial_energy:.4f}")
        print(f"  Final energy: {final_energy:.4f}")
        print(f"  Energy change: {energy_change:.4f}")
        print(f"  Relative change: {energy_change/abs(initial_energy)*100:.2f}%")

## 6. Oscillator State Visualization

Visualize how the oscillator states evolve over time.

In [None]:
def visualize_oscillator_dynamics(layer_results, layer_idx=0, spatial_downsample=4):
    """Visualize oscillator dynamics with spatial downsampling"""
    if layer_idx not in layer_results or layer_results[layer_idx]['trajectory'] is None:
        print(f"No trajectory data for layer {layer_idx}")
        return
    
    trajectory = layer_results[layer_idx]['trajectory']
    T = len(trajectory)
    
    # Get trajectory shape: [T, batch, channels, height, width]
    trajectory_array = np.stack(trajectory, axis=0)  # [T, 1, C, H, W]
    trajectory_array = trajectory_array[:, 0]  # Remove batch dimension: [T, C, H, W]
    
    T, C, H, W = trajectory_array.shape
    n = config['n']  # oscillator dimension
    
    print(f"Trajectory shape: {trajectory_array.shape}")
    print(f"Channels: {C}, Oscillator dim: {n}")
    
    # Reshape to oscillator format: [T, C//n, n, H, W]
    oscillator_trajectory = trajectory_array.reshape(T, C//n, n, H, W)
    
    # Downsample spatially for visualization
    H_ds = H // spatial_downsample
    W_ds = W // spatial_downsample
    
    oscillator_trajectory_ds = oscillator_trajectory[:, :, :, ::spatial_downsample, ::spatial_downsample]
    oscillator_trajectory_ds = oscillator_trajectory_ds[:, :, :, :H_ds, :W_ds]
    
    print(f"Downsampled shape: {oscillator_trajectory_ds.shape}")
    
    # Visualize oscillator magnitudes and phases
    if n == 2:  # Complex oscillators
        # Compute magnitude and phase
        real_part = oscillator_trajectory_ds[:, :, 0]  # [T, C//2, H_ds, W_ds]
        imag_part = oscillator_trajectory_ds[:, :, 1]  # [T, C//2, H_ds, W_ds]
        
        magnitude = np.sqrt(real_part**2 + imag_part**2)
        phase = np.arctan2(imag_part, real_part)
        
        # Plot magnitude evolution for first few channels
        num_channels_to_plot = min(4, C//n)
        
        fig, axes = plt.subplots(2, num_channels_to_plot, figsize=(4*num_channels_to_plot, 8))
        if num_channels_to_plot == 1:
            axes = axes.reshape(-1, 1)
        
        for ch_idx in range(num_channels_to_plot):
            # Plot initial and final magnitude
            im1 = axes[0, ch_idx].imshow(magnitude[0, ch_idx], cmap='viridis', vmin=0, vmax=magnitude.max())
            axes[0, ch_idx].set_title(f'Channel {ch_idx} Magnitude (t=0)')
            axes[0, ch_idx].axis('off')
            plt.colorbar(im1, ax=axes[0, ch_idx], fraction=0.046, pad=0.04)
            
            im2 = axes[1, ch_idx].imshow(magnitude[-1, ch_idx], cmap='viridis', vmin=0, vmax=magnitude.max())
            axes[1, ch_idx].set_title(f'Channel {ch_idx} Magnitude (t={T-1})')
            axes[1, ch_idx].axis('off')
            plt.colorbar(im2, ax=axes[1, ch_idx], fraction=0.046, pad=0.04)
        
        plt.suptitle(f'Layer {layer_idx} Oscillator Magnitudes', fontsize=16)
        plt.tight_layout()
        plt.show()
        
        # Plot phase evolution
        fig, axes = plt.subplots(2, num_channels_to_plot, figsize=(4*num_channels_to_plot, 8))
        if num_channels_to_plot == 1:
            axes = axes.reshape(-1, 1)
        
        for ch_idx in range(num_channels_to_plot):
            im1 = axes[0, ch_idx].imshow(phase[0, ch_idx], cmap='hsv', vmin=-np.pi, vmax=np.pi)
            axes[0, ch_idx].set_title(f'Channel {ch_idx} Phase (t=0)')
            axes[0, ch_idx].axis('off')
            plt.colorbar(im1, ax=axes[0, ch_idx], fraction=0.046, pad=0.04)
            
            im2 = axes[1, ch_idx].imshow(phase[-1, ch_idx], cmap='hsv', vmin=-np.pi, vmax=np.pi)
            axes[1, ch_idx].set_title(f'Channel {ch_idx} Phase (t={T-1})')
            axes[1, ch_idx].axis('off')
            plt.colorbar(im2, ax=axes[1, ch_idx], fraction=0.046, pad=0.04)
        
        plt.suptitle(f'Layer {layer_idx} Oscillator Phases', fontsize=16)
        plt.tight_layout()
        plt.show()
        
        return magnitude, phase

# Visualize dynamics for each layer
for layer_idx in range(config['L']):
    print(f"\n=== Layer {layer_idx} Oscillator Dynamics ===")
    magnitude, phase = visualize_oscillator_dynamics(layer_results, layer_idx, spatial_downsample=2)

## Animation of Kuramoto Dynamics Evolution

Let's create animations showing how the random input x evolves through each Kuramoto time step at each layer, driven by the external input c from the CIFAR-10 image.

In [None]:
def create_kuramoto_animation(layer_results, layer_idx=0, fps=2, figsize=(12, 8)):
    """Create animation of Kuramoto dynamics evolution for a specific layer"""
    
    if layer_idx not in layer_results or layer_results[layer_idx]['trajectory'] is None:
        print(f"No trajectory data for layer {layer_idx}")
        return None
    
    trajectory = layer_results[layer_idx]['trajectory']
    energies = layer_results[layer_idx]['energies']
    T = len(trajectory)
    
    print(f"Creating animation for Layer {layer_idx} with {T} time steps...")
    
    # Convert trajectory to numpy and remove batch dimension
    trajectory_np = [x[0] for x in trajectory]  # Remove batch dim: [T, C, H, W]
    energies_np = [e[0] if len(e.shape) > 0 else e for e in energies]  # Handle scalar energies
    
    C, H, W = trajectory_np[0].shape
    n = config['n']
    
    # Reshape to oscillator format for visualization
    oscillator_traj = [x.reshape(C//n, n, H, W) for x in trajectory_np]
    
    if n == 2:  # Complex oscillators
        # Compute magnitude and phase for each time step
        magnitudes = []
        phases = []
        
        for t in range(T):
            real_part = oscillator_traj[t][:, 0]  # [C//2, H, W]
            imag_part = oscillator_traj[t][:, 1]  # [C//2, H, W]
            
            magnitude = np.sqrt(real_part**2 + imag_part**2)
            phase = np.arctan2(imag_part, real_part)
            
            magnitudes.append(magnitude)
            phases.append(phase)
        
        # Create figure with subplots
        fig, axes = plt.subplots(2, 3, figsize=figsize)
        
        # Initialize plots
        # First channel magnitude
        im1 = axes[0, 0].imshow(magnitudes[0][0], cmap='viridis', vmin=0, vmax=max([m[0].max() for m in magnitudes]))
        axes[0, 0].set_title('Channel 0 Magnitude')
        axes[0, 0].axis('off')
        cbar1 = plt.colorbar(im1, ax=axes[0, 0], fraction=0.046, pad=0.04)
        
        # First channel phase
        im2 = axes[0, 1].imshow(phases[0][0], cmap='hsv', vmin=-np.pi, vmax=np.pi)
        axes[0, 1].set_title('Channel 0 Phase')
        axes[0, 1].axis('off')
        cbar2 = plt.colorbar(im2, ax=axes[0, 1], fraction=0.046, pad=0.04)
        
        # Average magnitude across channels
        avg_mag_init = np.mean(magnitudes[0], axis=0)
        im3 = axes[0, 2].imshow(avg_mag_init, cmap='plasma', vmin=0, vmax=max([np.mean(m, axis=0).max() for m in magnitudes]))
        axes[0, 2].set_title('Average Magnitude')
        axes[0, 2].axis('off')
        cbar3 = plt.colorbar(im3, ax=axes[0, 2], fraction=0.046, pad=0.04)
        
        # Energy evolution plot
        energy_line, = axes[1, 0].plot([], [], 'b-', linewidth=2)
        energy_point, = axes[1, 0].plot([], [], 'ro', markersize=8)
        axes[1, 0].set_xlim(0, T-1)
        axes[1, 0].set_ylim(min(energies_np), max(energies_np))
        axes[1, 0].set_title('Energy Evolution')
        axes[1, 0].set_xlabel('Time Step')
        axes[1, 0].set_ylabel('Energy')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Magnitude histogram
        axes[1, 1].set_title('Magnitude Distribution')
        axes[1, 1].set_xlabel('Magnitude')
        axes[1, 1].set_ylabel('Frequency')
        
        # Phase histogram
        axes[1, 2].set_title('Phase Distribution')
        axes[1, 2].set_xlabel('Phase')
        axes[1, 2].set_ylabel('Frequency')
        
        # Time step text
        time_text = fig.suptitle(f'Layer {layer_idx} - Time Step: 0/{T-1}', fontsize=14)
        
        plt.tight_layout()
        
        def animate(frame):
            # Update magnitude image
            im1.set_array(magnitudes[frame][0])
            
            # Update phase image
            im2.set_array(phases[frame][0])
            
            # Update average magnitude
            avg_mag = np.mean(magnitudes[frame], axis=0)
            im3.set_array(avg_mag)
            
            # Update energy plot
            energy_line.set_data(range(frame+1), energies_np[:frame+1])
            energy_point.set_data([frame], [energies_np[frame]])
            
            # Update magnitude histogram
            axes[1, 1].clear()
            mag_flat = magnitudes[frame].flatten()
            axes[1, 1].hist(mag_flat, bins=30, alpha=0.7, color='blue')
            axes[1, 1].set_title('Magnitude Distribution')
            axes[1, 1].set_xlabel('Magnitude')
            axes[1, 1].set_ylabel('Frequency')
            
            # Update phase histogram
            axes[1, 2].clear()
            phase_flat = phases[frame].flatten()
            axes[1, 2].hist(phase_flat, bins=30, alpha=0.7, color='red', range=(-np.pi, np.pi))
            axes[1, 2].set_title('Phase Distribution')
            axes[1, 2].set_xlabel('Phase')
            axes[1, 2].set_ylabel('Frequency')
            
            # Update time step
            time_text.set_text(f'Layer {layer_idx} - Time Step: {frame}/{T-1}')
            
            return [im1, im2, im3, energy_line, energy_point, time_text]
        
        # Create animation
        anim = FuncAnimation(fig, animate, frames=T, interval=1000//fps, blit=False, repeat=True)
        
        return anim, fig
    
    else:
        print(f"Animation for n={n} dimensional oscillators not implemented yet")
        return None, None

def create_multi_layer_animation(layer_results, fps=1, figsize=(15, 10)):
    """Create animation showing dynamics across all layers simultaneously"""
    
    # Find layers with trajectory data
    active_layers = [i for i in layer_results.keys() if layer_results[i]['trajectory'] is not None]
    
    if not active_layers:
        print("No trajectory data found")
        return None, None
    
    print(f"Creating multi-layer animation for layers: {active_layers}")
    
    # Get maximum time steps across all layers
    max_T = max([len(layer_results[i]['trajectory']) for i in active_layers])
    
    # Create figure with subplots for each layer
    n_layers = len(active_layers)
    fig, axes = plt.subplots(2, n_layers, figsize=figsize)
    if n_layers == 1:
        axes = axes.reshape(2, 1)
    
    # Prepare data for each layer
    layer_data = {}
    images = {}
    energy_lines = {}
    energy_points = {}
    
    for idx, layer_idx in enumerate(active_layers):
        trajectory = layer_results[layer_idx]['trajectory']
        energies = layer_results[layer_idx]['energies']
        T = len(trajectory)
        
        # Convert to numpy and remove batch dimension
        trajectory_np = [x[0] for x in trajectory]
        energies_np = [e[0] if len(e.shape) > 0 else e for e in energies]
        
        C, H, W = trajectory_np[0].shape
        n = config['n']
        
        # Compute average magnitude across channels for visualization
        if n == 2:
            oscillator_traj = [x.reshape(C//n, n, H, W) for x in trajectory_np]
            magnitudes = []
            
            for t in range(T):
                real_part = oscillator_traj[t][:, 0]
                imag_part = oscillator_traj[t][:, 1]
                magnitude = np.sqrt(real_part**2 + imag_part**2)
                avg_magnitude = np.mean(magnitude, axis=0)  # Average across channels
                magnitudes.append(avg_magnitude)
            
            layer_data[layer_idx] = {
                'magnitudes': magnitudes,
                'energies': energies_np,
                'T': T
            }
            
            # Initialize magnitude image
            vmax = max([m.max() for m in magnitudes])
            im = axes[0, idx].imshow(magnitudes[0], cmap='viridis', vmin=0, vmax=vmax)
            axes[0, idx].set_title(f'Layer {layer_idx} - Magnitude')
            axes[0, idx].axis('off')
            plt.colorbar(im, ax=axes[0, idx], fraction=0.046, pad=0.04)
            images[layer_idx] = im
            
            # Initialize energy plot
            line, = axes[1, idx].plot([], [], 'b-', linewidth=2)
            point, = axes[1, idx].plot([], [], 'ro', markersize=6)
            axes[1, idx].set_xlim(0, T-1)
            axes[1, idx].set_ylim(min(energies_np), max(energies_np))
            axes[1, idx].set_title(f'Layer {layer_idx} - Energy')
            axes[1, idx].set_xlabel('Time Step')
            axes[1, idx].set_ylabel('Energy')
            axes[1, idx].grid(True, alpha=0.3)
            energy_lines[layer_idx] = line
            energy_points[layer_idx] = point
    
    # Time step text
    time_text = fig.suptitle(f'Multi-Layer Kuramoto Dynamics - Time Step: 0/{max_T-1}', fontsize=16)
    
    plt.tight_layout()
    
    def animate(frame):
        updated_artists = []
        
        for layer_idx in active_layers:
            data = layer_data[layer_idx]
            T = data['T']
            
            if frame < T:
                # Update magnitude image
                images[layer_idx].set_array(data['magnitudes'][frame])
                updated_artists.append(images[layer_idx])
                
                # Update energy plot
                energy_lines[layer_idx].set_data(range(frame+1), data['energies'][:frame+1])
                energy_points[layer_idx].set_data([frame], [data['energies'][frame]])
                updated_artists.extend([energy_lines[layer_idx], energy_points[layer_idx]])
        
        # Update time step
        time_text.set_text(f'Multi-Layer Kuramoto Dynamics - Time Step: {frame}/{max_T-1}')
        updated_artists.append(time_text)
        
        return updated_artists
    
    # Create animation
    anim = FuncAnimation(fig, animate, frames=max_T, interval=1000//fps, blit=False, repeat=True)
    
    return anim, fig

# Create animations for each layer individually
print("Creating individual layer animations...")
layer_animations = {}

for layer_idx in range(config['L']):
    if layer_idx in layer_results and layer_results[layer_idx]['trajectory'] is not None:
        print(f"\nCreating animation for Layer {layer_idx}...")
        anim, fig = create_kuramoto_animation(layer_results, layer_idx, fps=2)
        
        if anim is not None:
            layer_animations[layer_idx] = anim
            
            # Display the animation
            print(f"Displaying Layer {layer_idx} animation:")
            display(HTML(anim.to_jshtml()))
            plt.show()

# Create multi-layer animation
print("\nCreating multi-layer animation...")
multi_anim, multi_fig = create_multi_layer_animation(layer_results, fps=1)

if multi_anim is not None:
    print("Displaying multi-layer animation:")
    display(HTML(multi_anim.to_jshtml()))
    plt.show()

print("\nAnimation creation completed!")

## 7. Order Parameter Analysis

Compute and analyze the order parameter (synchronization measure) for each layer.

In [None]:
def compute_order_parameter(trajectory):
    """Compute Kuramoto order parameter for trajectory"""
    T = len(trajectory)
    order_params = []
    
    for t in range(T):
        state = trajectory[t]  # [1, C, H, W]
        state = state[0]  # Remove batch dimension: [C, H, W]
        
        C, H, W = state.shape
        n = config['n']
        
        # Reshape to oscillators: [C//n, n, H, W]
        oscillators = state.reshape(C//n, n, H, W)
        
        if n == 2:  # Complex oscillators
            # Convert to complex numbers
            complex_osc = oscillators[:, 0] + 1j * oscillators[:, 1]  # [C//2, H, W]
            
            # Compute order parameter (average over space)
            mean_complex = np.mean(complex_osc, axis=(1, 2))  # [C//2]
            order_param = np.abs(mean_complex)  # Magnitude of mean
            
            # Average over channels
            avg_order_param = np.mean(order_param)
            order_params.append(avg_order_param)
        
        else:
            # For general n-dimensional oscillators, use norm of mean
            mean_osc = np.mean(oscillators, axis=(2, 3))  # [C//n, n]
            order_param = np.linalg.norm(mean_osc, axis=1)  # [C//n]
            avg_order_param = np.mean(order_param)
            order_params.append(avg_order_param)
    
    return np.array(order_params)

# Compute order parameters for all layers
order_parameters = {}

for layer_idx, results in layer_results.items():
    if results['trajectory'] is not None:
        order_param = compute_order_parameter(results['trajectory'])
        order_parameters[layer_idx] = order_param
        
        print(f"Layer {layer_idx} order parameter evolution:")
        print(f"  Initial: {order_param[0]:.4f}")
        print(f"  Final: {order_param[-1]:.4f}")
        print(f"  Change: {order_param[-1] - order_param[0]:.4f}")

# Plot order parameter evolution
fig, axes = plt.subplots(1, len(order_parameters), figsize=(5*len(order_parameters), 4))
if len(order_parameters) == 1:
    axes = [axes]

for layer_idx, order_param in order_parameters.items():
    time_steps = range(len(order_param))
    axes[layer_idx].plot(time_steps, order_param, 'r-', linewidth=2, marker='o')
    axes[layer_idx].set_title(f'Layer {layer_idx} Order Parameter')
    axes[layer_idx].set_xlabel('Time Step')
    axes[layer_idx].set_ylabel('Order Parameter')
    axes[layer_idx].grid(True, alpha=0.3)
    axes[layer_idx].set_ylim(0, 1)

plt.tight_layout()
plt.show()

## 8. Layer Output Analysis (c terms)

Analyze the layer outputs that serve as external inputs (c terms) to subsequent layers.

In [None]:
def get_layer_outputs(model, input_tensor):
    """Get outputs from each layer (c terms for next layer)"""
    layer_outputs = []
    
    with torch.no_grad():
        # Use the model's feature method to get proper layer-by-layer outputs
        c = model.conv0(model.rgb_normalize(input_tensor))
        x = torch.randn_like(c)
        layer_outputs.append(c.cpu().numpy())  # Initial conv output
        
        # Through each layer
        for i, layer in enumerate(model.layers):
            transition_layer, _, k_layer, readout_layer, _ = layer
            
            # Apply transition
            x, c = transition_layer[0](x), transition_layer[1](c)
            
            # Apply KLayer
            T_val = model.T[i] if hasattr(model.T, '__getitem__') else model.T
            layer_xs, layer_es = k_layer(x, c, T_val, model.gamma)
            x = layer_xs[-1]  # Take final state
            
            # Apply readout
            c = readout_layer(x)
            layer_outputs.append(c.cpu().numpy())
    
    return layer_outputs

# Get layer outputs
layer_outputs = get_layer_outputs(model, sample_image)

print("Layer Output Analysis:")
for i, output in enumerate(layer_outputs):
    print(f"\nLayer {i} output:")
    print(f"  Shape: {output.shape}")
    print(f"  Mean: {output.mean():.4f}")
    print(f"  Std: {output.std():.4f}")
    print(f"  Min: {output.min():.4f}")
    print(f"  Max: {output.max():.4f}")

# Visualize layer outputs
fig, axes = plt.subplots(2, len(layer_outputs), figsize=(4*len(layer_outputs), 8))
if len(layer_outputs) == 1:
    axes = axes.reshape(-1, 1)

for i, output in enumerate(layer_outputs):
    # Remove batch dimension and take first few channels
    output_viz = output[0]  # [C, H, W]
    
    # Show channel-wise average
    channel_avg = np.mean(output_viz, axis=(1, 2))  # Average over spatial dims
    axes[0, i].bar(range(len(channel_avg)), channel_avg)
    axes[0, i].set_title(f'Layer {i} Channel Averages')
    axes[0, i].set_xlabel('Channel')
    axes[0, i].set_ylabel('Average Activation')
    
    # Show spatial pattern for first channel
    if output_viz.shape[0] > 0:
        im = axes[1, i].imshow(output_viz[0], cmap='viridis')
        axes[1, i].set_title(f'Layer {i} First Channel')
        axes[1, i].axis('off')
        plt.colorbar(im, ax=axes[1, i], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

## 9. Summary and Analysis

Summarize the key findings about the learned AKOrN model.

In [None]:
print("=" * 60)
print("LEARNED AKORN MODEL ANALYSIS SUMMARY")
print("=" * 60)

print(f"\nModel Configuration:")
print(f"  Architecture: {config['L']} layers with channels {[config['ch'] * (2**i) for i in range(config['L'])]}")
print(f"  Oscillator dimension: {config['n']}")
print(f"  Time steps per layer: {config['T']}")
print(f"  Best accuracy: 49.33% on CIFAR-10")

print(f"\nOmega (Natural Frequencies):")
for i, omega in enumerate(omega_params):
    print(f"  Layer {i}: magnitude = {np.linalg.norm(omega):.4f}, values = {omega}")

print(f"\nConnectivity Matrices (J):")
for i, conn in enumerate(connectivity_weights):
    weight = conn['weight']
    print(f"  Layer {i}: shape = {weight.shape}, mean = {weight.mean():.4f}, std = {weight.std():.4f}")

print(f"\nDynamics Analysis:")
for layer_idx in range(config['L']):
    if layer_idx in layer_results and layer_results[layer_idx]['energies'] is not None:
        energies = layer_results[layer_idx]['energies']
        initial_energy = energies[0].mean()
        final_energy = energies[-1].mean()
        
        order_param = order_parameters.get(layer_idx, None)
        if order_param is not None:
            order_change = order_param[-1] - order_param[0]
            print(f"  Layer {layer_idx}: Energy {initial_energy:.3f} → {final_energy:.3f}, Order parameter Δ = {order_change:.4f}")

print(f"\nKey Observations:")
print(f"  • Omega parameters show layer-specific natural frequencies")
print(f"  • Connectivity matrices exhibit learned spatial coupling patterns")
print(f"  • Energy evolution indicates convergence to stable states")
print(f"  • Order parameters reveal synchronization dynamics")
print(f"  • Layer outputs provide structured external inputs to subsequent layers")

print("\n" + "="*60)