In [None]:
import os
import time
import numpy as np
from tqdm import tqdm
from scipy.integrate import solve_ivp

import jax
import flax.linen as nn
import jax.numpy as jnp
import optax

GENERATE = False
TRAIN = False

savedir = "models/simple"
os.makedirs(savedir, exist_ok=True)

SIZE = 1000000


backend = jax.default_backend()
print(backend)

import logging
logging.basicConfig(filename=f'{savedir}/training.log', level=logging.INFO, format='%(asctime)s - %(message)s')

: 

In [2]:
class MLP(nn.Module):
    dim: int
    out_dim: int = 1
    w: int = 64

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.out_dim)(x)

: 

In [3]:
model = MLP(4)

In [4]:
jax.devices()

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

In [5]:
jax.numpy.ones(10)

E1019 00:39:21.639217   25957 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1019 00:39:21.639794   25957 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [4]:
key = jax.numpy.ones(12)

E1019 00:27:13.255363   23608 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1019 00:27:13.255979   23608 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [None]:
key = jax.numpy.ones(12)

E1019 00:25:21.153652   23134 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1019 00:25:21.154205   23134 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [9]:
@jax.jit
def predict(params, inputs):
    return model.apply({"params": params}, inputs)

@jax.jit
def sample_conditional_pt(x0, x1, t, sigma):
    t = t.reshape(-1, *([1] * (x0.ndim - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = jax.random.normal(jax.random.PRNGKey(0), x0.shape)
    return mu_t + sigma * epsilon

@jax.jit
def compute_conditional_vector_field(x0, x1):
    return x1 - x0

@jax.jit
def loss_ffm_function(params, x1, x0, d, e):
    """Compute loss on mini-batch"""
    x0 = jax.random.normal(jax.random.PRNGKey(0), (x1.shape[0], 6))
    t = jax.random.uniform(jax.random.PRNGKey(0), (x0.shape[0],))
    xt = sample_conditional_pt(x0, x1, t, sigma=0.01)
    ut = compute_conditional_vector_field(x0, x1)
    inputs = jnp.concatenate([xt, d, e, t[:, None]], axis=-1)
    vt = predict(params, inputs)
    loss = jnp.mean((vt - ut) ** 2)
    return loss

@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)


In [None]:
import os
import time
import numpy as np
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training import train_state

# Data generation
SIZE = 10000  # Example size, adjust as needed
m = np.random.uniform(size=SIZE)
e = np.random.uniform(size=SIZE)
noise = np.random.normal(scale=1e-4, size=SIZE)
d = np.power(e, 2) * np.power(m, 3) + m * np.exp(-np.abs(0.2 - e)) + noise
data = np.array([m, e, d]).T
np.save('data.npy', data)

# Model definition
class MLP(nn.Module):
    dim: int
    out_dim: int = 1
    w: int = 64

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.out_dim)(x)
        return x

model = MLP(dim=4)

@jax.jit
def predict(params, inputs):
    return model.apply({"params": params}, inputs)

@jax.jit
def sample_conditional_pt(x0, x1, t, sigma):
    t = t.reshape(-1, *([1] * (x0.ndim - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = jax.random.normal(jax.random.PRNGKey(0), x0.shape)
    return mu_t + sigma * epsilon

@jax.jit
def compute_conditional_vector_field(x0, x1):
    return x1 - x0

@jax.jit
def loss_ffm_function(params, x1, x0, d, e, key):
    t = jax.random.uniform(key, (x0.shape[0],))
    xt = sample_conditional_pt(x0, x1, t, sigma=0.01)
    ut = compute_conditional_vector_field(x0, x1)
    inputs = jnp.concatenate([xt, d, e, t[:, None]], axis=-1)
    vt = predict(params, inputs)
    loss = jnp.mean((vt - ut) ** 2)
    return loss

@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

# Training setup
key = jax.random.PRNGKey(0)
batch_size = 512
num_epochs = 20000
learning_rate = 0.001
optimizer = optax.adamw(learning_rate=learning_rate)
params = model.init(key, jnp.ones((1, 4)))
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params["params"],
    tx=optimizer
)

# Load data
X = jnp.load('data.npy')
dataset = jnp.array(X)
loader = jax.random.permutation(key, dataset)

losses = []

start = time.time()
for k in tqdm(range(num_epochs)):
    key, subkey = jax.random.split(key)
    batch_indices = jax.random.choice(subkey, jnp.arange(len(dataset)), (batch_size,))
    batch = dataset[batch_indices]
    
    x0 = jax.random.uniform(subkey, (batch_size, 1))
    x1 = batch[:, 0].reshape(-1, 1)
    d = batch[:, 2].reshape(-1, 1)
    e = batch[:, 1].reshape(-1, 1)

    loss, grads = jax.value_and_grad(loss_ffm_function, has_aux=False)(state.params, x1, x0, d, e, subkey)
    state = update_model(state, grads)
    losses.append(loss.item())

    if loss < np.min(losses):
        print(f"Loss is less than min loss at iteration {k+1}")
        with open("model_best_iter.pkl", "wb") as f:
            jax.experimental.optimizers.serialize(state, f)

    if (k+1) % 1000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss:0.3f} time {(end - start):0.2f}")

with open("model.pkl", "wb") as f:
    jax.experimental.optimizers.serialize(state, f)
print(f"Model saved to model.pkl")
np.save("losses.npy", np.array(losses))

In [None]:
def sample_conditional_pt(x0, x1, t, sigma):
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

def compute_conditional_vector_field(x0, x1):
    return x1 - x0


def d_by_m_e(m, e, noise):
    return np.power(e,2) * np.power(m,3) + m * np.exp(- np.abs(0.2 - e)) + noise

In [None]:
SIZE = 1000000

if GENERATE:
    m_arr = np.zeros((SIZE, 6), dtype='float')
    e_arr = np.zeros((SIZE, 25), dtype='float')
    d_arr = np.zeros((SIZE, 50), dtype='float')
    d_noise_arr = np.zeros((SIZE, 50), dtype='float')

    for i in tqdm(range(SIZE)):
        constrain = False
        while not constrain:
            m = np.random.uniform(size=6)
            e = np.sort(np.random.uniform(low=1, high=3, size=25))
            d = d_by_m_e(m,e).flatten()
            constrain = d.shape[0] == 50
            
        m_arr[i] = m
        e_arr[i] = e
        d_arr[i] = d
        noise = np.concatenate([np.random.normal(0, 2, size=25), np.random.normal(0, 1, size=25)])
        d_noise_arr[i] = d + noise
        
    np.save('data/25_points/m.npy', m_arr)
    np.save('data/25_points/e.npy', e_arr)
    np.save('data/25_points/d.npy', d_arr)
    np.save('data/25_points/d_noise.npy', d_noise_arr)
else:
    m_arr = torch.tensor(np.load('data/25_points/m.npy'), dtype=torch.float32)
    e_arr = torch.tensor(np.load('data/25_points/e.npy'), dtype=torch.float32)
    d_arr = torch.tensor(np.load('data/25_points/d.npy'), dtype=torch.float32)

In [None]:
if TRAIN:
    batch_size = 4096

    dataset = torch.utils.data.TensorDataset(m_arr, e_arr, d_arr)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    sigma = 0.1
    dim = m_arr[0].shape[0] + e_arr[0].shape[0] + d_arr[0].shape[0]
    print(dim)
    model = MLP(dim=dim, out_dim=6, w=256,time_varying=True).to(device)
    optimizer = torch.optim.Adam(model.parameters())

    start = time.time()

    n_iter = 2000

    loss_iter = [np.inf]

    for k in tqdm(range(n_iter)):
        loss_epoch = []
        for m, e, d in loader:
            optimizer.zero_grad()
            # noise = torch.cat([
            #     torch.randn(d.shape[0], 4) * 2,
            #     torch.randn(d.shape[0], 4) * 1
            # ], dim=1)
            # d = d + noise
            d = d.to(device)
            e = e.to(device)
            x1 = m.to(device)
            x0 = torch.rand(x1.shape[0], 6).to(device)
            t = torch.rand(x0.shape[0]).type_as(x0)
            xt = sample_conditional_pt(x0, x1, t, sigma=0.01)
            ut = compute_conditional_vector_field(x0, x1)
            vt = model(torch.cat([xt, d, e, t[:, None]], dim=-1))
            loss = torch.mean((vt - ut) ** 2)
            
            loss_epoch.append(loss.item())
            loss.backward()
            optimizer.step()
            
        if np.mean(loss_epoch) < np.min(loss_iter):
            torch.save(model.state_dict(), f"{savedir}/model_best_epoch.pt")
            logging.info(f"Model saved to {savedir}/model_best_epoch.pt")
            logging.info(f"{np.min(loss_iter):0.4f}")
        loss_iter.append(np.mean(loss_epoch))

    torch.save(model.state_dict(), f"{savedir}/model.pt")
    logging.info(f"Model saved to {savedir}/model.pt")
    np.save(f'{savedir}/losses_4.npy', np.array(loss_iter))

In [None]:
m0 = torch.rand(1, 6)
m = [0.4, 0.3, 0.3, 0.1, 0.15, 0.6]
e = np.linspace(1,3,25)
d = d_by_m_e(m ,e).flatten()
m = torch.tensor(m,  dtype=torch.float32).unsqueeze(0)
e = torch.tensor(e,  dtype=torch.float32).unsqueeze(0)
d = torch.tensor(d,  dtype=torch.float32).unsqueeze(0)
dim = m[0].shape[0] + e[0].shape[0] + d[0].shape[0]
print(dim)
model = MLP(dim=dim, out_dim=6, w=256,time_varying=True)
model.load_state_dict(torch.load(f"{savedir}/model.pt",  map_location='cpu'))

def ode_function(t, m, d, e):
    t = torch.tensor(t, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    m = torch.tensor(m,  dtype=torch.float32).unsqueeze(0)
    return model(torch.cat([m, d, e, t], dim=-1)).detach().numpy()[0]

solution = solve_ivp(ode_function, t_span=[0, 1], y0=m0[0], t_eval=None, args=(d, e))
d_pred = d_by_m_e(solution.y[:, -1],e[0]).flatten()
diff_norm = torch.linalg.norm(d - d_pred) / torch.linalg.norm(d)

print(f'm_pred = {solution.y[:, -1]}')
print(f'd = {d}')
print(f'd_pred = {d_pred}')
print(f'diff norm = {diff_norm}')

logging.info(f'm_pred = {solution.y[:, -1]}')
logging.info(f'd = {d}')
logging.info(f'd_pred = {d_pred}')
logging.info(f'diff norm = {diff_norm}')

In [None]:
errors = []
sols = []

for i in tqdm(range(1000)):
    m0 = torch.rand(1, 6)
    m = [0.4, 0.3, 0.3, 0.1, 0.15, 0.6]
    e = np.linspace(1,3,25)
    d = d_by_m_e(m ,e).flatten()
    m = torch.tensor(m,  dtype=torch.float32).unsqueeze(0)
    e = torch.tensor(e,  dtype=torch.float32).unsqueeze(0)
    d = torch.tensor(d,  dtype=torch.float32).unsqueeze(0)

    solution = solve_ivp(ode_function, t_span=[0, 1], y0=m0[0], t_eval=None, args=(d, e))
    d_pred = d_by_m_e(solution.y[:, -1],e[0]).flatten()
    sols.append(solution.y[:, -1])
    errors.append(np.linalg.norm(d - d_pred) / np.linalg.norm(d))

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")

df = pd.DataFrame(sols, columns = [fr'$\beta_1$', fr'$\alpha', fr'$\gamma^r$', fr'$\gamma^d_1$', fr'$\beta_2$', fr'$\gamma^d_2$'])

plt.figure(dpi=300, figsize=(12,6))
sns.kdeplot(df, fill=True, alpha=0.5, common_norm=True)
plt.title(fr'Joint probability distribution $\rho(m|d,e)$', fontsize=16, fontweight='bold')
plt.xlabel('Parameter value')
plt.legend(fontsize=12, loc='best')
plt.xlim(0,1)
plt.savefig('seir_25p.png')
plt.show()

print(np.mean(errors), np.std(errors))
logging.info(np.mean(errors), np.std(errors))