In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import bayesflow as bf
import keras

import matplotlib.pyplot as plt
from FyeldGenerator import generate_field
import colorcet as cc
from tqdm.notebook import tqdm
from resnet import ResNetSummary

In [None]:
def generate_power_spectrum(alpha, scale):
    def power_spectrum(k):
        base = np.power(k, -alpha) * scale**2
        return base

    return power_spectrum


def distribution(shape=(32, 32)):
    a = np.random.normal(loc=0, scale=1., size=shape)
    b = np.random.normal(loc=0, scale=1., size=shape)
    return a + 1j * b

In [None]:
shape = (32, 32)
n_examples = 5
alphas = np.linspace(2, 5, n_examples)
spectra = [generate_power_spectrum(alpha, 1) for alpha in alphas]

In [None]:
def plot_distribution(shape=shape):
    rng = np.random.default_rng(seed=42)
    a = rng.normal(loc=0, scale=1., size=shape)
    b = rng.normal(loc=0, scale=1., size=shape)
    return a + 1j * b
fig, axs = plt.subplots(1, n_examples, figsize=(n_examples * 3, 4))

for power_spectrum, alpha, ax in zip(spectra, alphas, axs):
    
    field = generate_field(plot_distribution, power_spectrum, shape)
    max_magnitude = np.max(np.abs(field))
    ax.imshow(field, cmap=cc.cm.coolwarm, vmin=-max_magnitude, vmax=max_magnitude)
    ax.set_title(f"$\\alpha={alpha:.2f}$")
    ax.set_axis_off()

In [None]:
rng = np.random.default_rng()


def prior():
    log_std = rng.normal(scale=0.3)
    alpha = rng.normal(loc=3, scale=0.5)
    params_expanded = np.array([log_std, alpha])
    params_expanded = np.ones(shape + (2,)) * params_expanded[None, None, :]
    return {
        "log_std": log_std,
        "alpha": alpha,
        "params_expanded": params_expanded
    }


def likelihood(log_std, alpha):
    field = generate_field(
        distribution, generate_power_spectrum(alpha, np.exp(log_std)), shape
    )

    return {"field": field[..., None] / 50.}


simulator = bf.make_simulator([prior, likelihood])

In [None]:
@bf.utils.serialization.serializable("custom")
class ResNetSubnet(bf.networks.SummaryNetwork):
    def __init__(
        self,
        widths=(8, 16, 32),
        activation="mish",
        **kwargs,
    ):

        super().__init__(**kwargs)

        layers = [keras.layers.Conv2D(width, kernel_size=3, activation=activation, padding='SAME') for width in widths]
        self.net = bf.networks.Sequential(layers)

    def build(self, x_shape, t_shape, conditions_shape):
        self.net.build(x_shape[:-1] + (4,))

    def call(self, x, t, conditions, training=False):
        t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
        return self.net(keras.ops.concatenate((x, t, conditions), axis=-1), training=training)
    
    def compute_output_shape(self, x_shape, t_shape, conditions_shape):
        return x_shape[:-1] + (32,)

In [None]:
training_data = simulator.sample(5000)
validation_data = simulator.sample(50)

In [None]:
adapter = (
    bf.adapters.Adapter()
    .convert_dtype("float64", "float32")
    .rename("params_expanded", "inference_conditions")
    .rename("field", "inference_variables")
)

# Test Training

In [None]:
inference_network = bf.networks.DiffusionModel(subnet=ResNetSubnet, concatenate_subnet_input=False)

workflow = bf.workflows.BasicWorkflow(
    simulator=simulator,
    inference_network=inference_network,
    adapter=adapter,
    standardize=None,
)

In [None]:
history = workflow.fit_offline(
    data=training_data,
    epochs=100,
    validation_data=validation_data,
    batch_size=16,
)

In [None]:
workflow.approximator.inference_network.base_distribution.dims

In [None]:
inference_network is workflow.approximator.inference_network

In [None]:
validation_data = simulator.sample(1)

In [None]:
z = keras.random.normal((1, 32, 32, 1))
conditions = keras.ops.convert_to_tensor(validation_data["params_expanded"], dtype="float32")


sample = inference_network(z, conditions=conditions, inverse=True)

In [None]:
sample = keras.ops.convert_to_numpy(sample)

In [None]:
plt.imshow(sample[0], cmap="seismic")

In [None]:
f = bf.diagnostics.plots.loss(history)

In [None]:
small_training_data = {k: v[:100] for k,v in training_data.items()}

workflow.plot_custom_diagnostics(
    test_data=test_data,
    plot_fns={
        "recovery": bf.diagnostics.recovery,
        "calibration": bf.diagnostics.calibration_ecdf,
    },
)

# Evaluations

In [None]:
target = "NLE"
models = [
    "consistency_model",
    "diffusion_edm_vp",
    "flow_matching",
]
scales = [2**n for n in range(3, 9)]
model_configs = ["8_16", "32_64_128_256"]
checkpoint_paths = [
    f"{model}/{target}/checkpoints/{scale}_shape_config_{model_configs[0] if str(scale) in model_configs[0] else model_configs[1]}.keras"
    for model in models
    for scale in scales
]
print(checkpoint_paths)
checkpoint_path = checkpoint_paths[7]
print(checkpoint_path)
current_shape = int(checkpoint_path.split("/")[-1].split("_")[0])
current_config = checkpoint_path.split("_shape_config_")[-1].split(".keras")[0]
approximator = keras.saving.load_model(checkpoint_path)
approximator.summary()
approximator.inference_network.integrate_kwargs = {
    "method": "rk45",
    "steps": 500,
}


In [None]:
rng = np.random.default_rng()
shape = (current_shape, current_shape)

def generate_power_spectrum(alpha, scale):
    def power_spectrum(k):
        base = np.power(k, -alpha) * scale**2
        return base
    return power_spectrum

def distribution(shape):
    a = np.random.normal(loc=0, scale=1., size=shape)
    b = np.random.normal(loc=0, scale=1., size=shape)
    return a + 1j * b

def prior():
    log_std = rng.normal(scale=0.3)
    alpha = rng.normal(loc=3, scale=0.5)
    params_expanded = np.array([log_std, alpha])
    params_expanded = np.ones(shape + (2,)) * params_expanded[None, None, :]
    return {
        "log_std": log_std,
        "alpha": alpha,
        "params_expanded": params_expanded
    }


def likelihood(log_std, alpha):
    field = generate_field(
        distribution, generate_power_spectrum(alpha, np.exp(log_std)), shape
    )

    return {"field": field[..., None] / 50.}


simulator = bf.make_simulator([prior, likelihood])

In [None]:
validation_data = simulator.sample(1)
z = keras.random.normal((1, current_shape, current_shape, 1))
conditions = keras.ops.convert_to_tensor(validation_data["params_expanded"], dtype="float32")
sample = approximator.inference_network(z, conditions=conditions, inverse=True)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(validation_data["field"][0, :, :, 0], cmap=cc.cm.coolwarm)
ax[0].set_title("Simulated Field")
ax[1].imshow(keras.ops.convert_to_numpy(sample)[0, :, :, 0], cmap=cc.cm.coolwarm)
ax[1].set_title("Generated Field")
plt.show()

In [None]:
batch_size = 100
def generate_classifier_data(approximator, simulator, n_samples):
    data = simulator.sample(n_samples)
    z = keras.random.normal((n_samples, current_shape, current_shape, 1))
    conditions = keras.ops.convert_to_tensor(data["params_expanded"], dtype="float32")
    for b in tqdm(range(n_samples // batch_size)):
        z_batch = z[b*batch_size:(b+1)*batch_size]
        conditions_batch = conditions[b*batch_size:(b+1)*batch_size]
        samples_batch = approximator.inference_network(z_batch, conditions=conditions_batch, inverse=True)
        if b == 0:
            samples = samples_batch
        else:
            samples = keras.ops.concatenate([samples, samples_batch], axis=0)
    return {
        "simulated": data["field"],
        "generated": keras.ops.convert_to_numpy(samples),
    }

In [None]:
classifier_kwargs = {
        "shape_config_32_64_128_256": {
            "summary_kwargs": {
                "summary_dim": 1,
                "widths": [16, 16],
                "use_batchnorm": False,
                "dropout": 0.0,
            },
        },
        "shape_config_8_16": {
            "summary_kwargs": {
                "summary_dim": 1,
                "widths": 2*(8,),
                "use_batchnorm": False,
                "dropout": 0.0,
            },
        }
    }

def make_classifier(current_shape, current_config):
    inputs = keras.Input(current_shape)
    outputs = ResNetSummary(**classifier_kwargs[f"shape_config_{current_config}"]["summary_kwargs"])(inputs)
    return keras.Model(inputs=inputs, outputs=outputs)

classifier = make_classifier((current_shape, current_shape, 1), current_config)
classifier.summary()


In [None]:
traindata = generate_classifier_data(approximator, simulator, n_samples=10000)
x = np.concatenate(
    [traindata["simulated"], traindata["generated"]], axis=0
)
y = np.concatenate(
    [np.ones((traindata["simulated"].shape[0], 1)), np.zeros((traindata["generated"].shape[0], 1))],
    axis=0,
)
validation_data = generate_classifier_data(approximator, simulator, n_samples=100)
x_val = np.concatenate(
    [validation_data["simulated"], validation_data["generated"]], axis=0
)
y_val = np.concatenate(
    [np.ones((validation_data["simulated"].shape[0], 1)), np.zeros((validation_data["generated"].shape[0], 1))],
    axis=0,
)

In [None]:
plt.figure()
plt.hist(traindata["simulated"].flatten(), bins=100, alpha=0.5, label="simulated")
plt.hist(traindata["generated"].flatten(), bins=100, alpha=0.5, label="generated")
plt.legend()
plt.show()

In [None]:
epochs = 100
classifier.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
)
history=classifier.fit(
    x,
    y,
    epochs=epochs,
    batch_size=16,
    validation_data=(x_val, y_val),
)

In [None]:
plt.figure()
plt.plot(history.history["loss"], label="train loss")
plt.plot(history.history["val_loss"], label="val loss")
plt.plot(history.history["accuracy"], label="train accuracy")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()