In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import json
import os
import yaml
import glob

from tqdm import tqdm

from einops import rearrange

from the_well.benchmark.metrics import VRMSE
from the_well.data import WellDataset, datasets
from the_well.utils.download import well_download
from the_well.benchmark import models
from the_well.benchmark.models import UNetConvNext

from safetensors.torch import load_file

In [None]:
def load_unetconvnext_from_local(dataset_name: str,
                                 model_base="data/model-data/benchmarks",
                                 dataset_base="data/datasets"):
    # Locate model directory
    all_dirs = os.listdir(model_base)
    match = next(
        (d for d in all_dirs if dataset_name in d and d.startswith("UNetConvNext")),
        None
    )
    if not match:
        raise FileNotFoundError(f"No local model found for dataset '{dataset_name}' in {model_base}")

    model_dir = os.path.join(model_base, match)
    config_path = os.path.join(model_dir, "config.json")
    safetensor_path = os.path.join(model_dir, "model.safetensors")

    if not os.path.exists(config_path) or not os.path.exists(safetensor_path):
        raise FileNotFoundError("Missing config.json or model.safetensors in model directory.")

    # Load config.json
    with open(config_path, "r") as f:
        config = json.load(f)

    # Load metadata YAML
    yaml_path = os.path.join(dataset_base, dataset_name, f"{dataset_name}.yaml")
    if not os.path.exists(yaml_path):
        raise FileNotFoundError(f"Missing metadata YAML: {yaml_path}")

    with open(yaml_path, "r") as f:
        metadata = yaml.safe_load(f)

    # Extract model parameters
    model_kwargs = {
        "dim_in": config["dim_in"],
        "dim_out": config["dim_out"],
        "stages": config.get("stages", 4),
        "blocks_per_stage": config.get("blocks_per_stage", 1),
        "blocks_at_neck": config.get("blocks_at_neck", 1),
        "init_features": config.get("init_features", 32),
        "gradient_checkpointing": config.get("gradient_checkpointing", False),
        "n_spatial_dims": metadata["n_spatial_dims"],
        "spatial_resolution": tuple(metadata["spatial_resolution"]),
    }

    # Instantiate and load weights
    model = UNetConvNext(**model_kwargs)
    state_dict = load_file(safetensor_path)
    model.load_state_dict(state_dict)
    model.eval()

    return model, model_kwargs, model_dir

Load model and dataset

In [None]:
dataset_base = "data/datasets"
dataset_name = "turbulent_radiative_layer_2D"
model, model_kwargs, path = load_unetconvnext_from_local(dataset_name, dataset_base=dataset_base)

testset = datasets.WellDataset(
    well_base_path="data/datasets/",
    well_dataset_name=dataset_name,
    well_split_name="test",
    n_steps_input=4,
    n_steps_output=1,
    use_normalization=False,
)

testloader = torch.utils.data.DataLoader(
    dataset=testset,
    shuffle=False,
    batch_size=4,
    num_workers=4,
)

Get training data statistics from stats.yaml file and define pre and post processing functions

In [None]:
yaml_path = os.path.join(dataset_base, dataset_name, "stats.yaml")

with open(yaml_path, "r") as f:
    stats = yaml.safe_load(f)

# Flatten into a list: [density, pressure, velocity_x, velocity_y]
mean_vals = [
    stats["mean"]["density"],
    stats["mean"]["pressure"],
    stats["mean"]["velocity"][0],
    stats["mean"]["velocity"][1]
]
std_vals = [
    stats["std"]["density"],
    stats["std"]["pressure"],
    stats["std"]["velocity"][0],
    stats["std"]["velocity"][1]
]

device = "cpu" # change if using gpu

mean_tensor = torch.tensor(mean_vals).view(1, 1, 1, 1, -1).to(device)  # shape (1,1,1,1,4)
std_tensor = torch.tensor(std_vals).view(1, 1, 1, 1, -1).to(device)    # shape (1,1,1,1,4)

def preprocess(x):
    return (x - mean_tensor) / std_tensor

def postprocess(x):
    return std_tensor * x + mean_tensor

Register hooks for target layers

In [None]:
activations = {}

def get_activation(name):
    def hook(module, input, output):
        activations[name] = output.detach()
    return hook

hooks = []

target_layer_names = ["encoder.3.blocks.1.dwconv",]  # update this list

for name, module in model.named_modules():
    if name in target_layer_names:
        hook = module.register_forward_hook(get_activation(name))
        hooks.append(hook)

Collect activations for real data and noise, save each batch to disk

In [None]:
base_dir = "data/activations"
model_name = "UNetConvNext"
split_name = "test"

save_dir = f"{base_dir}/{model_name}-{dataset_name}/{split_name}"

def save_batch(save_dir, layer_name, data, kind, batch_idx):
    path = os.path.join(base_dir, f"{kind}_{layer_name}_batch{batch_idx}.pt")
    torch.save(data.cpu(), path)

model = model.to(device)

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(testloader, desc="Streaming activations to disk")):
        # Real data
        x = batch["input_fields"].to(device)
        x = preprocess(x)
        x = rearrange(x, "B Ti Lx Ly F -> B (Ti F) Lx Ly")
        _ = model(x)

        for name, act in activations.items():
            save_batch(save_dir, name, act.flatten(start_dim=1), "real", batch_idx)

        # Noise
        noise_x = torch.randn_like(x)
        _ = model(noise_x)

        for name, act in activations.items():
            save_batch(save_dir, name, act.flatten(start_dim=1), "noise", batch_idx)

        # Free up memory
        del x, noise_x

In [None]:
# Unhook after use
for hook in hooks:
    hook.remove()

In [None]:
def load_streamed_activations(base_dir, model_name, dataset_name, split_name, kind="real", layer_name=None):
    base_dir = f"{base_dir}/{model_name}-{dataset_name}/{split_name}"
    pattern = f"{kind}_{layer_name}_batch*.pt" if layer_name else f"{kind}_*_batch*.pt"
    filepaths = sorted(glob.glob(os.path.join(base_dir, pattern)))

    if not filepaths:
        raise FileNotFoundError(f"No matching activation files found at: {base_dir}/{pattern}")

    # Load and concatenate
    batches = [torch.load(path) for path in filepaths]
    return torch.cat(batches, dim=0)

Load activations from disk

In [None]:
real_activations = load_streamed_activations("data/activations", model_name, dataset_name, "test", kind="real", layer_name="encoder.3.blocks.1.dwconv")
noise_activations = load_streamed_activations("data/activations", model_name, dataset_name, "test", kind="noise", layer_name="encoder.3.blocks.1.dwconv")

See if noise and real activations are linearly separable

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

# Assume real_activations and noise_activations are torch.Tensors
X_real = real_activations.numpy()
X_noise = noise_activations.numpy()

X = np.concatenate([X_real, X_noise], axis=0)
y = np.concatenate([
    np.ones(len(X_real)),   # label 1 = real
    np.zeros(len(X_noise))  # label 0 = noise
])

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, stratify=y, random_state=42
)

clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {acc:.2f}")

Visualize decision boundary in 2D

In [None]:
from sklearn.preprocessing import StandardScaler

# Normalize full feature set before PCA
scaler = StandardScaler()
X_test_scaled = scaler.fit_transform(X_test)
X_test_pca = PCA(n_components=2).fit_transform(X_test_scaled)

# Train a new classifier in 2D just for visualization
clf_2d = LogisticRegression(max_iter=1000)
clf_2d.fit(X_test_pca, y_test)

# Meshgrid for contour plot
h = 0.05
# Range of grid
x_min, x_max = X_test_pca[:, 0].min() - 1, X_test_pca[:, 0].max() + 1
y_min, y_max = X_test_pca[:, 1].min() - 1, X_test_pca[:, 1].max() + 1

# Cap grid resolution
grid_points = 300  # reduce to 200–500 as needed
xx, yy = np.meshgrid(
    np.linspace(x_min, x_max, grid_points),
    np.linspace(y_min, y_max, grid_points)
)
Z = clf_2d.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)

# Plot
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.3)
plt.scatter(X_test_pca[y_test == 0, 0], X_test_pca[y_test == 0, 1], label='Noise', alpha=0.7)
plt.scatter(X_test_pca[y_test == 1, 0], X_test_pca[y_test == 1, 1], label='Real', alpha=0.7)
plt.legend()
plt.title("Decision Boundary (PCA projection)")
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.grid(True)
plt.tight_layout()
plt.show()