In [3]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import equinox as eqx
import time
import requests
import itertools

from urllib.parse import urlencode
from jax import random, vmap
from jax.lax import scan
from jax.tree_util import tree_map

%config InlineBackend.figure_format='retina'

In [None]:
from flow_matching import train as train_fm, priors, flows
from flow_matching.architectures import ffno as ffno_fm

from flow_matching.integrators import explicit_Euler, integrator
from flow_matching.evaluate import get_statistics_scan, compute_error as compute_error_fm

In [27]:
# link = 'https://disk.yandex.ru/d/7XkIfN8hJ2a2RA'
# name = 'Diffusion.npz'
# download_from_yandex_disk(link, name)

# Diffusion_data = jnp.load(name)
Diffusion_data = jnp.load("/mnt/local/dataset/by-domain/pde/PDE_datasets/PDE_bench/2D_DarcyFlow_beta100.0_Train.npz")

Diffusion_features = Diffusion_data['features']
Diffusion_targets = Diffusion_data['targets']
Diffusion_coordinates = Diffusion_data['coordinates']

Diffusion_features = Diffusion_features / jnp.max(jnp.linalg.norm(Diffusion_features, ord=jnp.inf, axis=(2, 3), keepdims=True), axis=0, keepdims=True)
Diffusion_targets = Diffusion_data['targets']  / jnp.max(jnp.linalg.norm(Diffusion_targets, ord=jnp.inf, axis=(2, 3), keepdims=True), axis=0, keepdims=True)

print("features", Diffusion_features.shape)
print("targets", Diffusion_targets.shape)
print("coordinates", Diffusion_coordinates.shape)

del Diffusion_data

features (10000, 1, 128, 128)
targets (10000, 1, 128, 128)
coordinates (2, 128, 128)


In [None]:
D = 2
learning_rate = 1e-4
N_processor = 32
N_train = 4000
N_run = 10000
N_batch = 10
N_layers = 4
N_modes = 16
N_drop = N_run // 4
gamma = 0.5
scale = 0.001
po = 2.0
N = 150

key = random.PRNGKey(11)
keys = random.split(key, 3)
N_features = [Diffusion_coordinates.shape[0] + Diffusion_features.shape[1] + Diffusion_targets.shape[1] + 1, N_processor, Diffusion_targets.shape[1]]
model = ffno_fm.flow_FFNO(N_layers, N_features, N_modes, D, keys[0])

learning_rate = optax.exponential_decay(learning_rate, N_drop, gamma)
optim = optax.lion(learning_rate=learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

ind = jnp.arange(Diffusion_features.shape[0])
ind_train, ind_test = ind[:N_train], ind[N_train:]
n = random.choice(keys[1], ind_train, shape = (N_run, N_batch))

carry = [model, Diffusion_targets, Diffusion_features, Diffusion_coordinates, opt_state, keys[2]]

flow_params = [0.0, ]
flow = lambda target_1, target_0, t: flows.optimal_transport(target_1, target_0, t, flow_params)

basis, freq = priors.get_basis_normal_periodic(Diffusion_coordinates, N)
prior_params = [basis, freq, scale, po]
prior = lambda key: priors.normal_periodic(key, prior_params)

make_step_scan_ = lambda a, b: train_fm.make_step_scan(a, b, optim, flow, prior)
carry, losses = scan(make_step_scan_, carry, n)
model = carry[0]

plt.yscale("log")
plt.plot(losses);