In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# configs / defaults

In [2]:
!pip install ml_collections



In [3]:
import ml_collections
from flax import linen as nn


def get_config():
    config = ml_collections.ConfigDict()

    config.use_wandb = True

    # Weights & Biases
    config.wandb = wandb = ml_collections.ConfigDict()
    wandb.project = "burgers"
    wandb.name = "best_model"
    wandb.tags = None
    wandb.group = None

    # Simulation Settings
    config.data = data = ml_collections.ConfigDict()
    data.in_channels = 2
    data.out_channels = 1
    data.spatial_dims = 2048

    # FNO Architecture
    config.arch = arch = ml_collections.ConfigDict()
    arch.modes = 24
    arch.width = 64
    arch.num_layers = 12
    arch.seed = 0
    arch.activation = nn.tanh
    arch.layer_init = nn.initializers.glorot_uniform()
    arch.lift_init = nn.initializers.glorot_uniform()
    arch.proj_init = nn.initializers.glorot_uniform()

    # Training
    config.training = training = ml_collections.ConfigDict()
    training.batch_size = 128
    training.epochs = 8000
    training.seed = 1

    # Optimizer
    config.optim = optim = ml_collections.ConfigDict()
    optim.optimizer = "adam"
    optim.learning_rate = 1e-3
    optim.b1 = 0.9
    optim.b2 = 0.999
    optim.eps = 1e-8
    optim.eps_root = 0.0

    optim.transition_steps = 250
    optim.transition_begin = 0
    optim.decay_rate = 0.9
    return config


# models

In [4]:
from typing import Callable, Dict, Tuple
from functools import partial

import jax.numpy as jnp
import ml_collections
import optax
import jax
from flax.training.train_state import TrainState
from flax import linen as nn
from jax import random, vmap, grad, jit


def _create_optimizer(
    config: ml_collections.ConfigDict,
) -> optax.GradientTransformation:
    lr = optax.exponential_decay(
        init_value=config.learning_rate,
        transition_steps=config.transition_steps,
        decay_rate=config.decay_rate,
    )

    optimizer = optax.adam(
        learning_rate=lr,
        b1=config.b1,
        b2=config.b2,
        eps=config.eps,
        eps_root=config.eps_root,
    )
    return optimizer


def _create_train_state(
    config: ml_collections.ConfigDict,
) -> TrainState:
    arch = config.arch
    data = config.data

    model = FNO1d(
        arch.width,
        data.out_channels,
        arch.modes,
        arch.activation,
        arch.num_layers,
        arch.lift_init,
        arch.proj_init,
        arch.layer_init,
    )

    dummy_input = jnp.ones((data.spatial_dims, data.in_channels))
    key = random.PRNGKey(arch.seed)
    params = model.init(key, dummy_input)

    tx = _create_optimizer(config.optim)

    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state


class FNO:
    def __init__(
        self,
        config: ml_collections.ConfigDict,
    ):
        self.config = config
        self.state = _create_train_state(config)

    def loss(
        self,
        params: Dict,
        state: TrainState,
        batch: Tuple[jnp.ndarray, jnp.ndarray],
    ) -> jnp.ndarray:
        data, labels = batch
        pred = vmap(lambda x: state.apply_fn(params, x))(data)
        loss = jnp.square(pred - labels)
        return loss.mean()

    @partial(jit, static_argnums=(0,))
    def step(
        self,
        state: TrainState,
        batch: Tuple[jnp.ndarray, jnp.ndarray],
    ) -> TrainState:
        grads = grad(self.loss)(state.params, state, batch)
        state = state.apply_gradients(grads=grads)
        return state


class SpectralConv1d(nn.Module):
    width: int
    modes: int

    def setup(
        self,
    ):
        scale = 1 / (self.width * self.width)
        self.weights = self.param(
            "global_kernel",
            lambda rng, shape: random.uniform(
                rng, shape, minval=-scale, maxval=scale
            ),
            (2, self.modes, self.width, self.width),
        )

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        spatial_resolution = x.shape[0]

        x_ft = jnp.fft.rfft(x, axis=0)
        x_ft_trunc = x_ft[: self.modes, :]

        R = jax.lax.complex(self.weights[0, ...], self.weights[1, ...])

        R_x_ft = jnp.einsum("Mio,Mi->Mo", R, x_ft_trunc)

        result = jnp.zeros((x_ft.shape[0], self.width), dtype=x_ft.dtype)
        result = result.at[: self.modes, :].set(R_x_ft)

        inv_ft_R_x_ft = jnp.fft.irfft(result, n=spatial_resolution, axis=0)
        return inv_ft_R_x_ft


class FNOBlock1d(nn.Module):
    width: int
    modes: int
    activation: Callable
    layer_init: Callable

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        spectral_conv = SpectralConv1d(self.width, self.modes)(x)
        local_conv = nn.Conv(
            self.width,
            kernel_size=(1,),
            kernel_init=self.layer_init,
            name="local",
        )(x)
        return self.activation(spectral_conv + local_conv)


class FNO1d(nn.Module):
    width: int
    out_channels: int
    modes: int
    activation: Callable
    num_layers: int
    lift_init: Callable
    proj_init: Callable
    layer_init: Callable

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        x = nn.Conv(
            self.width,
            kernel_size=(1,),
            kernel_init=self.lift_init,
            name="lifting",
        )(x)
        for _ in range(self.num_layers):
            x = FNOBlock1d(
                self.width, self.modes, self.activation, self.layer_init
            )(x)
        x = nn.Conv(
            self.out_channels,
            kernel_size=(1,),
            kernel_init=self.proj_init,
            name="projection",
        )(x)
        return x


# utils

In [5]:
from functools import partial
from typing import Tuple

import jax.numpy as jnp
from torch.utils.data import Dataset
from jax import random, jit


class DataGenerator(Dataset):
    def __init__(
        self,
        inputs: jnp.ndarray,
        outputs: jnp.ndarray,
        batch_size: int,
        key: jnp.ndarray,
    ):
        self.inputs = inputs
        self.outputs = outputs
        self.N = outputs.shape[0]
        self.batch_size = batch_size
        self.key = key

    @partial(jit, static_argnums=(0,))
    def __data_generation(
        self,
        key: jnp.ndarray,
        inputs: jnp.ndarray,
        outputs: jnp.ndarray,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        idx = random.choice(key, self.N, (self.batch_size,), replace=False)
        inputs = inputs[idx, ...]
        outputs = outputs[idx, ...]
        return inputs, outputs

    def __getitem__(
        self,
        index: int,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        self.key, subkey = random.split(self.key)
        inputs, outputs = self.__data_generation(
            self.key, self.inputs, self.outputs
        )
        return inputs, outputs


# train

In [6]:
import ml_collections
import numpy as npy
from flax.training.train_state import TrainState
from wandb.sdk.wandb_run import Run
from tqdm import trange
from jax import random

data = npy.load(
    "/content/drive/My Drive/data/burgers_data.npy", allow_pickle=True
).item()

a_with_mesh = data["a_with_mesh"]
u = data["u"]

train_split = int(0.75 * a_with_mesh.shape[0])

train_x = a_with_mesh[:train_split, ::4, :]
train_x_mean = train_x.mean((0, 1))
train_x_std = train_x.std((0, 1))
train_x = (train_x - train_x_mean) / train_x_std

train_y = u[:train_split, ::4, :]
train_y_mean = train_y.mean((0, 1))
train_y_std = train_y.std((0, 1))
train_y = (train_y - train_y_mean) / train_y_std


def train_model(
    config: ml_collections.ConfigDict,
    run: Run = None,
) -> TrainState:
    dataset = DataGenerator(
        train_x,
        train_y,
        config.training.batch_size,
        random.PRNGKey(config.training.seed),
    )

    epochs = config.training.epochs
    pbar = trange(epochs)
    data = iter(dataset)

    model = FNO(config)

    for epoch in pbar:
        batch = next(data)
        model.state = model.step(model.state, batch)

        if (epoch % 20 == 0) or (epoch == epochs):
            loss = model.loss(model.state.params, model.state, batch)
            pbar.set_postfix({"train_loss": loss})

            if run is not None:
                run.log({"train_loss": loss})

    return model.state


# main

In [7]:
import wandb

config = get_config()


def main():
    use_wandb = config.use_wandb

    if use_wandb == True:
        wandb.init(
            project=config.wandb.project,
            name=config.wandb.name,
            tags=config.wandb.tags,
            group=config.wandb.group,
        )
        run = wandb.run
        model_state = train_model(config, run)
    else:
        run = None
        model_state = train_model(config, run)
    return model_state, run


model_state, run = main()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mskoohy[0m ([33mskoohy-penn[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 8000/8000 [20:09<00:00,  6.61it/s, train_loss=2.63977e-06]


# save model

In [8]:
if run is not None:
    import pickle

    model_params_filename = f"{run.name}_model_params"
    with open(model_params_filename + ".pkl", "wb") as f:
        pickle.dump(model_state.params, f)

    artifact = wandb.Artifact(model_params_filename, type="model")
    artifact.add_file(model_params_filename + ".pkl")
    wandb.log_artifact(artifact)
    wandb.finish()


0,1
train_loss,▄█▃▃▂▁▂▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,0.0
