# An Overview of Symplectic Spectrum Gaussian Processes

This notebook is an exploration and implementation based on the work of Yusuke Tanaka, Tomoharu Iwata, and Naonori Ueda from NTT Communication Science Laboratories. The original code from the authors can be found at [SSGP on GitHub](https://github.com/yusuk-e/SSGP).

## Problem Statement
The primary problem the authors are addressing is learning the dynamics of Hamiltonian systems from noisy and sparse data. In mathematical terms, they are focusing on systems where the evolution over time can be described by Hamiltonian mechanics. The Hamiltonian, $ H(\boldsymbol{x}) $, represents the total energy of the system, and is a function of the state $ \boldsymbol{x} $ in phase space. This state $ \boldsymbol{x} $ combines both generalized coordinates $ \boldsymbol{x}^{\mathrm{q}} $ and momenta $ \boldsymbol{x}^{\mathrm{p}} $.

The key equation governing the dynamics of such a system is given by:

$$
\frac{d \boldsymbol{x}}{d t} = (\mathbf{S} - \mathbf{R}) \nabla H(\boldsymbol{x}) =: \boldsymbol{f}(\boldsymbol{x})
$$

Here, $ \mathbf{S} $ is a skew-symmetric matrix, and $ \mathbf{R} $ is a positive semi-definite dissipation matrix, which introduces dissipative effects like friction into the system. When $ \mathbf{R} = \mathbf{O} $ (the zero matrix), the system conserves total energy. The function $ \boldsymbol{f}(\boldsymbol{x}) $ represents the time derivatives of the state and has a symplectic geometric structure.

The authors' goal is to model these dynamics using a probabilistic model, specifically a Gaussian Process (GP) that accounts for the symplectic structure of Hamiltonian systems. This model is aimed at effectively learning from data that is both sparse (limited in quantity and low in temporal resolution) and noisy. They propose the Symplectic Spectrum Gaussian Processes (SSGPs), which utilizes random Fourier features to efficiently approximate the Hamiltonian dynamics and allows for variational inference using ordinary differential equation solvers. 

The innovation lies in the ability to predict the dynamics of such systems from arbitrary initial conditions and to decompose the dynamics into conservative and dissipative terms, which is particularly challenging when the available data is sparse and noisy. The approach extends to scenarios where derivative observations are unavailable, relying instead on trajectory data.


## Paper Example
![](https://cdn.mathpix.com/cropped/2024_01_06_22f307116735454ec646g-10.jpg?height=598&width=1372&top_left_y=250&top_left_x=366)

Prediction results of SSGP. The color indicates time-evolution, starting at blue and ending at red. The first and second columns are the true trajectories for the dissipative systems and their conservative terms, respectively. The third column is the prediction for the dissipative systems in Task 1. The fourth and fifth columns are the predicted conservative and dissipative terms in Task 2, respectively. Here, the dissipative terms are multiplied by 30 for enhanced clarity. Comparisons with other models are shown in Appendix $\mathrm{G}$.

### Dynamic Systems

The authors test the Symplectic Spectrum Gaussian Processes (SSGP) model on two physical systems: the pendulum and Duffing oscillator. The Hamiltonians are defined as follows:

Pendulum System:
$$
H(\boldsymbol{x}) = 2 m g l(1 - \cos x^{\mathrm{q}}) + \frac{l^{2}(x^{\mathrm{p}})^{2}}{2 m}
$$
Parameters: $ g = 3 $, $ m = l = 1 $.

Duffing Oscillator System:
$$
H(\boldsymbol{x}) = \frac{1}{2}(x^{\mathrm{p}})^{2} + \frac{\alpha}{2}(x^{\mathrm{q}})^{2} + \frac{\beta}{4}(x^{\mathrm{q}})^{4}
$$

### Tasks

**Task 1: Normal Prediction**
Performance was assessed by comparing predicted state trajectories $\{\hat{\boldsymbol{x}}_{i j}\}$ against ground truth $\{\boldsymbol{x}_{i j}^{\text{true}}\}$ from the test set. The metric was mean squared error (MSE): $\frac{1}{I} \sum_{i=1}^{I}\left(\frac{1}{J_{i}} \sum_{j=1}^{J_{i}}\|\hat{\boldsymbol{x}}_{i j}-\boldsymbol{x}_{i j}^{\text{true}}\|^{2}\right)$. Evaluations were done for pendulum and Duffing oscillator systems, both with and without energy dissipation.

**Task 2: Predicting Dynamics for Unseen Friction Coefficients**
We showcased SSGP's ability to separate conservative and dissipative dynamics by training on datasets with friction coefficients ($ \mathbf{R} = \operatorname{diag}(0,0.05) $), then predicting conservative system dynamics ($ \mathbf{R} = \mathbf{O} $). The same MSE metric was used for evaluation.

## Method Overview

![](https://cdn.mathpix.com/cropped/2024_01_06_22f307116735454ec646g-04.jpg?height=333&width=1374&top_left_y=251&top_left_x=381)

Schematic diagram of SSGP: Generative processes of noisy trajectories. (a) We model the unknown Hamiltonian $H(\boldsymbol{x})$ by a single-output GP. Here, the color represents the magnitude of energy. (b) The vector field $\boldsymbol{f}(\boldsymbol{x})$ is calculated by applying the differential operator $\mathcal{L}$ to $H(\boldsymbol{x})$. (c) We sample the initial condition from the standard Gaussian distribution and solve the ODE defined by $\boldsymbol{f}(\boldsymbol{x})$ to obtain the noiseless trajectory $\left\{\boldsymbol{x}_{i j}\right\}$ depicted by black dots. (d) The noisy trajectory $\left\{\boldsymbol{y}_{i j}\right\}$ depicted by red dots is observed by adding Gaussian noise.

### Hamiltonian Mechanics

Consider a system with $N$ degrees of freedom. In the Hamiltonian formalism, the continuous-time evolution of the system is described in phase space, that is, the product space of generalized coordinates $\boldsymbol{x}^{\mathrm{q}}=\left(x_{1}^{\mathrm{q}}, \ldots, x_{N}^{\mathrm{q}}\right)$ and generalized momenta $\boldsymbol{x}^{\mathrm{p}}=\left(x_{1}^{\mathrm{p}}, \ldots, x_{N}^{\mathrm{p}}\right)$. Let $\boldsymbol{x}=\left(\boldsymbol{x}^{\mathrm{q}}, \boldsymbol{x}^{\mathrm{p}}\right) \in \mathbb{R}^{D}$ be a state of the system, where $D=2 N$. The system's evolution is determined by the Hamiltonian $H(\boldsymbol{x}): \mathbb{R}^{D} \rightarrow \mathbb{R}$, which denotes the system's total energy. Traditionally, the Hamiltonian is manually designed to suit the system. The dynamics of a Hamiltonian system with additive dissipative terms is given by

$$
\frac{d \boldsymbol{x}}{d t}=(\mathbf{S}-\mathbf{R}) \nabla H(\boldsymbol{x})=: \boldsymbol{f}(\boldsymbol{x}), \quad \text { where } \quad \mathbf{S}=\left(\begin{array}{cc}
\mathbf{O} & \mathbf{I} \\
-\mathbf{I} & \mathbf{O} 
\end{array}\right) \quad \quad (1)
$$

Here, $\nabla H(\boldsymbol{x}): \mathbb{R}^{D} \rightarrow \mathbb{R}^{D}$ is the gradient of the Hamiltonian with respect to state $\boldsymbol{x}, \mathbf{S} \in \mathbb{R}^{D \times D}$ is the skew-symmetric matrix, $\mathbf{R} \in \mathbb{R}^{D \times D}$ is the positive semi-definite dissipation matrix, $\mathbf{I}$ is the identity matrix, and $\mathbf{O}$ is the zero matrix. In (1), we define the time derivatives of the state by the function $\boldsymbol{f}(\boldsymbol{x}): \mathbb{R}^{D} \rightarrow \mathbb{R}^{D}$, which is a special kind of vector field that has a symplectic geometric structure (called Hamiltonian vector field or symplectic gradient). The dynamics on this vector field conserve the total energy when $\mathbf{R}=\mathbf{O}$. Given vector field $\boldsymbol{f}(x)$ and initial condition $\boldsymbol{x}_{1}$ at time $t_{1}$, one can predict state $\boldsymbol{x}_{t}$ at time $t$ by integrating $\boldsymbol{f}(\boldsymbol{x})$ from $t_{1}$ to $t$, as follows: $\boldsymbol{x}_{t}=\boldsymbol{x}_{1}+\int_{t_{1}}^{t} \boldsymbol{f}(\boldsymbol{x}) d t$.


### GP Priors for Hamiltonian Systems with Additive Dissipation

In SSGP, the Hamiltonian $ H(\boldsymbol{x}) $ is modeled as a Gaussian Process (GP) with zero mean. The vector field $ \boldsymbol{f}(\boldsymbol{x}) $ is derived from this GP using the differential operator $ \mathcal{L} = (\mathbf{S} - \mathbf{R}) \nabla $:

$$
\boldsymbol{f}(\boldsymbol{x}) = \mathcal{L} H(\boldsymbol{x}), \quad H(\boldsymbol{x}) \sim \mathcal{G}\mathcal{P}\left(0, \gamma(\boldsymbol{x}, \boldsymbol{x}^{\prime})\right)  \quad \quad (2)
$$

Here, $ \gamma(\boldsymbol{x}, \boldsymbol{x}^{\prime}) $ is the covariance function. The vector field is a multi-output GP:

$$
\boldsymbol{f}(\boldsymbol{x}) \sim \mathcal{G}\mathcal{P}(\mathbf{0}, \mathbf{K}(\boldsymbol{x}, \boldsymbol{x}^{\prime}))  \quad \quad (3)
$$

The covariance function $ \mathbf{K}(\boldsymbol{x}, \boldsymbol{x}^{\prime}) $ incorporates the geometric structure of Hamiltonian systems:

$$
\mathbf{K}(\boldsymbol{x}, \boldsymbol{x}^{\prime}) = \mathcal{L} \mathcal{L}^{\top} \gamma(\boldsymbol{x}, \boldsymbol{x}^{\prime}) = (\mathbf{S} - \mathbf{R}) \nabla^{2}(\mathbf{S} - \mathbf{R})^{\top} \gamma(\boldsymbol{x}, \boldsymbol{x}^{\prime})  \quad \quad (4)
$$

A common choice for $ \gamma $ is the ARD Gaussian kernel:

$$
\gamma(\boldsymbol{x}, \boldsymbol{x}^{\prime}) = \sigma_{0}^{2} \exp\left(-\frac{1}{2}(\boldsymbol{x} - \boldsymbol{x}^{\prime})^{\top} \boldsymbol{\Lambda}^{-1}(\boldsymbol{x} - \boldsymbol{x}^{\prime})\right)  \quad \quad (5)
$$

where $ \sigma_{0}^{2} $ is the signal variance and $ \boldsymbol{\Lambda} = \operatorname{diag}(\lambda_{1}^{2}, \ldots, \lambda_{D}^{2}) $ are the length scales.


### Spectral Representations

In SSGP, the Hamiltonian $H(\boldsymbol{x})$ is approximated as a Gaussian Process (GP) using Random Fourier Features (RFF) that capture symplectic structures in Hamiltonian systems. This approach is particularly beneficial for estimating GP posteriors and efficiently sampling vector fields, especially when only trajectory data are available.

The Hamiltonian is represented as:

$$
H(\boldsymbol{x}) = \sum_{m=1}^{M} \boldsymbol{w}_{m} \boldsymbol{\phi}_{m}(\boldsymbol{x}), \quad \boldsymbol{w}_{m} \sim \mathcal{N}\left(\mathbf{0}, \frac{\sigma_{0}^{2}}{M} \mathbf{I}\right)  \quad \quad (6)
$$

where $\boldsymbol{\phi}_{m}(\boldsymbol{x})$ are the basis functions and $\boldsymbol{w}_{m}$ are the weights. The spectral points $s_{m}$ are sampled from the spectral density of the kernel:

$$
p(\boldsymbol{s}) = \mathcal{N}\left(\mathbf{0},\left(4 \pi^{2} \boldsymbol{\Lambda}\right)^{-1}\right)  \quad \quad (7)
$$

Applying the differential operator $\mathcal{L}$ to this approximation yields the spectral representation of the vector field:

$$
\boldsymbol{f}(\boldsymbol{x}) = \mathcal{L} H(\boldsymbol{x}) =: \boldsymbol{\Psi}(\boldsymbol{x}) \boldsymbol{w}^{\top}  \quad \quad (8)
$$

with the feature maps $\boldsymbol{\Psi}(\boldsymbol{x})$ defined as:

$$
\boldsymbol{\Psi}_{m}(\boldsymbol{x}) = 2 \pi(\mathbf{S}-\mathbf{R}) \boldsymbol{s}_{m}\left[-\sin \left(2 \pi \boldsymbol{s}_{m}^{\top} \boldsymbol{x}\right), \cos \left(2 \pi \boldsymbol{s}_{m}^{\top} \boldsymbol{x}\right)\right]  \quad \quad (9)
$$

These are referred to as Symplectic Random Fourier Features (S-RFF). This novel integration of symplectic structures into random features bridges GP modeling and Hamiltonian mechanics. The distribution of $\boldsymbol{f}$ is obtained by integrating out $\boldsymbol{w}$:

$$
p(\boldsymbol{f}) = \mathcal{N}\left(\mathbf{0}, \tilde{\mathbf{K}}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right)\right), \quad \tilde{\mathbf{K}}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right) = \frac{\sigma_{0}^{2}}{M} \boldsymbol{\Psi}(\boldsymbol{x}) \boldsymbol{\Psi}\left(\boldsymbol{x}^{\prime}\right)^{\top}  \quad \quad (10)
$$

The covariance function $\tilde{\mathbf{K}}$ for the GP approximation improves in quality with more spectral points, as shown by comparing the gram matrices of $\mathbf{K}$ and $\tilde{\mathbf{K}}$.


### Generative processes of noisy observations
Suppose that we have a collection of $I$ trajectories $\left\{\left(t_{i j}, \boldsymbol{y}_{i j}\right) \mid i=1, \ldots, I ; j=1, \ldots, J_{i}\right\}$, where $J_{i}$ is the number of samples in the $i$ th trajectory. Each sample is specified by a pair $\left(t_{i j}, \boldsymbol{y}_{i j}\right)$, which represents the observation of noisy state $\boldsymbol{y}_{i j}$ at time $t_{i j}$. We treat the noiseless state $\boldsymbol{x}_{i j}$, the counterpart of $\boldsymbol{y}_{i j}$, as a latent variable. We assume that the observation model of $\boldsymbol{y}_{i j}$ is a Gaussian distribution with a variance of $\sigma^{2}$. Letting $\mathbf{Y}=\left\{\boldsymbol{y}_{i j}\right\}$, the marginal likelihood (i.e., evidence) is given by

$$
p(\mathbf{Y})=\int p(\boldsymbol{f}) \prod_{i=1}^{I}\left[\int p\left(\boldsymbol{y}_{i 1} \mid \boldsymbol{x}_{i 1}\right) p\left(\boldsymbol{x}_{i 1}\right) \prod_{j=2}^{J_{i}} p\left(\boldsymbol{y}_{i j} \mid \boldsymbol{x}_{i j}\right) p\left(\boldsymbol{x}_{i j} \mid \boldsymbol{f}, \boldsymbol{x}_{i 1}\right) d \boldsymbol{x}_{i 1}\right] d \boldsymbol{f}  \quad \quad (11)
$$

where $p(\boldsymbol{f})$ is the GP prior (10) of the vector field, and $p\left(\boldsymbol{x}_{i 1}\right)$ is the prior distribution ${ }^{4}$ of the initial condition $\boldsymbol{x}_{i 1}$. Given $\boldsymbol{f}$ and $\boldsymbol{x}_{i 1}$, the state $\boldsymbol{x}_{i j}$ is deterministically given by solving the ODE; thus, we can write the conditional distribution $p\left(\boldsymbol{x}_{i j} \mid \boldsymbol{f}, \boldsymbol{x}_{i 1}\right)$ in (11) using Dirac's delta function, as follows:

$$
p\left(\boldsymbol{x}_{i j} \mid \boldsymbol{f}, \boldsymbol{x}_{i 1}\right)=\delta\left(\boldsymbol{x}_{i j}-\left[\boldsymbol{x}_{i 1}+\int_{t_{i 1}}^{t_{i j}} \boldsymbol{f}(\boldsymbol{x}) d t\right]\right)  \quad \quad (12)
$$

Note that, although we omit the observation time points $\left\{t_{i j}\right\}$ in (11), it is actually conditioned on $\left\{t_{i j}\right\}$. For simplicity, we adopt this notation hereinafter.

### Inference

Exact calculation of the marginal likelihood in our model is intractable due to the ODE solving process. 

### Parameter Learning
We use the evidence lower bound (ELBO) for parameter estimation:

$$
\log p(\mathbf{Y}) \geq \sum_{i=1}^{I}\left[\sum_{j=1}^{J_{i}} \mathbb{E}_{q\left(\boldsymbol{x}_{i j}\right)}\left[\log p\left(\boldsymbol{y}_{i j} \mid \boldsymbol{x}_{i j}\right)\right]-\mathrm{KL}\left[q\left(\boldsymbol{x}_{i 1}\right) \| p\left(\boldsymbol{x}_{i 1}\right)\right]\right]-\mathrm{KL}[q(\boldsymbol{w}) \| p(\boldsymbol{w})] \quad (13)
$$

The variational distributions for the state $\boldsymbol{x}_{i j}$ and weights $\boldsymbol{w}$ are assumed to be Gaussian. For the initial condition, the variational distribution is Gaussian centered at the observed state:

$$
q\left(\boldsymbol{x}_{i j}\right) = \iint p\left(\boldsymbol{x}_{i j} \mid \boldsymbol{f}, \boldsymbol{x}_{i 1}\right)\left[\int p(\boldsymbol{f} \mid \boldsymbol{w}) q(\boldsymbol{w}) d \boldsymbol{w}\right] q\left(\boldsymbol{x}_{i 1}\right) d \boldsymbol{x}_{i 1} d \boldsymbol{f} \quad (14)
$$

The expectation in ELBO is approximated using Monte Carlo integration:

$$
\mathbb{E}_{q\left(\boldsymbol{x}_{i j}\right)}\left[\log p\left(\boldsymbol{y}_{i j} \mid \boldsymbol{x}_{i j}\right)\right] \approx \frac{1}{K} \sum_{k=1}^{K} \log p\left(\boldsymbol{y}_{i j} \mid \boldsymbol{x}_{i j}^{(k)}\right) \quad (15)
$$

Monte Carlo samples are generated as follows:

$$
\begin{aligned}
\boldsymbol{x}_{i 1}^{(k)} & = \boldsymbol{y}_{i 1}+\sqrt{\mathbf{A}} \boldsymbol{\epsilon}_{i}^{(k)}, \quad \boldsymbol{\epsilon}_{i}^{(k)} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \quad (16) \\
\boldsymbol{x}_{i 2}^{(k)}, \ldots, \boldsymbol{x}_{i J_{i}}^{(k)} & = \text{ODESolve}\left(\boldsymbol{x}_{i 1}^{(k)}, \boldsymbol{f}^{(k)}(\boldsymbol{x}), t_{i 2}, \ldots, t_{i J_{i}}\right) \quad (17)
\end{aligned}
$$

The sample of the vector field from the variational posterior is:

$$
\boldsymbol{f}^{(k)}(\boldsymbol{x})=\boldsymbol{\Psi}(\boldsymbol{x})\left[\frac{1}{L} \sum_{l=1}^{L} \boldsymbol{w}^{(k, l)}\right]^{\top}, \quad \boldsymbol{w}^{(k, l)}=\boldsymbol{b}+\sqrt{\mathbf{C}} \boldsymbol{\epsilon}^{(k, l)} \quad (18)
$$

Parameter optimization accounts for uncertainties in the vector field by sampling $\boldsymbol{w}$ during training. Our model extends to high-dimensional data by combining with an autoencoder, where $\boldsymbol{x}$ and $\boldsymbol{x}^{\prime}$ in $\mathbf{K}(\boldsymbol{x}, \boldsymbol{x}^{\prime})$ are latent vectors.

### Prediction
The variational posterior $q(\boldsymbol{f})$ is Gaussian with mean $\tilde{\boldsymbol{m}}^{*}(\boldsymbol{x})=\boldsymbol{\Psi}(\boldsymbol{x}) \boldsymbol{b}^{\top}$ and covariance $\tilde{\mathbf{K}}^{*}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right)=\boldsymbol{\Psi}(\boldsymbol{x}) \mathbf{C} \boldsymbol{\Psi}^{\top}\left(\boldsymbol{x}^{\prime}\right)$. Predictions and their uncertainties are made by integrating $\tilde{\boldsymbol{m}}^{*}(\boldsymbol{x})$ and using $\tilde{\mathbf{K}}^{*}\left(\boldsymbol{x}, \boldsymbol{x}^{\prime}\right)$.

### On Computational Complexity
The primary computational cost lies in sampling $w$, specifically in computing $\sqrt{\mathbf{C}}$ which is $\mathcal{O}\left(M^{3}\right)$. While cost increases with more basis functions, it is important that this cost is not part of the ODE solver process.

In [5]:
! pip install autograd
! pip install torchdiffeq



In [6]:
# Symplectic Spectrum Gaussian Processes | 2022
# Yusuke Tanaka

import os
import csv
import pickle
import pandas as pd
import numpy as np

def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def csv_read(file):
    f = open(file)
    csvReader = csv.reader(f)
    D = []
    for row in csvReader:
        D.append(row)
    return D

def csv_write(file, D):
    f = open(file,'w')
    csvWriter = csv.writer(f,lineterminator='\n')
    if np.ndim(D) == 1:
        csvWriter.writerow(D)
    elif np.ndim(D) == 2:
        for i in range(np.shape(D)[0]):
            line = D[i]
            csvWriter.writerow(line)
    f.close()

def pkl_read(file):
    f = open(file, 'rb')
    D = pickle.load(f)
    f.close()
    return D

def pkl_write(file, D):
    f = open(file,'wb')
    pickle.dump(D,f,protocol=4)
    f.close()


In [7]:
# Symplectic Spectrum Gaussian Processes | 2022
# Yusuke Tanaka

import pdb
import json
import argparse
import math
import autograd.numpy as np
import autograd
import torch
from torch import nn
from torchdiffeq import odeint
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import os, sys

torch.set_default_dtype(torch.float64)
%pwd
THIS_DIR = %pwd  # This will set THIS_DIR to the current working directory

xmin = -3.2; xmax = 3.2; ymin = -3.2; ymax = 3.2
DPI = 200
FORMAT = 'pdf'
LINE_SEGMENTS = 10
ARROW_SCALE = 100
ARROW_WIDTH = 6e-3
LINE_WIDTH = 2

class Args:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)

args = Args({
    'batch_time': 1,        
    'learn_rate': 1e-3,      
    'total_steps': 100,     
    'print_every': 100,      
    'sigma': 0.1,            
    'eta': 0.0,                
    'samples': 10,           
    'timescale': 3,          
    'name': 'pendulum',      
    's': 0,                  
    'gridsize': 15,          
    'seed': 0,              
    'num_basis': 100,      
    'friction': False,      
    'train_samples': 5, 
    'val_samples': 5,
    'datasets': 5,
    'T': 5,
    'radius_a': 1.,
    'radius_b': 1.,
    'seed': 0,
    'save_dir': THIS_DIR,
    'input_dim': 2
})


class ODE_pendulum(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.M = self.permutation_tensor(input_dim)

    def forward(self, t, x):
        H = self.H(x)
        dH = torch.autograd.grad(H.sum(), x)[0]
        field = dH @ self.M.t()
        dH[:,0] = 0
        field = field - args.eta * dH
        return field
    
    def time_derivative(self, x):
        H = self.H(x)
        dH = torch.autograd.grad(H.sum(), x)[0]
        field = dH @ self.M.t()
        dH[:,0] = 0
        field = field - args.eta * dH
        return field
        
    def H(self, coords):
        q, p = coords[:,0], coords[:,1]
        H = 3*(1-torch.cos(q)) + p**2
        return H

    def permutation_tensor(self,n):
        M = torch.eye(n)
        M = torch.cat([M[n//2:], -M[:n//2]])
        return M

class ODE_duffing(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.M = self.permutation_tensor(input_dim)

    def forward(self, t, x):
        H = self.H(x)
        dH = torch.autograd.grad(H.sum(), x)[0]
        field = dH @ self.M.t()
        dH[:,0] = 0
        field = field - args.eta * dH
        return field
    
    def time_derivative(self, x):
        H = self.H(x)
        dH = torch.autograd.grad(H.sum(), x)[0]
        field = dH @ self.M.t()
        dH[:,0] = 0
        field = field - args.eta * dH
        return field
        
    def H(self, coords):
        if len(coords) == 2:
            q, p = coords[0], coords[1]
        else:
            q, p = coords[:,0], coords[:,1]
        H = .5*p**2 - .5*q**2 + .25*q**4
        return H

    def permutation_tensor(self,n):
        M = torch.eye(n)
        M = torch.cat([M[n//2:], -M[:n//2]])
        return M

def vis_obs(save_dir, field, data, trajectory_name):
    y = data[trajectory_name]
    t_eval = data['t']
    fig = plt.figure(figsize=(11.3, 21), facecolor='white', dpi=DPI)
    N = y.shape[0]
    N = 28 if N > 28 else N
    for i in range(N):
        t = torch.tensor(np.linspace(0, t_eval[-1], t_eval.shape[0]))
        ax = fig.add_subplot(math.ceil(N/4), 4, i+1, frameon=True)
        ax.set_aspect('equal', adjustable='box')
        ax.quiver(field['x'][:,0], field['x'][:,1], field['dx'][:,0], field['dx'][:,1],
                   scale=ARROW_SCALE, width=ARROW_WIDTH,
                   cmap='gray_r', color=(.5,.5,.5))
        ax.scatter(y[i][:,0], y[i][:,1], c=t, s=16, cmap='coolwarm')
        plt.axis([xmin, xmax, ymin, ymax])
        plt.xlabel("$q$", fontsize=12)
        plt.ylabel("$p$", rotation=0, fontsize=12)
        plt.title("Sample " + str(i+1))
    plt.tight_layout()
    fig.savefig('{}/{}.{}'.format(save_dir, trajectory_name, FORMAT))
    plt.close()

def vis_energy(save_dir, data, e_name):
    es = data[e_name]
    t_eval = data['t']
    
    fig = plt.figure(figsize=(11.3, 21), facecolor='white', dpi=DPI)
    N = es.shape[0]
    N = 28 if N > 28 else N
    for i in range(N):
        t = t_eval
        ax = fig.add_subplot(math.ceil(N/4), 4, i+1, frameon=True)
        ax.plot(t, es[i],'-', color='black')
        ymax = data['es'].max()
        ax.axis([0, args.T, 0, ymax*1.5])
        plt.xlabel("time", fontsize=12)
        plt.ylabel("Energy", rotation=90, fontsize=12)
        plt.title("Sample " + str(i+1))
    plt.tight_layout()
    fig.savefig('{}/{}.{}'.format(save_dir, e_name, FORMAT))
    plt.close()

def get_field(ode, xmin, xmax, ymin, ymax, gridsize):
    field = {}

    # meshgrid to get vector field
    b, a = np.meshgrid(np.linspace(xmin, xmax, gridsize), np.linspace(ymin, ymax, gridsize))
    x = np.stack([b.flatten(), a.flatten()]).T
    x = torch.tensor(x, requires_grad=True)

    # get vector directions
    dx = ode.time_derivative(x)
    field['x'] = x.detach().numpy()
    field['dx'] = dx.detach().numpy()
    field['mesh_a'] = a
    field['mesh_b'] = b
    return field

def vis_field(save_dir, field, name):
    a = field['mesh_a']
    b = field['mesh_b']
    fig = plt.figure(figsize=(4,3), facecolor='white', dpi=DPI)
    ax = fig.subplots()
    ax.set_aspect('equal', adjustable='box')
    scale = ARROW_SCALE
    ax.quiver(field['x'][:,0], field['x'][:,1], field['dx'][:,0], field['dx'][:,1],
              scale=scale, width=ARROW_WIDTH)
    plt.tight_layout()
    fig.savefig('{}/{}.{}'.format(save_dir, name, FORMAT))
    plt.close()

def get_init():

    N = samples
    if args.name in ['pendulum']:
        np.random.seed(args.seed)
        x0s = np.random.rand(N,2)*2.-1
        x0s = x0s.T
        np.random.seed(args.seed)
        radius = np.random.rand(N,1)*args.radius_a+args.radius_b
        x0s = (x0s / np.sqrt((x0s**2).sum(0))).T * radius
        x0s = torch.tensor(x0s, requires_grad=True)

    elif args.name in ['duffing']:
        x0s = []
        np.random.seed(args.seed)
        while(1):
            x0 = np.random.rand(2)*6.-3
            en = ode.H(torch.tensor(x0))
            if (args.radius_b <= en) and (en <= args.radius_a+args.radius_b):
                x0s.append(x0)
                if len(x0s) == N:
                    break
        x0s = torch.tensor(np.stack(x0s), requires_grad=True)

    return x0s

def path_arrange(path):
  x = []
  for i in range(path.shape[1]):
    x.append(path[:,i,:])
  return torch.stack(x)

def generate_data(ode):

    data = {'meta': locals()}
    dt = 1/args.timescale
    
    xs, ys, dys, es = [], [], [], []
    x0s = get_init()
    t = torch.tensor(np.linspace(0, args.T, int(args.timescale*args.T+1)))
    xs = odeint(ode, x0s, t, method='dopri5', atol=1e-8, rtol=1e-8)
    xs = path_arrange(xs)
    for x in xs:
        e = ode.H(x)
        es.append(e)
    es = torch.stack(es)
    np.random.seed(args.seed)
    noise = np.random.normal(0,args.sigma,[xs.shape[0],xs.shape[1],xs.shape[2]])
    ys = xs + torch.tensor(noise)
    
    for y in ys:
        dy = torch.diff(y, dim=0) / dt
        dys.append(dy)
    dys = torch.stack(dys)

    data['xs'] = xs.detach().numpy()
    data['ys'] = ys.detach().numpy()
    data['dys'] = dys.detach().numpy()
    data['t'] = t.detach().numpy()
    data['es'] = es.detach().numpy()
    field = get_field(ode, xmin, xmax, ymin, ymax, args.gridsize)

    vis_obs(save_dir, field, data, trajectory_name='xs')
    vis_obs(save_dir, field, data, trajectory_name='ys')
    vis_field(save_dir, field, 'field')
    vis_energy(save_dir, data, e_name='es')
    return data

def split(s):
    train_split_id = args.train_samples
    ids = [i for i in range(samples)]
    if samples == 10:
        np.random.seed(s)
        np.random.shuffle(ids)
        train_ids = ids[:train_split_id]; val_ids = ids[train_split_id:]
    else:
        if samples == 15:
            prev_samples = 10
        elif samples == 20:
            prev_samples = 15
        elif samples == 30:
            prev_samples = 20
        elif samples == 50:
            prev_samples = 30
            
        prev_dir = ( args.save_dir + '/' + args.name  + '/' + str(args.eta) + '/train/'
                    + str(args.sigma) + '/' + str(prev_samples) + '/' + str(args.timescale))
        filename = prev_dir + '/' + str(s) + '/train_ids.csv'
        train_ids = csv_read(filename)
        train_ids = list(np.array(train_ids[0]).astype(int))

        a_ids = tuple(set(ids) - set(train_ids))
        a_ids = [a_ids[i] for i in range(len(a_ids))]
        a_samples = train_split_id - len(train_ids)

        np.random.seed(s)
        np.random.shuffle(a_ids)
        train_ids.extend(a_ids[:a_samples]); val_ids = a_ids[a_samples:]

    split_data = {}
    for k in ['xs', 'ys', 'dys', 'es']:
        split_data['val_' + k], split_data[k] = data[k][val_ids], data[k][train_ids]
    split_data['val_t'], split_data['t'] = data['t'], data['t']

    return split_data, train_ids


if __name__ == "__main__":
    samples = args.train_samples + args.val_samples
    if args.name == 'pendulum':
        ode = ODE_pendulum(args.input_dim)
    elif args.name == 'duffing':
        ode = ODE_duffing(args.input_dim)

    save_dir = ( args.save_dir + '/' + args.name  + '/' + str(args.eta) + '/train/' + str(args.sigma)
                 + '/' + str(samples) + '/' + str(args.timescale))
    os.makedirs(save_dir) if not os.path.exists(save_dir) else None
    data = generate_data(ode)
    pkl_write(save_dir + '/data.pkl', data)
    field = get_field(ode, xmin, xmax, ymin, ymax, args.gridsize)
    pkl_write(save_dir + '/field.pkl', field)

    for s in range(args.datasets):
        split_data, train_ids = split(s)
        save_dir_s = save_dir + '/' + str(s)
        os.makedirs(save_dir_s) if not os.path.exists(save_dir_s) else None
        pkl_write(save_dir_s + '/dataset.pkl', split_data)
        csv_write(save_dir_s + '/train_ids.csv', train_ids)

        vis_obs(save_dir_s, field, split_data, trajectory_name='ys')
        vis_obs(save_dir_s, field, split_data, trajectory_name='val_ys')
        

    filename = '{}/{}.json'.format(save_dir, args.name)
    with open(filename, 'w') as f:
        json.dump(vars(args), f)


In [8]:
# Symplectic Spectrum Gaussian Processes | 2022
# Yusuke Tanaka

import matplotlib.pyplot as plt
import seaborn as sns

dpi=100
sns.set()
sns.set_style("whitegrid", {'grid.linestyle': '--'})
sns.set_context("paper", 1.5, {"lines.linewidth": 1.5})
sns.set_palette("deep")

def plot(file, x, ys, xlabel, ylabel, legend):
    colors = ['blue','orange','green']
    fig = plt.figure(figsize=(8, 4), facecolor='white', dpi=dpi)
    for i in range(2):
        ax = fig.add_subplot(1, 2, i+1, frameon=True)
        ax.plot(x, ys[i], c=colors[i])
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        plt.title(legend[i])
        ax.yaxis.offsetText.set_fontsize(16)
        plt.gca().ticklabel_format(style="sci", scilimits=(0,0), axis="y")
    plt.tight_layout()
    plt.savefig(file, format='pdf')
    plt.close()


In [9]:
# Symplectic Spectrum Gaussian Processes | 2022
# Yusuke Tanaka

import math
import numpy as np
import os, torch, pickle, zipfile
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import seaborn as sns
import pdb
torch.set_default_dtype(torch.float64)

DPI = 200
FORMAT = 'pdf'
LINE_SEGMENTS = 10
ARROW_SCALE = 100
ARROW_WIDTH = 6e-3
LINE_WIDTH = 2
xmin = -3.2; xmax = 3.2; ymin = -3.2; ymax = 3.2


def get_batch(args, x, t_eval, batch_step):
  n_samples, n_points, input_dim = x.shape
  N = n_samples
  n_ids = torch.from_numpy(np.arange(N))
  p_ids = torch.from_numpy(np.random.choice(np.arange(n_points-batch_step, dtype=np.int64), N, replace=True))
  batch_x0 = x[n_ids,p_ids].reshape([N,1,input_dim])
  batch_step += 1
  batch_t = t_eval[:batch_step]
  batch_x = ( torch.stack([x[n_ids, p_ids+i] for i in range(batch_step)], dim=0)
              .reshape([batch_step,N,1,input_dim]) )
  return batch_x0, batch_t, batch_x

def arrange(args, x, t_eval):
  n_samples, n_points, input_dim = x.shape
  n_ids = np.arange(n_samples, dtype=np.int64)
  p_ids = np.array([0]*n_samples)
  batch_x0 = x[n_ids,p_ids].reshape([n_samples,1,input_dim])
  batch_t = t_eval
  batch_x = torch.stack([x[n_ids, p_ids+i] for i in range(n_points)],dim=0).reshape([n_points,n_samples,1,input_dim])
  return batch_x0, batch_t, batch_x

def get_field(func, xmin, xmax, ymin, ymax, gridsize):
  field = {'meta': locals()}

  # meshgrid to get vector field
  b, a = np.meshgrid(np.linspace(xmin, xmax, gridsize), np.linspace(ymin, ymax, gridsize))
  ys = np.stack([b.flatten(), a.flatten()])
  ys = torch.tensor( ys, dtype=torch.float64, requires_grad=True).t()

  # get vector directions
  dydt = func(torch.tensor([0]),ys)
  field['x'] = ys.cpu().detach().numpy()
  field['dx'] = dydt.squeeze().cpu().detach().numpy()
  field['mesh_a'] = a
  field['mesh_b'] = b
  return field

def vis_path(filename, field, y, t, xmin, xmax, ymin, ymax):
  fig = plt.figure(figsize=(21, 11.3), facecolor='white', dpi=DPI)
  N = y.shape[0]
  N = 28 if N > 28 else N
  for i in range(N):
    ax = fig.add_subplot(math.ceil(N/7), 7, i+1, frameon=True)
    ax.set_aspect('equal', adjustable='box')
    ax.quiver(field['x'][:,0], field['x'][:,1], field['dx'][:,0], field['dx'][:,1],
              scale=ARROW_SCALE, width=ARROW_WIDTH,
              cmap='gray_r', color=(.5,.5,.5))
    ax.scatter(y[i][:,0], y[i][:,1], c=t, s=0.5, cmap='coolwarm')
    plt.axis([xmin, xmax, ymin, ymax])
    plt.xlabel("$x_q$", fontsize=12)
    plt.ylabel("$x_p$", rotation=0, fontsize=12)
    plt.title("Sample " + str(i+1))
    plt.grid(False)
  plt.tight_layout()
  fig.savefig(filename)
  plt.close()

def vis_path_2d(filename, y, t_eval):
  fig = plt.figure(figsize=(21, 11.3), facecolor='white', dpi=DPI)
  N = y.shape[0]
  N = 28 if N > 28 else N
  for i in range(N):
    xs1 = y[i,:,0]; xs2 = y[i,:,1]
    ax = fig.add_subplot(math.ceil(N/7), 7, i+1, frameon=True)
    ax.scatter(t_eval, xs1, s=0.5)
    ax.scatter(t_eval, xs2, s=0.5)
    plt.axis([t_eval.min().item(), t_eval.max().item(), -3, 3])
    plt.xlabel("$t$", fontsize=12)
    plt.ylabel("$x_q$", rotation=0, fontsize=12)
    plt.title("Sample " + str(i+1))
    plt.grid(False)
  plt.tight_layout()
  fig.savefig(filename)
  plt.close()

def vis_field(filename, field, xmin, xmax, ymin, ymax):
  fig = plt.figure(figsize=(4,3), facecolor='white', dpi=DPI)
  ax = fig.subplots()
  ax.set_aspect('equal', adjustable='box')
  scale = ARROW_SCALE
  ax.quiver(field['x'][:,0], field['x'][:,1], field['dx'][:,0], field['dx'][:,1],
            scale=scale, width=ARROW_WIDTH,
            cmap='gray_r', color=(.5,.5,.5))
  plt.axis([xmin, xmax, ymin, ymax])
  plt.xlabel("$x_q$", fontsize=12)
  plt.ylabel("$x_p$", rotation=0, fontsize=12)
  plt.grid(False)
  plt.tight_layout()
  fig.savefig(filename)
  plt.close()

def vis_err(filename, es, t):
  fig = plt.figure(figsize=(21, 11.3), facecolor='white', dpi=DPI)
  N = es.shape[0]
  N = 28 if N > 28 else N
  for i in range(N):
    ax = fig.add_subplot(math.ceil(N/7), 7, i+1, frameon=True)
    ax.plot(t, es[i],'-', color='black')
    ax.axis([0, t.max(), 0, es.max()*1.2])
    plt.xlabel("time", fontsize=12)
    plt.ylabel("MSE", rotation=90, fontsize=12)
    plt.title("Sample " + str(i+1))
  plt.tight_layout()
  fig.savefig(filename)
  plt.close()

def vis_energy(filename, true, es, t):
  fig = plt.figure(figsize=(21, 11.3), facecolor='white', dpi=DPI)
  N = es.shape[0]
  N = 28 if N > 28 else N
  if es.max() > 0:
    ymax = es.max() if true.max() < es.max() else true.max()
  else:
    ymax = es.min() if true.min() > es.min() else true.min()
  for i in range(N):
    ax = fig.add_subplot(math.ceil(N/7), 7, i+1, frameon=True)
    ax.plot(t, true[i],'-', color='black')
    ax.plot(t, es[i],'-', color='red')
    ax.axis([0, t.max(), 0, ymax*1.2])
    plt.xlabel("time", fontsize=12)
    plt.ylabel("Energy", rotation=90, fontsize=12)
    plt.title("Sample " + str(i+1))
  plt.tight_layout()
  fig.savefig(filename)
  plt.close()

def path_arrange(path):
  x = []
  for i in range(path.shape[1]):
    x.append(path[:,i,:])
  return torch.stack(x)

def d_pendulum_energy(coords):
  q1, q2, p1, p2 = coords[:,:,0], coords[:,:,1], coords[:,:,2], coords[:,:,3]
  m1 = .2; m2 = .1
  H = ( (m2*p1**2 + (m1+m2)*p2**2 - 2*m2*p1*p2*np.cos(q1-q2))
        / (2*m2*(m1+m2*np.sin(q1-q2)**2))
        - (m1+m2)*9.8*np.cos(q1)
        - m2*9.8*np.cos(q2) )
  return H

def pendulum_energy(coords):
    qs = coords[:,:,0]; ps = coords[:,:,1]
    energy = 3*(1-np.cos(qs)) + ps**2
    return energy

def duffing_energy(coords):
    qs = coords[:,:,0]; ps = coords[:,:,1]
    energy= .5*ps**2 - .5*qs**2 + .25*qs**4
    return energy

def real_pend_energy(coords):
    qs = coords[:,:,0]; ps = coords[:,:,1]
    energy = 2.4*(1-np.cos(qs)) + ps**2
    return energy


In [10]:
##This code is served in https://github.com/steveli/pytorch-sqrtm.git

import torch
from torch.autograd import Function
import numpy as np
import scipy.linalg


class MatrixSquareRoot(Function):
    """Square root of a positive definite matrix.

    NOTE: matrix square root is not differentiable for matrices with
          zero eigenvalues.
    """
    @staticmethod
    def forward(ctx, input):
        m = input.detach().cpu().numpy().astype(np.float_)
        sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).to(input)
        ctx.save_for_backward(sqrtm)
        return sqrtm

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        if ctx.needs_input_grad[0]:
            sqrtm, = ctx.saved_tensors
            sqrtm = sqrtm.data.cpu().numpy().astype(np.float_)
            gm = grad_output.data.cpu().numpy().astype(np.float_)

            # Given a positive semi-definite matrix X,
            # since X = X^{1/2}X^{1/2}, we can compute the gradient of the
            # matrix square root dX^{1/2} by solving the Sylvester equation:
            # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}).
            grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm)

            grad_input = torch.from_numpy(grad_sqrtm).to(grad_output)
        return grad_input


sqrtm = MatrixSquareRoot.apply


def main():
    from torch.autograd import gradcheck
    k = torch.randn(20, 10).double()
    # Create a positive definite matrix
    pd_mat = (k.t().matmul(k)).requires_grad_()
    test = gradcheck(sqrtm, (pd_mat,))
    print(test)


if __name__ == '__main__':
    main()


True


In [11]:
# Symplectic Spectrum Gaussian Processes | 2022
# Yusuke Tanaka

import sys
import pdb
import torch
import torch.nn as nn
import numpy as np
import math
torch.set_default_dtype(torch.float64)


class SSGP(torch.nn.Module):
  def __init__(self, input_dim, basis, friction):
    super(SSGP, self).__init__()
    self.sigma = nn.Parameter(torch.tensor([1e-1]))
    self.a = nn.Parameter(torch.ones(input_dim)*1e-1)
    self.b = nn.Parameter(1e-4 * (torch.rand(basis*2)-0.5))
    self.init_C(basis)
    self.sigma_0 = nn.Parameter(torch.tensor([1e-0]))
    self.lam = nn.Parameter(torch.ones(input_dim)*1.5)
    if friction:
      self.eta = nn.Parameter(torch.tensor([1e-16]))
    else:
      self.eta = torch.tensor([0.0])
    self.M = self.permutation_tensor(input_dim)
    np.random.seed(0)
    tmp = torch.tensor(np.random.normal(0, 1, size=(int(basis/2.), input_dim)))
    self.epsilon = torch.vstack([tmp,-tmp])
    self.d = input_dim
    self.num_basis = basis

  def sampling_epsilon_f(self):
    C = self.make_C()
    sqrt_C = sqrtm(C)
    sqrt_C = torch.block_diag(sqrt_C,sqrt_C)
    epsilon = torch.tensor(np.random.normal(0, 1, size=(1,sqrt_C.shape[0]))).T
    self.w = self.b + (sqrt_C @ epsilon).squeeze()
    num = 99
    for i in range(num):
      epsilon = torch.tensor(np.random.normal(0, 1, size=(1,sqrt_C.shape[0]))).T
      self.w += self.b + (sqrt_C @ epsilon).squeeze()
    self.w = self.w/(num+1)

  def mean_w(self):
    self.w = self.b * 1
    
  def forward(self, t, x):
    s = self.epsilon @ torch.diag((1 / torch.sqrt(4*math.pi**2 * self.lam**2)))
    R = torch.eye(self.d)
    R[:int(self.d/2),:int(self.d/2)] = 0
    mat = 2*math.pi*((self.M-self.eta**2*R)@s.T).T
    x = x.squeeze()
    samples = x.shape[0]
    sim = 2*math.pi*s@x.squeeze().T
    basis_s = -torch.sin(sim); basis_c = torch.cos(sim)

    # deterministic
    tmp = []
    for i in range(self.d):
      tmp.extend([mat[:,i]]*samples)
    tmp = torch.stack(tmp).T
    aug_mat = torch.vstack([tmp,tmp])
    aug_s = torch.hstack([basis_s]*self.d); aug_c = torch.hstack([basis_c]*self.d)
    aug_basis = torch.vstack([aug_s, aug_c])
    PHI = aug_mat * aug_basis
    aug_W = torch.stack([self.w]*samples*self.d).T
    F = PHI * aug_W
    f = torch.vstack(torch.split(F.sum(axis=0),samples)).T
    return f.reshape([samples,1,self.d])

  def neg_loglike(self, batch_x, pred_x):
    n_samples, n_points, dammy, input_dim = batch_x.shape
    likelihood = ( (-(pred_x-batch_x)**2/self.sigma**2/2).nansum()
                   - torch.log(self.sigma**2)/2*n_samples*n_points*input_dim)
    return -likelihood

  def KL_x0(self, x0):
    n, d = x0.shape
    S = torch.diag(self.a**2)
    return .5*((x0*x0).sum() + n*torch.trace(S) - n*torch.logdet(S))

  def KL_w(self):
    num = self.b.shape[0]
    C = self.make_C()
    C = torch.block_diag(C,C)
    term3 = (self.b*self.b).sum() / (self.sigma_0**2 / num * 2)
    term2 = torch.diag(C).sum() / (self.sigma_0**2 / num * 2)
    term1_1 = torch.log(self.sigma_0**2 / num * 2) * num
    term1_2 = torch.logdet(C)
    return .5*( term1_1 - term1_2 + term2 + term3)

  def sampling_x0(self, x0):
    n, _, d = x0.shape
    return (x0 + torch.sqrt(torch.stack([self.a**2]*n).reshape([n,1,d]))
            * (torch.normal(0,1, size=(x0.shape[0],1,x0.shape[2]))))

  def permutation_tensor(self,n):
    M = torch.eye(n)
    M = torch.cat([M[n//2:], -M[:n//2]])
    return M

  def init_C(self, basis):
    C = torch.linalg.cholesky(torch.ones(basis,basis)*1e-2+torch.eye(basis)*1e-2)
    C_line = C.reshape([(basis)**2])
    ids = torch.where(C_line!=0)[0]
    self.c = nn.Parameter(C_line[ids])
    ids = []
    for i in range(basis):
      for j in range(i+1):
        ids.append([i,j])
    ids = torch.tensor(ids)
    self.ids0 = ids[:,0]
    self.ids1 = ids[:,1]
    
  def make_C(self):
    C = torch.zeros(self.num_basis,self.num_basis)
    C[self.ids0,self.ids1] = self.c
    C = C@C.T
    return C




In [12]:
# Symplectic Spectrum Gaussian Processes | 2022
# Yusuke Tanaka

#os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
#num_threads = '1'
#os.environ['OMP_NUM_THREADS'] = num_threads
#os.environ['MKL_NUM_THREADS'] = num_threads
#os.environ['NUMEXPR_NUM_THREADS'] = num_threads

import pdb
import os, sys
import copy
import time
import json
import argparse
import math
import numpy as np
import torch
from torchdiffeq import odeint
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

%pwd
THIS_DIR = %pwd  # This will set THIS_DIR to the current working directory

torch.set_default_dtype(torch.float64)
DPI = 200
FORMAT = 'pdf'
LINE_SEGMENTS = 10
ARROW_SCALE = 100
ARROW_WIDTH = 6e-3
LINE_WIDTH = 2
xmin = -3.2; xmax = 3.2; ymin = -3.2; ymax = 3.2


# Now you can access the arguments using dot notation
print(args.batch_time)
print(args.learn_rate)

def train():
    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    output_dim = input_dim
    model = SSGP(input_dim, args.num_basis, args.friction).double()
    optim = torch.optim.Adam(model.parameters(), args.learn_rate)

    # train loop
    stats = {'train_loss': [], 'val_loss': []}
    t0 = time.time()
    min_val_loss = 1e+10
    for step in range(args.total_steps+1):

        # train step
        batch_y0, batch_t, batch_ys = get_batch(args, ys, t_eval, batch_step)
        s_batch_x0 = model.sampling_x0(batch_y0)
        model.sampling_epsilon_f()
        pred_x = odeint(model, s_batch_x0, batch_t, method='dopri5', atol=1e-8, rtol=1e-8)
        neg_loglike = model.neg_loglike(batch_ys, pred_x)
        KL_x0 = model.KL_x0(batch_y0.squeeze())
        KL_w = model.KL_w()
        loss = neg_loglike + KL_x0 + KL_w
        loss.backward(); optim.step(); optim.zero_grad()
        train_loss = loss.detach().item()/batch_y0.shape[0]/batch_t.shape[0]
        # run validation data
        with torch.no_grad():
            batch_y0, batch_t, batch_ys = arrange(args, val_ys, t_eval)
            s_batch_x0 = model.sampling_x0(batch_y0)
            model.mean_w()
            pred_val_x = odeint(model, s_batch_x0, t_eval, method='dopri5', atol=1e-8, rtol=1e-8)
            val_neg_loglike = model.neg_loglike(batch_ys, pred_val_x)
            loss = val_neg_loglike
            val_loss = loss.item()/batch_y0.shape[0]/t_eval.shape[0]

        # logging
        stats['train_loss'].append(train_loss)
        stats['val_loss'].append(val_loss)
        if step % args.print_every == 0:
            print("step {}, time {:.2e}, train_loss {:.4e}, val_loss {:.4e}"
                  .format(step, time.time()-t0, train_loss, val_loss))
            t0 = time.time()

        if val_loss < min_val_loss:
            best_model = copy.deepcopy(model)
            min_val_loss = val_loss; best_train_loss = train_loss
            best_step = step
            
    return best_model, stats, best_train_loss, min_val_loss, best_step

def param_save(model):
    csv_write(save_dir + '/' + 'sigma.csv', model['sigma'].cpu().detach().numpy())
    csv_write(save_dir + '/' + 'a.csv', model['a'].cpu().detach().numpy())
    csv_write(save_dir + '/' + 'b.csv', model['b'].cpu().detach().numpy())
    csv_write(save_dir + '/' + 'c.csv', model['c'].cpu().detach().numpy())
    csv_write(save_dir + '/' + 'sigma_0.csv', model['sigma_0'].cpu().detach().numpy())
    csv_write(save_dir + '/' + 'lam.csv', model['lam'].cpu().detach().numpy())
    if args.friction:
        csv_write(save_dir + '/' + 'eta.csv', model['eta'].cpu().detach().numpy())

    
if __name__ == "__main__":

    # setting
    label = 'SSGP/' + str(args.num_basis)
    i_dir = ( './' + args.name + '/' + str(args.eta) + '/train/' + str(args.sigma) 
              + '/' + str(args.samples) + '/' + str(args.timescale) + '/' + str(args.s))
    print(i_dir)
    save_dir = ( './' + args.name + '/' + str(args.eta) + '/result/' + str(args.sigma) 
                 + '/' + str(args.samples) + '/' + str(args.timescale) + '/' + str(args.s) + '/' + label)
    os.makedirs(save_dir) if not os.path.exists(save_dir) else None

    # input
    filename = i_dir + '/dataset.pkl'
    data = pkl_read(filename)
    ys = torch.tensor( data['ys'], requires_grad=False)
    val_ys = torch.tensor( data['val_ys'], requires_grad=False)
    t_eval = torch.tensor( data['t'])
    n_samples, n_points, input_dim = ys.shape
    batch_step = int(((len(t_eval)-1)/t_eval[-1]).item() * args.batch_time)

    # learning
    t0 = time.time()
    model, stats, train_loss, val_loss, step = train()
    train_time = time.time() - t0

    # save
    path = '{}/model.tar'.format(save_dir)
    torch.save(model.state_dict(), path)
    param_save(model.state_dict())
    
    path = '{}/model.json'.format(save_dir)
    with open(path, 'w') as f:
        json.dump(vars(args), f)

    path = '{}/result.csv'.format(save_dir)
    csv_write(path, np.array(['val_step',step,'train_loss',train_loss,
                                     'val_loss',val_loss,'train_time',train_time]))

    # vis
    ## learning curve
    filename = save_dir + '/learning_curve.pdf'
    x = np.arange(len(stats['train_loss']))
    plot(filename, x, [stats['train_loss'],stats['val_loss']],
                  'epoch','neg_loglike', ['train','validation'])

    ## pred field
    if input_dim == 2:
        model.mean_w()
        pred_field = get_field(model.forward, xmin, xmax, ymin, ymax, args.gridsize)
        filename = save_dir + '/pred_field.pdf'
        vis_field(filename, pred_field, xmin, xmax, ymin, ymax)


1
0.001
./pendulum/0.0/train/0.1/10/3/0
step 0, time 4.35e-02, train_loss 2.0027e+02, val_loss 2.7076e+02
step 100, time 9.43e+00, train_loss 2.2586e+01, val_loss 8.1981e+01


KeyboardInterrupt: 