In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# configs / defaults

In [2]:
!pip install ml_collections



In [3]:
import jax.numpy as jnp
import ml_collections


def get_config():
    config = ml_collections.ConfigDict()

    config.use_wandb = True

    # Weights & Biases
    config.wandb = wandb = ml_collections.ConfigDict()
    wandb.project = "kdv"
    wandb.name = "current_model"
    wandb.tags = None
    wandb.group = None

    # Simulation settings
    config.data = data = ml_collections.ConfigDict()
    data.nt = 140
    data.nx = 256
    data.L = 128
    data.T = 140

    # FNO Architecture
    config.arch = arch = ml_collections.ConfigDict()
    arch.modes = 16
    arch.width = 64
    arch.seed = 0

    # Training
    config.training = training = ml_collections.ConfigDict()
    training.batch_size = 16
    training.time_history = 20
    training.time_future = 20
    training.epochs = 30 * 2 * data.nt
    training.seed = 1

    # Optimizer
    config.optim = optim = ml_collections.ConfigDict()
    optim.optimizer = "adamw"
    optim.learning_rate = 0.001
    optim.b1 = 0.9
    optim.b2 = 0.999
    optim.eps = 1e-8
    optim.eps_root = 0.0
    optim.weight_decay = 0.01
    optim.scale = 0.4
    optim.boundaries = jnp.array([10, 1200, 2400, 3600])
    return config


# models

In [4]:
from functools import partial
from typing import Callable, Dict, Tuple

import jax
import jax.numpy as jnp
import ml_collections
import optax
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad, vmap


def _create_optimizer(
    config: ml_collections.ConfigDict,
) -> optax.GradientTransformation:

    learning_rate = config.learning_rate

    keys_arr = config.boundaries
    vals_arr = learning_rate * config.scale ** jnp.arange(
        1, keys_arr.shape[0] + 1
    )

    def dictionary_based_schedule(step, keys_arr, vals_arr, default_lr):
        idx = jnp.sum(keys_arr <= step) - 1
        return jnp.where(idx < 0, default_lr, vals_arr[idx])

    lr = lambda step: dictionary_based_schedule(
        step, keys_arr, vals_arr, learning_rate
    )

    optimizer = optax.adamw(
        learning_rate=lr,
        b1=config.b1,
        b2=config.b2,
        eps=config.eps,
        eps_root=config.eps_root,
        weight_decay=config.weight_decay,
    )
    return optimizer


def _create_train_state(
    config: ml_collections.ConfigDict,
) -> TrainState:
    arch = config.arch
    training = config.training

    time_future = training.time_future
    time_history = training.time_history

    model = FNO1d(width=arch.width, modes=arch.modes, time_future=time_future)

    dummy_input = jnp.ones((config.data.nx, time_history))
    key = random.PRNGKey(arch.seed)
    params = model.init(key, dummy_input)

    tx = _create_optimizer(config.optim)

    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state


class FNO:
    def __init__(
        self,
        config: ml_collections.ConfigDict,
    ):
        self.config = config
        self.state = _create_train_state(config)

    def loss(
        self,
        params: Dict,
        state: TrainState,
        batch: Tuple[jnp.ndarray, jnp.ndarray],
    ) -> jnp.ndarray:
        data, labels = batch
        pred = vmap(lambda x: state.apply_fn(params, x))(data)
        loss = jnp.square(pred - labels)
        return loss.sum()

    @partial(jit, static_argnums=(0,))
    def step(
        self,
        state: TrainState,
        batch: Tuple[jnp.ndarray, jnp.ndarray],
    ) -> TrainState:
        loss, grads = value_and_grad(self.loss)(state.params, state, batch)
        state = state.apply_gradients(grads=grads)
        return state, loss


class SpectralConv1d(nn.Module):
    width: int
    modes: int

    def setup(
        self,
    ):
        scale = 1 / (self.width * self.width)
        self.weights = self.param(
            "global_kernel",
            lambda rng, shape: random.uniform(
                rng, shape, minval=0, maxval=scale
            ),
            (2, self.modes, self.width, self.width),
        )

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        spatial_resolution = x.shape[0]

        x_ft = jnp.fft.rfft(x, axis=0)
        x_ft_trunc = x_ft[: self.modes, :]

        R = jax.lax.complex(self.weights[0, ...], self.weights[1, ...])

        R_x_ft = jnp.einsum("Mio,Mi->Mo", R, x_ft_trunc)

        result = jnp.zeros((x_ft.shape[0], self.width), dtype=x_ft.dtype)
        result = result.at[: self.modes, :].set(R_x_ft)

        inv_ft_R_x_ft = jnp.fft.irfft(result, n=spatial_resolution, axis=0)
        return inv_ft_R_x_ft


class FNOBlock1d(nn.Module):
    width: int
    modes: int
    activation: Callable

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        spectral_conv = SpectralConv1d(self.width, self.modes)(x)
        local_conv = nn.Conv(self.width, kernel_size=(1,), name="local")(x)
        return self.activation(spectral_conv + local_conv)


class FNO1d(nn.Module):
    width: int
    modes: int
    time_future: int

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        x = nn.Dense(self.width)(x)

        x = FNOBlock1d(self.width, self.modes, nn.gelu)(x)
        x = FNOBlock1d(self.width, self.modes, nn.gelu)(x)
        x = FNOBlock1d(self.width, self.modes, nn.gelu)(x)
        x = FNOBlock1d(self.width, self.modes, nn.gelu)(x)

        x = nn.Dense(128)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.time_future)(x)
        return x


# utils

In [5]:
from typing import Tuple

import jax.numpy as jnp
import torch
import h5py
from torch.utils.data import DataLoader, Dataset


def to_coords(
    x: torch.Tensor,
    t: torch.Tensor,
) -> torch.Tensor:
    x_, t_ = torch.meshgrid(x, t)
    x_ = x_.T
    t_ = t_.T
    return torch.stack((x_, t_), -1)


class HDF5Dataset(Dataset):
    def __init__(
        self,
        path: str,
        mode: str,
        nt: int,
        nx: int,
        dtype=torch.float64,
        load_all: bool = False,
    ):
        super().__init__()
        f = h5py.File(path, "r")
        self.mode = mode
        self.dtype = dtype
        self.data = f[self.mode]
        self.dataset = f"pde_{nt}-{nx}"

        if load_all:
            data = {self.dataset: self.data[self.dataset][:]}
            f.close()
            self.data = data

    def __len__(
        self,
    ) -> int:
        return self.data[self.dataset].shape[0]

    def __getitem__(
        self,
        idx: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        u = self.data[self.dataset][idx]
        x = self.data["x"][idx]
        t = self.data["t"][idx]
        dx = self.data["dx"][idx]
        dt = self.data["dt"][idx]

        if self.mode == "train":
            X = to_coords(torch.tensor(x), torch.tensor(t))
            sol = (torch.tensor(u), X)
            u = sol[0]
            X = sol[1]
            dx = X[0, 1, 0] - X[0, 0, 0]
            dt = X[1, 0, 1] - X[0, 0, 1]
        else:
            u = torch.from_numpy(u)
            dx = torch.tensor([dx])
            dt = torch.tensor([dt])
        return u.float(), dx.float(), dt.float()


def create_dataloader(
    data_string: str,
    mode: str,
    nt: int,
    nx: int,
    batch_size: int,
) -> DataLoader:
    try:
        dataset = HDF5Dataset(data_string, mode, nt=nt, nx=nx)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    except:
        raise Exception("Datasets could not be loaded properly")
    return loader


def create_data(
    datapoints: torch.Tensor,
    start_time: list,
    time_future: int,
    time_history: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    data = torch.Tensor()
    labels = torch.Tensor()

    for dp, start in zip(datapoints, start_time):
        end_time = start + time_history
        d = dp[start:end_time]
        target_start_time = end_time
        target_end_time = target_start_time + time_future
        l = dp[target_start_time:target_end_time]

        data = torch.cat((data, d[None, :]), 0)
        labels = torch.cat((labels, l[None, :]), 0)
    return jnp.array(data), jnp.array(labels)


# train

In [6]:
from typing import List, Tuple

import jax.numpy as jnp
import ml_collections
import numpy as npy
from flax.training.train_state import TrainState
from torch.utils.data import DataLoader
from wandb.sdk.wandb_run import Run
from tqdm import trange
from jax import random


def train_model(
    config: ml_collections.ConfigDict,
    run: Run = None,
) -> TrainState:
    def train(
        state: TrainState,
        key: jnp.ndarray,
        loader: DataLoader,
    ) -> Tuple[TrainState, List, jnp.ndarray]:
        time_history = training.time_history
        time_future = training.time_future
        batch_size = training.batch_size

        max_start_time = (config.data.nt - time_history) - time_future
        possible_start_times = jnp.arange(
            time_history, max_start_time + 1, time_history
        )

        losses = npy.zeros(len(loader))
        for i, (u, _, _) in enumerate(loader):
            key, subkey = random.split(key)
            start_time = random.choice(
                subkey, possible_start_times, (batch_size,)
            )

            data, labels = create_data(
                u, start_time, time_history, time_future
            )

            batch = (
                jnp.permute_dims(data, (0, 2, 1)),
                jnp.permute_dims(labels, (0, 2, 1)),
            )
            state, loss = model.step(state, batch)
            losses[i] = loss.item()
        return state, losses / batch_size, key

    training = config.training

    train_loader = create_dataloader(
        "/content/drive/My Drive/data/KdV_train.h5",
        mode="train",
        nt=config.data.nt,
        nx=config.data.nx,
        batch_size=training.batch_size,
    )

    epochs = training.epochs
    pbar = trange(epochs)

    model = FNO(config)

    key = random.PRNGKey(training.seed)
    for epoch in pbar:
        model.state, losses, key = train(model.state, key, train_loader)

        if (epoch % 10 == 0) or (epoch == epochs):
            loss = losses.mean()
            pbar.set_postfix({"train_loss": loss})
            if run is not None:
                run.log({"train_loss": loss})
    return model.state


# main

In [7]:
import wandb

config = get_config()


def main():
    use_wandb = config.use_wandb

    if use_wandb == True:
        wandb.init(
            project=config.wandb.project,
            name=config.wandb.name,
            tags=config.wandb.tags,
            group=config.wandb.group,
        )
        run = wandb.run
    else:
        run = None
    model_state = train_model(config, run)
    return model_state, wandb.run


model_state, run = main()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 8400/8400 [3:37:46<00:00,  1.56s/it, train_loss=0.149]


# save model

In [8]:
if run is not None:
    import pickle

    model_params_filename = f"{run.name}_model_params"
    with open(model_params_filename + ".pkl", "wb") as f:
        pickle.dump(model_state.params, f)

    artifact = wandb.Artifact(model_params_filename, type="model")
    artifact.add_file(model_params_filename + ".pkl")
    wandb.log_artifact(artifact)
    wandb.finish()


0,1
train_loss,█▆▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,0.14933
