In [33]:
import torch
import numpy as np
import math
import torch.nn.functional as F
import json
import dill
from tqdm import tqdm
from scipy.integrate import odeint
from scipy.special import legendre
from itertools import combinations_with_replacement

In [34]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## SINDy Utils

In [35]:
def sindy_library_pt(z, latent_dim, poly_order, include_sine=False):
    """
    Description: Builds the SINDy library for a first-order dynamical system.

    Args:
        z (torch.Tensor): Input tensor of shape (batch_size, latent_dim), representing latent states.
        latent_dim (int): Number of latent variables (dimensions).
        poly_order (int): Maximum degree of polynomial terms to include in the library.
        include_sine (bool): Whether to include sine terms in the library.

    Returns:
        torch.Tensor: A matrix (batch_size, library_size) where each column is a function of z.
    """

    # Initialize the library with a column of ones. The number of rows is equal to batch size.
    library = [torch.ones(z.size(0)).to(device)]

    # Prepare to loop over all variable combinations
    sample_list = range(latent_dim)

    for n in range(1, poly_order + 1):
        # Get all combinations (with replacement) of latent_dim variables of total degree n
        list_combinations = list(combinations_with_replacement(sample_list, n))

        for combination in list_combinations:
            # For each combination, compute the product of the corresponding columns in z
            # e.g., z[:, [0, 0]] -> z_0^2, z[:, [1, 2]] -> z_1 * z_2
            term = torch.prod(z[:, combination], dim=1)
            library.append(term.to(device))  # Add to the library (on GPU)

    # Optionally add sine terms of each latent variable
    if include_sine:
        for i in range(latent_dim):
            library.append(
                torch.sin(z[:, i])
            )  # Automatically on correct device since z is

    # Stack all features column-wise into a single tensor of shape (batch_size, num_features)
    return torch.stack(library, dim=1).to(device)

def sindy_library_pt_order2(z, dz, latent_dim, poly_order, include_sine=False):
    """
    Build the SINDy library for a second-order system.
    """
    library = [torch.ones(z.size(0)).to(device)]  # initialize library

    # concatenate z and dz
    z_combined = torch.cat([z, dz], dim=1)

    sample_list = range(2 * latent_dim)
    list_combinations = list()

    for n in range(1, poly_order + 1):
        list_combinations = list(combinations_with_replacement(sample_list, n))
        for combination in list_combinations:
            library.append(
                torch.prod(z_combined[:, combination], dim=1).to(device)
            )

    # add sine terms if included
    if include_sine:
        for i in range(2 * latent_dim):
            library.append(torch.sin(z_combined[:, i]))

    return torch.stack(library, dim=1).to(device)

def library_size(latent_dim, poly_order):
    f = lambda latent_dim, poly_order: math.comb(
        latent_dim + poly_order - 1, poly_order
    )
    total = 0
    for i in range(poly_order + 1):
        total += f(latent_dim, i)
    return total

# Generate data

## Lorenz Code

In [36]:
def get_lorenz_data(n_ics, noise_strength=0):
    """
    Generate a set of Lorenz training data for multiple random initial conditions.

    Arguments:
        n_ics - Integer specifying the number of initial conditions to use.
        noise_strength - Amount of noise to add to the data.

    Return:
        data - Dictionary containing elements of the dataset. See generate_lorenz_data()
        doc string for list of contents.
    """
    t = np.arange(0, 5, 0.02)
    n_steps = t.size
    input_dim = 128

    ic_means = np.array([0, 0, 25])
    ic_widths = 2 * np.array([36, 48, 41])

    # training data
    ics = ic_widths * (np.random.rand(n_ics, 3) - 0.5) + ic_means
    data = generate_lorenz_data(
        ics,
        t,
        input_dim,
        linear=False,
        normalization=np.array([1 / 40, 1 / 40, 1 / 40]),
    )
    data["x"] = data["x"].reshape((-1, input_dim)) + noise_strength * np.random.randn(
        n_steps * n_ics, input_dim
    )
    data["dx"] = data["dx"].reshape((-1, input_dim)) + noise_strength * np.random.randn(
        n_steps * n_ics, input_dim
    )
    data["ddx"] = data["ddx"].reshape(
        (-1, input_dim)
    ) + noise_strength * np.random.randn(n_steps * n_ics, input_dim)

    return data

def simulate_lorenz(z0, t, sigma=10.0, beta=8 / 3, rho=28.0):
    """
    Simulate the Lorenz dynamics.

    Arguments:
        z0 - Initial condition in the form of a 3-value list or array.
        t - Array of time points at which to simulate.
        sigma, beta, rho - Lorenz parameters

    Returns:
        z, dz, ddz - Arrays of the trajectory values and their 1st and 2nd derivatives.
    """
    f = lambda z, t: [
        sigma * (z[1] - z[0]),
        z[0] * (rho - z[2]) - z[1],
        z[0] * z[1] - beta * z[2],
    ]
    df = lambda z, dz, t: [
        sigma * (dz[1] - dz[0]),
        dz[0] * (rho - z[2]) + z[0] * (-dz[2]) - dz[1],
        dz[0] * z[1] + z[0] * dz[1] - beta * dz[2],
    ]

    z = odeint(f, z0, t)

    dt = t[1] - t[0]
    dz = np.zeros(z.shape)
    ddz = np.zeros(z.shape)
    for i in range(t.size):
        dz[i] = f(z[i], dt * i)
        ddz[i] = df(z[i], dz[i], dt * i)
    return z, dz, ddz

def lorenz_coefficients(normalization, poly_order=3, sigma=10.0, beta=8 / 3, rho=28.0):
    """
    Generate the SINDy coefficient matrix for the Lorenz system.

    Arguments:
        normalization - 3-element list of array specifying scaling of each Lorenz variable
        poly_order - Polynomial order of the SINDy model.
        sigma, beta, rho - Parameters of the Lorenz system
    """
    Xi = np.zeros((library_size(3, poly_order), 3))
    Xi[1, 0] = -sigma
    Xi[2, 0] = sigma * normalization[0] / normalization[1]
    Xi[1, 1] = rho * normalization[1] / normalization[0]
    Xi[2, 1] = -1
    Xi[6, 1] = -normalization[1] / (normalization[0] * normalization[2])
    Xi[3, 2] = -beta
    Xi[5, 2] = normalization[2] / (normalization[0] * normalization[1])
    return Xi

def generate_lorenz_data(
    ics, t, n_points, linear=True, normalization=None, sigma=10, beta=8 / 3, rho=28
):
    """
    Generate high-dimensional Lorenz data set.

    Arguments:
        ics - Nx3 array of N initial conditions
        t - array of time points over which to simulate
        n_points - size of the high-dimensional dataset created
        linear - Boolean value. If True, high-dimensional dataset is a linear combination
        of the Lorenz dynamics. If False, the dataset also includes cubic modes.
        normalization - Optional 3-value array for rescaling the 3 Lorenz variables.
        sigma, beta, rho - Parameters of the Lorenz dynamics.

    Returns:
        data - Dictionary containing elements of the dataset. This includes the time points (t),
        spatial mapping (y_spatial), high-dimensional modes used to generate the full dataset
        (modes), low-dimensional Lorenz dynamics (z, along with 1st and 2nd derivatives dz and
        ddz), high-dimensional dataset (x, along with 1st and 2nd derivatives dx and ddx), and
        the true Lorenz coefficient matrix for SINDy.
    """

    n_ics = ics.shape[0]
    n_steps = t.size
    dt = t[1] - t[0]

    d = 3
    z = np.zeros((n_ics, n_steps, d))
    dz = np.zeros(z.shape)
    ddz = np.zeros(z.shape)
    for i in range(n_ics):
        z[i], dz[i], ddz[i] = simulate_lorenz(
            ics[i], t, sigma=sigma, beta=beta, rho=rho
        )

    if normalization is not None:
        z *= normalization
        dz *= normalization
        ddz *= normalization

    n = n_points
    L = 1
    y_spatial = np.linspace(-L, L, n)

    modes = np.zeros((2 * d, n))
    for i in range(2 * d):
        modes[i] = legendre(i)(y_spatial)
        # modes[i] = chebyt(i)(y_spatial)
        # modes[i] = np.cos((i+1)*np.pi*y_spatial/2)
    x1 = np.zeros((n_ics, n_steps, n))
    x2 = np.zeros((n_ics, n_steps, n))
    x3 = np.zeros((n_ics, n_steps, n))
    x4 = np.zeros((n_ics, n_steps, n))
    x5 = np.zeros((n_ics, n_steps, n))
    x6 = np.zeros((n_ics, n_steps, n))

    x = np.zeros((n_ics, n_steps, n))
    dx = np.zeros(x.shape)
    ddx = np.zeros(x.shape)
    for i in range(n_ics):
        for j in range(n_steps):
            x1[i, j] = modes[0] * z[i, j, 0]
            x2[i, j] = modes[1] * z[i, j, 1]
            x3[i, j] = modes[2] * z[i, j, 2]
            x4[i, j] = modes[3] * z[i, j, 0] ** 3
            x5[i, j] = modes[4] * z[i, j, 1] ** 3
            x6[i, j] = modes[5] * z[i, j, 2] ** 3

            x[i, j] = x1[i, j] + x2[i, j] + x3[i, j]
            if not linear:
                x[i, j] += x4[i, j] + x5[i, j] + x6[i, j]

            dx[i, j] = (
                modes[0] * dz[i, j, 0] + modes[1] * dz[i, j, 1] + modes[2] * dz[i, j, 2]
            )
            if not linear:
                dx[i, j] += (
                    modes[3] * 3 * (z[i, j, 0] ** 2) * dz[i, j, 0]
                    + modes[4] * 3 * (z[i, j, 1] ** 2) * dz[i, j, 1]
                    + modes[5] * 3 * (z[i, j, 2] ** 2) * dz[i, j, 2]
                )

            ddx[i, j] = (
                modes[0] * ddz[i, j, 0]
                + modes[1] * ddz[i, j, 1]
                + modes[2] * ddz[i, j, 2]
            )
            if not linear:
                ddx[i, j] += (
                    modes[3]
                    * (
                        6 * z[i, j, 0] * dz[i, j, 0] ** 2
                        + 3 * (z[i, j, 0] ** 2) * ddz[i, j, 0]
                    )
                    + modes[4]
                    * (
                        6 * z[i, j, 1] * dz[i, j, 1] ** 2
                        + 3 * (z[i, j, 1] ** 2) * ddz[i, j, 1]
                    )
                    + modes[5]
                    * (
                        6 * z[i, j, 2] * dz[i, j, 2] ** 2
                        + 3 * (z[i, j, 2] ** 2) * ddz[i, j, 2]
                    )
                )

    if normalization is None:
        sindy_coefficients = lorenz_coefficients(
            [1, 1, 1], sigma=sigma, beta=beta, rho=rho
        )
    else:
        sindy_coefficients = lorenz_coefficients(
            normalization, sigma=sigma, beta=beta, rho=rho
        )

    data = {}
    data["t"] = t
    data["y_spatial"] = y_spatial
    data["modes"] = modes
    data["x"] = x
    data["dx"] = dx
    data["ddx"] = ddx
    data["z"] = z
    data["dz"] = dz
    data["ddz"] = ddz
    data["sindy_coefficients"] = sindy_coefficients.astype(np.float32)

    return data

## Lorenz Data

In [37]:
noise_strength = 1e-6
training_data = get_lorenz_data(10, noise_strength=noise_strength)
validation_data = get_lorenz_data(5, noise_strength=noise_strength)
test_data = get_lorenz_data(5, noise_strength=noise_strength)

# Neural Network

## Custom Linear Layer

In [38]:
class LinearLayer(torch.nn.Linear):
    """Custom implementation of layer which includes the activation function
    and its derivative
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        activation_function: torch.nn.Module,
        last: bool = False,
        order: int = 1,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:

        # Constructor for a regular linear layer
        super().__init__(in_features, out_features, bias, device, dtype=torch.float64)

        # Our modifications to the linear layer

        # Store the activation function
        self.activation_function = activation_function

        # Store the activation function's first derivative
        self.activation_derivative = self.__get_activation_derivative()

        # Store the activation function'sclass LinearLayer(torch.nn.Linear):
    """Custom implementation of layer which includes the activation function
    and its derivative
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        activation_function: torch.nn.Module,
        last: bool = False,
        order: int = 1,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:

        # Constructor for a regular linear layer
        super().__init__(in_features, out_features, bias, device, dtype=torch.float64)

        # Our modifications to the linear layer

        # Store the activation function
        self.activation_function = activation_function

        # Store the activation function's first derivative
        self.activation_derivative = self.__get_activation_derivative()

        # Store the activation function's second derivative
        self.activation_2nd_derivative = self.__get_activation_2nd_derivative()

        # Store the most recently computed
        self.last = last
        if order == 1:
            self.forward = self.forward_dx
        else:
            self.forward = self.forward_ddx

    def forward_dx(self, input: torch.Tensor) -> torch.Tensor:

        print(input)
        x, dx = input[: input.shape[0] // 2], input[input.shape[0] // 2 :]
        dim = self.weight.shape[0]

        if self.last:
            x = F.linear(x, self.weight, self.bias)
            dx = F.linear(dx, self.weight, torch.zeros_like(self.bias))
        else:
            x = F.linear(x, self.weight, self.bias)
            dx = self.activation_derivative(x) * F.linear(
                dx, self.weight, torch.zeros_like(self.bias)
            )
            x = self.activation_function(x)

        return torch.cat((x, dx), dim=0)

    def forward_ddx(self, input: torch.Tensor) -> torch.Tensor:

        slicer = input.shape[0] // 3
        x, dx, ddx = input[:slicer], input[slicer : 2 * slicer], input[2 * slicer :]

        if self.last:
            x = F.linear(x, self.weight, self.bias)
            dx = F.linear(dx, self.weight, torch.zeros_like(self.bias))
            ddx = F.linear(ddx, self.weight, torch.zeros_like(self.bias))
        else:
            x = F.linear(x, self.weight, self.bias)
            dx_ = F.linear(dx, self.weight, torch.zeros_like(self.bias))
            ddx_ = F.linear(ddx, self.weight, torch.zeros_like(self.bias))

            dactivation = self.activation_derivative(x)
            ddactivation = self.activation_2nd_derivative(x, dactivation)

            dx = dactivation * dx_
            ddx = ddactivation * dx_ + dactivation * ddx_

            x = self.activation_function(x)

        return torch.cat((x, dx, ddx), dim=0)

    def __get_activation_derivative(self):
        if isinstance(self.activation_function, torch.nn.ReLU):
            return lambda x: torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
        if isinstance(self.activation_function, torch.nn.ELU):
            return lambda x: torch.minimum(x, torch.exp(x))
        if isinstance(self.activation_function, torch.nn.Sigmoid):
            return lambda x: self.dsigmoid(x)

    def __get_activation_2nd_derivative(self):
        if isinstance(self.activation_function, torch.nn.ReLU):
            return lambda x, dx=0: 0
        if isinstance(self.activation_function, torch.nn.ELU):
            return lambda x, dx=0: torch.where(x > 0, torch.exp(x), torch.zeros_like(x))
        if isinstance(self.activation_function, torch.nn.Sigmoid):
            return lambda x, dx=0: self.dsigmoid(dx)

    def dsigmoid(self, x):
        sigmoid = self.activation_function(x)
        return sigmoid * (1 - sigmoid)
        self.activation_2nd_derivative = self.__get_activation_2nd_derivative()

        # Store the most recently computed
        self.last = last
        if order == 1:
            self.forward = self.forward_dx
        else:
            self.forward = self.forward_ddx

    def forward_dx(self, input: torch.Tensor):

        print(input)
        x, dx = input[: input.shape[0] // 2], input[input.shape[0] // 2 :]
        dim = self.weight.shape[0]

        if self.last:
            x = F.linear(x, self.weight, self.bias)
            dx = F.linear(dx, self.weight, torch.zeros_like(self.bias))
        else:
            x = F.linear(x, self.weight, self.bias)
            dx = self.activation_derivative(x) * F.linear(
                dx, self.weight, torch.zeros_like(self.bias)
            )
            x = self.activation_function(x)

        return torch.cat((x, dx), dim=0)

    def forward_ddx(self, input: torch.Tensor):

        slicer = input.shape[0] // 3
        x, dx, ddx = input[:slicer], input[slicer : 2 * slicer], input[2 * slicer :]

        if self.last:
            x = F.linear(x, self.weight, self.bias)
            dx = F.linear(dx, self.weight, torch.zeros_like(self.bias))
            ddx = F.linear(ddx, self.weight, torch.zeros_like(self.bias))
        else:
            x = F.linear(x, self.weight, self.bias)
            dx_ = F.linear(dx, self.weight, torch.zeros_like(self.bias))
            ddx_ = F.linear(ddx, self.weight, torch.zeros_like(self.bias))

            dactivation = self.activation_derivative(x)
            ddactivation = self.activation_2nd_derivative(x, dactivation)

            dx = dactivation * dx_
            ddx = ddactivation * dx_ + dactivation * ddx_

            x = self.activation_function(x)

        return torch.cat((x, dx, ddx), dim=0)

    def __get_activation_derivative(self):
        if isinstance(self.activation_function, torch.nn.ReLU):
            return lambda x: torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
        if isinstance(self.activation_function, torch.nn.ELU):
            return lambda x: torch.minimum(x, torch.exp(x))
        if isinstance(self.activation_function, torch.nn.Sigmoid):
            return lambda x: self.dsigmoid(x)

    def __get_activation_2nd_derivative(self):
        if isinstance(self.activation_function, torch.nn.ReLU):
            return lambda x, dx=0: 0
        if isinstance(self.activation_function, torch.nn.ELU):
            return lambda x, dx=0: torch.where(x > 0, torch.exp(x), torch.zeros_like(x))
        if isinstance(self.activation_function, torch.nn.Sigmoid):
            return lambda x, dx=0: self.dsigmoid(dx)

    def dsigmoid(self, x):
        sigmoid = self.activation_function(x)
        return sigmoid * (1 - sigmoid)

## Loss Function

In [39]:
class Loss(torch.nn.Module):
    def __init__(
        self,
        lambda_1: float,
        lambda_2: float,
        lambda_3: float,
        lambda_r: float,
        order: int = 1,
        *args,
        **kwargs
    ) -> None:
        """Custom loss fucnction based on multiple MSEs

        Args:
            lambda_1 (float): loss weight decoder
            lambda_2 (float): loss weight sindy z
            lambda_3 (float): loss weight sindy x
            lambda_r (float): loss weight sindy regularization
            order (int, optional): Order of the model can be 1 or 2. Defaults to 1.
        """
        super().__init__(*args, **kwargs)
        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        self.lambda_3 = lambda_3
        self.lambda_r = lambda_r

        self.regularization = True

        if order == 1:
            self.forward = self.forward_dx
        else:
            self.forward = self.forward_ddx

    def forward_dx(
        self,
        x,
        dx,
        dz,
        dz_pred,
        x_decode,
        dx_decode,
        sindy_coeffs: torch.Tensor,
        coeff_mask,
    ) -> torch.Tensor:

        loss = 0

        loss += self.lambda_1 * torch.mean((x - x_decode) ** 2)
        loss += self.lambda_2 * torch.mean((dz - dz_pred) ** 2)
        loss += self.lambda_3 * torch.mean((dx - dx_decode) ** 2)
        loss += (
            int(self.regularization)
            * self.lambda_r
            * torch.mean(torch.abs(sindy_coeffs) * coeff_mask)
        )

        return loss

    def forward_ddx(
        self,
        x,
        dx,
        dz,
        dz_pred,
        x_decode,
        dx_decode,
        sindy_coeffs: torch.Tensor,
        ddz,
        ddx,
        ddx_decode,
        coeff_mask,
    ) -> torch.Tensor:

        loss = 0

        loss += self.lambda_1 * torch.mean((x - x_decode) ** 2)
        # dz_pred is in this case ddz_pred
        loss += self.lambda_2 * torch.mean((ddz - dz_pred) ** 2)
        loss += self.lambda_3 * torch.mean((ddx - ddx_decode) ** 2)
        loss += (
            int(self.regularization)
            * self.lambda_r
            * torch.mean(torch.abs(sindy_coeffs) * coeff_mask)
        )

        return loss

    def set_regularization(self, include_regularization: bool) -> None:

        self.regularization = include_regularization

## Autoencoder

In [40]:
class AutoEncoder(torch.nn.Module):

    RELU = "relu"
    SIGMOID = "sigmoid"
    ELU = "elu"

    def __init__(
        self, params: dict = {}, name: str = "encoder", *args, **kwargs
    ) -> None:

        super().__init__(*args, **kwargs)
        self.params = params

        activation = self.params["activation"]
        self.activation_function = self.__get_activation(activation)

        self.weights = (
            [self.params["input_dim"]]
            + self.params["widths"]
            + [self.params["latent_dim"]]
        )
        self.order = self.params["model_order"]

        if self.weights is None:
            raise TypeError("Missing weight param")

        if name == "encoder":
            self.__create_encoder()
        elif name == "decoder":
            self.__create_decoder()

    def __create_encoder(self) -> None:
        """Creates the encoder based on weights and activation function"""
        layers = []
        for curr_weights, next_weights in zip(self.weights[:-1], self.weights[1:]):
            layers.append(
                LinearLayer(
                    curr_weights,
                    next_weights,
                    self.activation_function,
                    len(layers) + 2 == len(self.weights),
                    self.order,
                )
            )
        self.net = torch.nn.Sequential(*layers)

    def __create_decoder(self) -> None:
        """Creates decoder, the weights are swapped and reversed compared to the encoder"""
        layers = []
        for curr_weights, next_weights in zip(
            reversed(self.weights[1:]), reversed(self.weights[:-1])
        ):
            layers.append(
                LinearLayer(
                    curr_weights,
                    next_weights,
                    self.activation_function,
                    len(layers) + 2 == len(self.weights),
                    self.order,
                )
            )

        self.net = torch.nn.Sequential(*layers)

    def __get_activation(self, activation: str = "relu") -> torch.nn.Module:
        match (activation):
            case self.RELU:
                return torch.nn.ReLU()
            case self.SIGMOID:
                return torch.nn.Sigmoid()
            case self.ELU:
                return torch.nn.ELU()
            case _:
                raise TypeError(f"Invalid activation function {activation}")

    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
        """Forward function of the autoencoder

        Args:
            x (List[Tensor]): either the List has 2 or 3 elements
            if it has 2 elements the model order has to be set to 1
            if it has 3 elements the model order has to be set to 2

        Returns:
            List[torch.Tensor]: returns the forward passed list whit the same number of elements as the input
        """
        return self.net(x)

## SINDy Network

In [41]:
class SINDy(torch.nn.Module):
    """

    Description: Custom neural network module that embeds a SINDy model into an autoencoder.

    """

    def __init__(
        self,
        encoder: AutoEncoder,
        decoder: AutoEncoder,
        device: str,
        params: dict = {},
        *args,
        **kwargs
    ) -> None:
        """

        Description: Constructor for the SINDy class. Initializes the model parameters, encoder, and decoder.

        Args:
            encoder (AutoEncoder): The encoder part of the autoencoder.
            decoder (AutoEncoder): The decoder part of the autoencoder.
            params (Dict): A dictionary containing model and SINDy parameters.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        """

        super().__init__(*args, **kwargs)

        # Initialize model parameters, encoder, and decoder
        self.params = params
        self.encoder = encoder
        self.decoder = decoder

        # Initialize autoencoder parameters ----------
        self.input_dim = self.params["input_dim"]
        self.latent_dim = self.params["latent_dim"]

        # Initialize SINDy parameters ------------------------------------------------------

        # Library parameters
        self.poly_order = self.params["poly_order"]
        if self.params["include_sine"]:
            self.include_sine = self.params["include_sine"]
        else:
            self.include_sine = False
        self.library_dim = library_size(
            self.params["latent_dim"], self.params["poly_order"]
        )

        # Coefficient parameters
        self.sequential_thresholding = self.params["sequential_thresholding"]
        self.coefficient_initialization = self.params["coefficient_initialization"]
        self.coefficient_mask = torch.ones((self.library_dim, self.latent_dim)).to(
            device
        )
        self.coefficient_threshold = self.params["coefficient_threshold"]

        # Greek letter 'Xi' in the paper. Learned during training (different from linear regression).
        sindy_coefficients = self.init_sindy_coefficients(
            self.params["coefficient_initialization"]
        )
        # Treat sindy_coefficients as a parameter to be learned and move it to device
        self.sindy_coefficients = torch.nn.Parameter(
            sindy_coefficients.to(torch.float64).to(device)
        )

        # Order of dynamical system
        self.model_order = self.params["model_order"]
        if self.model_order == 1:
            self.forward = self.forward_dx
        else:
            self.forward = self.forward_ddx

    def init_sindy_coefficients(self, name="normal", std=1.0, k=1) -> torch.Tensor:
        """

        Description: Initializes the SINDy coefficients based on the specified method. These coefficients are learned during training.

        Args:
            name (str): The method for initializing the coefficients. Options are 'xavier', 'uniform', 'constant', and 'normal'.
            std (float): Standard deviation for normal initialization.
            k (float): Constant value for constant initialization.

        """

        sindy_coefficients = torch.zeros((self.library_dim, self.latent_dim))

        if name == "xavier":
            return torch.nn.init.xavier_uniform_(sindy_coefficients)
        elif name == "uniform":
            return torch.nn.init.uniform_(sindy_coefficients, low=0.0, high=1.0)
        elif name == "constant":
            return torch.ones_like(sindy_coefficients) * k
        elif name == "normal":
            return torch.nn.init.normal_(sindy_coefficients, mean=0, std=std)

    def forward_dx(self, x, dx) -> torch.Tensor:
        """

        Description: Forward pass for the SINDy model with first-order derivatives.

        Args:
            x (torch.Tensor): Input tensor representing the state of the system.
            dx (torch.Tensor): Input tensor representing the first-order derivatives of the state.

        Returns:
            torch.Tensor: The output tensors including the original state, first-order derivatives, predicted derivatives, and decoded states.

        """

        # pass input through encoder
        out_encode = self.encoder(torch.cat((x, dx)))
        dz = out_encode[out_encode.shape[0] // 2 :]
        z = out_encode[: out_encode.shape[0] // 2]

        # create library
        Theta = sindy_library_pt(z, self.latent_dim, self.poly_order, self.include_sine)

        # apply thresholding or not
        if self.sequential_thresholding:
            sindy_predict = torch.matmul(
                Theta, self.coefficient_mask * self.sindy_coefficients
            )
        else:
            sindy_predict = torch.matmul(Theta, self.sindy_coefficients)

        # decode transformed input (z) and predicted derivatives (z dot)
        x_decode = self.decoder(torch.cat((z, sindy_predict)))
        dx_decode = x_decode[x_decode.shape[0] // 2 :]
        x_decode = x_decode[: x_decode.shape[0] // 2]

        dz_predict = sindy_predict

        return (
            x,
            dx,
            dz_predict,
            dz,
            x_decode,
            dx_decode,
            self.sindy_coefficients,
        )

    def forward_ddx(self, x: torch.Tensor, dx: torch.Tensor, ddx: torch.Tensor):
        """

        Description: Forward pass for the SINDy model with second-order derivatives.

        Args:
            x (torch.Tensor): Input tensor representing the state of the system.
            dx (torch.Tensor): Input tensor representing the first-order derivatives of the state.
            ddx (torch.Tensor): Input tensor representing the second-order derivatives of the state.

        """

        out = self.encoder(torch.cat((x, dx, ddx)))
        slicer = out.shape[0] // 3
        z, dz, ddz = out[:slicer], out[slicer : 2 * slicer], out[2 * slicer :]

        # create Theta
        Theta = sindy_library_pt_order2(
            z, dz, self.latent_dim, self.poly_order, self.include_sine
        )

        # apply thresholding or not
        if self.sequential_thresholding:
            """
            tmp = torch.rand(size=(library_dim,latent_dim), dtype=torch.float32)
            mask = torch.zeros_like(tmp)
            mask = mask.where(self.coefficient_mask, tmp)
            """
            mask = torch.where(
                self.sindy_coefficients > self.coefficient_threshold,
                self.sindy_coefficients,
                0,
            )
            sindy_predict = torch.matmul(Theta, mask * self.sindy_coefficients)
        else:
            sindy_predict = torch.matmul(Theta, self.sindy_coefficients)

        # decode
        out_decode = self.decoder(torch.cat((z, dz, sindy_predict)))
        slicer = out_decode.shape[0] // 3
        x_decode, dx_decode, ddx_decode = (
            out_decode[:slicer],
            out_decode[slicer : 2 * slicer],
            out_decode[2 * slicer :],
        )

        dz_predict = sindy_predict

        return (
            x,
            dx,
            dz_predict,
            dz,
            x_decode,
            dx_decode,
            self.sindy_coefficients,
            sindy_predict,
            z,
        )

# Training

In [44]:
print(training_data['x'].shape[0])

2500


In [42]:
def train(sindy, num_epochs, optimizer, criterion):
    for epoch in tqdm(range(num_epochs), desc="Training"):
        sindy.train()
        optimizer.zero_grad()

        # Forward pass
        (
            x,
            dx,
            dz_predict,
            dz,
            x_decode,
            dx_decode,
            sindy_coefficients,
        ) = sindy(
            torch.from_numpy(training_data["x"]).to(device=device),
            torch.from_numpy(training_data["dx"]).to(device=device),
        )
        loss = criterion(
            x,
            dx,
            dz,
            dz_predict,
            x_decode,
            dx_decode,
            sindy_coefficients,
            sindy.coefficient_mask,
        )
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

        # Backward pass
        loss.backward()
        optimizer.step()

        # Optional coefficient thresholding
        if (
            criterion.regularization
            and params["sequential_thresholding"]
            and (epoch % params["threshold_frequency"] == 0)
            and (epoch > 0)
        ):
            sindy.coefficient_mask = (
                torch.abs(sindy_coefficients) > params["coefficient_threshold"]
            )
            print(
                f"THRESHOLDING: {torch.sum(sindy.coefficient_mask)} active coefficients"
            )
            dill.dump(sindy, open(f"model_lorenz_1_{epoch}", "wb"))


In [None]:
params = json.load(open("params.json"))
encoder = AutoEncoder(params, "encoder")
decoder = AutoEncoder(params, "decoder")
sindy = SINDy(encoder=encoder, decoder=decoder, device=device, params=params)
sindy = sindy.to(device=device)
criterion = Loss(
    params["loss_weight_decoder"],
    params["loss_weight_sindy_z"],
    params["loss_weight_sindy_x"],
    params["loss_weight_sindy_regularization"],
)
optimizer = torch.optim.Adam(sindy.parameters(), lr=params["learning_rate"])
num_epochs = training_data['x'].shape[0]


### Train with regularization
criterion.set_regularization(True)
train(sindy, num_epochs, optimizer, criterion)

dill.dump(sindy, open(f"model_lorenz_1", "wb"))