### Background:

Brownian dynamics:

\begin{equation}
  dX_t = -\nabla V(X_t)\,dt + \sqrt{2\beta^{-1}} dB_t\,, \quad t > 0\,.
\end{equation}

$X_t$ is ergodic with respect to a unique invariant density 
$$\pi(x) = \frac{1}{Z}\mathrm{e}^{-\beta V(x)},$$  
where $$Z = \int_{\mathbb{R}} \mathrm{e}^{-\beta V(x)} dx$$ is a normalizing constant.

The associated semigroup is  
\begin{equation}
  (T_tg)(x) = \mathbb{E}[g(X_t)|X_0=x], \quad x\in\mathbb{R}^d\,.
\end{equation}

### This notebook illustrates how to solve the problem:

\begin{equation}
  \nu_2 = \max_{g\in L^2_\pi(\mathbb{R}^d), \langle g, \mathbf{1}\rangle_\pi = 0} \frac{\langle T_\tau g, g\rangle_\pi}{\langle g, g\rangle_\pi} = \max_{g\in L^2_\pi(\mathbb{R}^d),\, \langle g,\mathbf{1}\rangle_\pi = 0} \frac{\mathbb{E}_{X_0\sim \pi} \big[g(X_0)g(X_\tau)\big]}{\mathbb{E}_\pi(g^2)},       \qquad (*)
\end{equation}
for fixed $\tau > 0$.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg

## Part 1: define the system and generate data

#### First, define the potential $V$ and its gradient

The system is in 1D.

In [None]:
# potential V, one-dimensional
def V(x):
    y1 = x**8
    y2 = 0.8 * np.exp(-80 * x**2)
    y3 = 0.55 * np.exp(-80 * (x-0.5)**2)
    y4 = 0.3 * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

# gradient of V
def gradV(x):
    y1 = 8 * x**7 
    y2 = - 0.8 * 160 * x * np.exp(-80 * x**2)
    y3 = - 0.55 * 160 * (x - 0.5) * np.exp(-80 * (x-0.5)**2)
    y4 = - 0.3 * 160 * (x + 0.5) * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

#### function to sample the process

In [None]:
# sample the SDE using Euler-Maruyama scheme
def sample(beta=1.0, dt=0.001, N=10000, seed=42):
    rng = np.random.default_rng(seed=seed)
    X = 0.0
    traj = []
    tlist = []
    for i in range(N):
        traj.append(X)
        tlist.append(dt*i)        
        b = rng.normal()
        X = X - gradV(X) * dt + np.sqrt(2 * dt/beta) * b

    return np.array(tlist), np.array(traj)  

#### parameters

In [None]:
# coefficient in SDE
beta = 2.0
# step-size 
dt = 0.005
# number of sampling steps 
N = 50000
# range of the domain 
xmin, xmax = -1.0, 1.0

### plot the potential $V(x)$ and the invariant density (boltzmann density)
$$\pi(x) = \frac{1}{Z}\mathrm{e}^{-\beta V(x)},$$  
where $$Z = \int_{\mathbb{R}} \mathrm{e}^{-\beta V(x)} dx$$ is a normalizing constant.

In [None]:
# uniform grid on [xmin, xmax]
xvec = np.linspace(xmin, xmax, 101)

# potential on grid
pot_vals = V(xvec)

# compute invariant density
density_unnormalized = np.exp(-beta * pot_vals)
# normalizing constant Z
z = np.sum(density_unnormalized) * (xmax-xmin) / 100
# normalize to get densitz
density_pi = density_unnormalized / z

fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)
# plot V
ax.plot(xvec, pot_vals)
ax.set_xlabel(r'x')
ax.set_title(r'V')

ax = fig.add_subplot(1, 2, 2)
# plot invariant density
ax.plot(xvec, density_pi)
ax.set_xlabel(r'x')
ax.set_title(r'invarint density')

### get trajectory data by sampling

In [None]:
tvec, traj = sample(beta, dt=dt, N=N)

### display the sampled trajectory and verify that the trajectory is long enough.

When the simulation time is long enough, the empirical density of the trajectory data should match the invariant density (ergodic theorm). 


In [None]:
fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)

# plot trajectory vs time
ax.plot(tvec, traj, alpha=0.5)
ax.set_ylim([xmin, xmax])
ax.set_xlabel(r'time')
ax.set_ylabel(r'x')
ax.set_title('trajectory')

ax1 = fig.add_subplot(1, 2, 2)

# plot empirical density of the trajectory data
ax1.hist(traj, 50, density=True, label='empirical density')

# plot the invariant density
ax1.plot(xvec, density_pi, label='invarint density')

ax1.set_title('impirical and invariant density')
ax1.legend()

# Part 2:   Marov state models (MSMs)


### decomposition of the space $\mathbb{R}$:

1. $D_1=(-\infty, xmin)$.

2. The interval [xmin, xmax] is divided uniformly into $M'$ subsets, $D_2,\dots, D_{M'+1}$.

3. $D_{M'+2}=(xmax, +\infty)$.

% Therefore, there are $M=M'+2$ subsets in total.

### function $g$ is approximated by piece-wise functions:
\begin{equation}
  g(x)=\sum_{i=1}^{M} \omega_i \mathbf{1}_{D_i}(x) \,, \quad
  x\in\mathbb{R}\,, 
\end{equation}
where $\mathbf{1}_{D_i}(x)$ denotes the indicator function associated to the set $D_i$.

### The problem (*) becomes 
\begin{equation}
  \max_{\omega \in \mathbb{R}^M, \langle \omega, \mathbf{1}_M \rangle_{\hat{\pi}} = 0} \frac{\langle \hat{P} \omega, \omega \rangle_{\hat{\pi}}}{\langle \omega, \omega \rangle_{\hat{\pi}}}\,. \qquad  (\hat{*})
\end{equation}
associated to the Markov chain defined by $\hat{P}$.

Notice that ($\hat{*}$) is equivalent to solving the matrix eigenvalue problem

$$\hat{P}v=\hat{\nu} v.$$

### expressions of the matrix $\hat{P}$ and the invarint density $\hat{\pi}$ of the Markov chain 

$\hat{\pi}=(\hat{\pi}_1, \dots, \hat{\pi}_{M})^\top \in \mathbb{R}^{M}$, where 
  \begin{equation}
    \hat{\pi}_i = \mathbb{E}_{x\sim \pi}\big[\mathbf{1}_{D_i}(x)\big] = \int_{\mathbb{R}^d} \mathbf{1}_{D_i}(x) \pi(x)dx \,.
  \end{equation}

entries of the matrix $\hat{P} \in \mathbb{R}^{M\times M}$:
\begin{equation}
  \hat{P}_{ij} = \frac{\mathbb{E}_{X_0\sim \pi}\big[\mathbf{1}_{D_i}(X_0) \mathbf{1}_{D_j}(X_\tau)\big]}{ \mathbb{E}_{x\sim\pi}
  \big[\mathbf{1}_{D_i}(x)\big]}\,,\quad 1\le i,j\le M\,.
\end{equation}

### empirical expressions used in practice (by ergodic theorm):
\begin{equation}
  \begin{aligned}
    \hat{\pi}_i =& \int_{\mathbb{R}^d} \mathbf{1}_{D_i}(x) \pi(x)dx = \lim_{T\rightarrow \infty} \frac{1}{T} \int_0^ T \mathbf{1}_{D_i}(X_t) dt \approx \frac{1}{N}\sum_{n=0}^{N-1} \mathbf{1}_{D_i}(X_n) 
  \end{aligned}
\end{equation}  
and
\begin{equation}
  \begin{aligned}    
    \hat{P}_{ij} =& \frac{\mathbb{E}_{X_0\sim \pi}\big[\mathbf{1}_{D_i}(X_0) \mathbf{1}_{D_j}(X_\tau)\big]}{ \mathbb{E}_{x\sim\pi}
    \big[\mathbf{1}_{D_i}(x)\big]} = \frac{\lim_{T\rightarrow \infty}
    \frac{1}{T} \int_0^ T \mathbf{1}_{D_i}(X_t) \mathbf{1}_{D_j}(X_{t+\tau})
    dt}{\lim_{T\rightarrow \infty} \frac{1}{T} \int_0^ T
    \mathbf{1}_{D_i}(X_t) dt} \approx \frac{\frac{1}{N-n'}\sum_{n=0}^{N-n'-1} \mathbf{1}_{D_i}(X_n) \mathbf{1}_{D_j}(X_{n+n'})}{\hat{\pi}_i}     
  \end{aligned}
\end{equation}


In [None]:
# determine the index of the subset corresponding to state x

def find_index(x, xmin, xmax, M_prime):
    if x < xmin:  # 0 if x exceeds the lower range of [xmin, xmax] 
        return 0
    if x > xmax:  # M+1 if x exceeds the upper range of [xmin, xmax] 
        return M_prime+1 
    # the interval [xmin, xmax] is divided uniformly into M sub-intervals.
    bin_width = (xmax-xmin) / M_prime
    idx = int (np.floor((x - xmin) / bin_width)) + 1
    return idx
    
M_prime = 200

n_prime=10 
tau = dt * n_prime

# Total number of states is M+2 
# initialize the vector and matrix 
pi_hat = np.zeros((M_prime+2))
P = np.zeros((M_prime+2, M_prime+2))

for i in range(N-n_prime):
    x = traj[i]
    idx1= find_index(x, xmin, xmax, M_prime)
    # count the state
    pi_hat[idx1] += 1
    
    x_next = traj[i+n_prime]    
    idx2= find_index(x_next, xmin, xmax, M_prime)
    # count the pair 
    P[idx1,idx2] += 1

# normalize to get weights (densities)
pi_hat /= np.sum(pi_hat)    

# normalize to get entries of the probability matrix
for i in range(M_prime+2):
    if pi_hat[i] > 0:  
        P[i, :] /= pi_hat[i] * (N-n_prime)

print ('size of P:', P.shape)

### Solve the matrix eigenvalue problem

$$\hat{P}v=\hat{\nu} v.$$

In [None]:
# solve the eigenvalue problem
la, v = linalg.eig(P)

# sort the eigenvalues so that the largest ones are ranked first. 
idx = la.real.argsort()[::-1]

# get the real parts
la = la.real[idx]
v = v[:,idx].real

print ('eigenvalues:\n', la.real)

### let's plot the approximation of the eigenfunctions

In [None]:
fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)

bins = np.arange(xmin, xmax, (xmax-xmin)/M_prime)
# we comfine ourself to the range [xmin, xmax]
eigvec = v[1:-1,:]

# normalize to get eigenfunctions!
for i in range(2):
    scale = np.sqrt(np.sum(eigvec[:,i]**2*pi_hat[1:-1]))
    eigvec[:,i] /= scale
   
for i in range(2):   
    print ('eig %d: %.3f' % (i+1, la[i]))
    ax.plot(bins, eigvec[:,i], alpha=0.5, label='eig_%d' % (i+1))

ax.set_xlim([xmin, xmax])
ax.set_xlabel(r'x')
ax.set_title('eigenfunctions')
plt.legend()

# Part 2: Solving the problem using neural networks

### change the optimization problem  (*) to an unconstrained optimization problem (by including penalties)

\begin{equation}
  \min_{g\in \Phi} \Big[-\frac{1}{\tau}\frac{\langle T_\tau g, g\rangle_\pi}{\langle g, g\rangle_\pi} + \alpha \big(\mathbb{E}_\pi g\big)^2 + \alpha \big(\mathbb{E}_\pi (g^2)-1\big)^2\Big]
\end{equation}

where $\alpha > 0$ is a large penalty constant and we included a scaling constant $\frac{1}{\tau}$. The two penalty terms correspond to 
1. the constrant $\langle g, \mathbf{1}\rangle_\pi=0$.
2. normalization: $\langle g, g\rangle_\pi=1$, 
respectively.

### Loss in practice:

\begin{equation}
  \mathrm{Loss}(g)= \min_{g\in \Phi} \bigg[-\frac{1}{\tau}\frac
  {\frac{1}{N-n'} \sum_{n=0}^{N-n'-1} g(X_{n}) g(X_{n+n'})}{
\frac{1}{N} \sum_{n=0}^{N-1} g^2(X_{n})} + \alpha \Big(\frac{1}{N} \sum_{n=0}^{N-1} g(X_{n})\Big)^2 + \alpha \Big(\frac{1}{N} \sum_{n=0}^{N-1} g^2(X_{n})-1\Big)^2 \bigg]\,.
\end{equation}


In [None]:
# batch-size
batch_size = 5000
# penalty constant in the loss
alpha = 10
# total training epochs
total_epochs = 300

# represent the function g using a neural network
model = nn.Sequential(
            nn.Linear(1, 20),  
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, 1))
# Adam
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

# to evaluate loss, we need paired data at time t and t+tau 
# we build a dataloader that provides min-batch consisting of paired data
x = torch.tensor(traj[:-n_prime], dtype=torch.float32).reshape(-1,1)
x_tau = torch.tensor(traj[n_prime:], dtype=torch.float32).reshape(-1,1)
traj_pair = torch.utils.data.TensorDataset(x, x_tau)

data_loader = DataLoader(traj_pair, batch_size=batch_size, shuffle=True, drop_last=True)

# lists to record the loss, constraint, and approximation of eigenvalue 
loss_list = []
constraint_list = []
eig_list = []

for epoch in range(total_epochs):   # for each epoch
    
    for idx, data in enumerate(data_loader):  # loop over all mini-batches 
        
        # g(x(t))
        y = model(data[0])
        # g(x(t+tau))
        y_tau = model(data[1])
        # Rayleigh ratio in the maximization problem
        eig_loss = (y*y_tau).mean() / (y**2).mean() 
        # objective consisting of Rayleigh ratio and penalties
        loss = -1.0 / tau * eig_loss + alpha * (y.mean())**2 + alpha * ((y**2).mean()-1)**2
        
        optimizer.zero_grad()
        # gradient step
        loss.backward()
        # update weights
        optimizer.step()
        
        if idx == 0:
            # record the loss    
            loss_list.append(loss.item())  
            eig_list.append(eig_loss.item())
            constraint_list.append(y.mean().item())
            if epoch % 20 == 0:
                print ('epoch=%d\n   loss=%.4f, eig=%.4f, constraints=[%.3f, %.3f]' \
                       % (epoch, loss.item(), eig_loss.item(), y.mean(), (y**2).mean()))        
    

y = model(x)
y_tau = model(x_tau)
eig_val = (y*y_tau).mean() / (y**2).mean()

print ('estimated eigenvalue: %.4f' % eig_val)

fig, ax = plt.subplots(1,2, figsize=(10, 4))

ax[0].plot(loss_list)
ax[0].set_xlabel('epoch')
ax[0].set_title('loss vs epoch')

ax[1].plot(constraint_list)
ax[1].set_ylim([-0.5, 0.5])
ax[1].set_xlabel('epoch')
ax[1].set_title('constraint vs epoch')

### Compare the solutions obtained using markov state models and neural networks

In [None]:
fig, ax = plt.subplots(1,1, figsize=(5, 5))

# evaluate and plot the neural network solution on grid
y = model(torch.tensor(xvec, dtype=torch.float32).reshape(-1,1)) 

# may be necessary to change the sign 
#y *= -1.0

ax.plot(xvec, y.detach().numpy(), '.', c='r', label='neural network')

# plot the solution from Markov state models
ax.plot(bins, eigvec[:,1], c='b', label='markov state model' )

ax.legend() 
ax.set_xlabel('x')
ax.set_title('eigenfunction')