In [None]:
import numpy as np
import torch
import time
import os
import sys
import torchsde
from torch import nn
import matplotlib.pyplot as plt

def pbc_duplicate(x1, x2, u):
    x1 = np.stack([x1, u[0] + x1, 
                   x1, u[0] + x1])
    x2 = np.stack([x2, x2, 
                   u[1] + x2, u[1] + x2])
    return x1, x2
def cutoff_func(data):
    return torch.where(data > torch.zeros_like(data), torch.zeros_like(data),torch.ones_like(data))

In [None]:
#####################################################################################################
# defining a DeepRitz block: f_i(s) = ϕ(Wi2 . ϕ(Wi1 . s +bi1) + bi2) + s
class DeepRitz_block(nn.Module):
    def __init__(self, h_size):
        super(DeepRitz_block, self).__init__()
        self.dim_h = h_size

        self.activation_function = nn.ReLU()
        block = [nn.Linear(self.dim_h, self.dim_h),
                 self.activation_function,
                 nn.Linear(self.dim_h, self.dim_h),
                 self.activation_function]
        self._block = nn.Sequential(*block)
    def forward(self, x):
        return self._block(x) + x

# defining the neural network constructed by DeepRitz blocks
class Neural_Network(nn.Module):
    def __init__(self, in_size=1, h_size = 10, block_size = 1, cutoff_distance = 4., theta=0., dev="cpu"):
        super(Neural_Network, self).__init__()
        self.num_blocks = block_size
        self.in_size = in_size
        self.dim_h = h_size
        self.dev = dev
        self.CUTOFF = cutoff_distance
        self.theta = theta
        
        # assemble the neural network with DeepRitz blocks
        self._block0 = DeepRitz_block(self.dim_h)
        self._block = DeepRitz_block(self.dim_h)
        
        model_t = [nn.ConstantPad1d((0, self.dim_h - self.in_size), 0)]
        for _ in range(2):
            model_t.append(self._block0)
        model_t.append(nn.Linear(self.dim_h, 1))
        self._model_t = nn.Sequential(*model_t)
        
        model0 = [nn.ConstantPad1d((0, self.dim_h - self.in_size), 0)]
        for _ in range(2):
            model0.append(self._block0)
        self.fe = nn.Sequential(*model0)

        model = []
        for _ in range(self.num_blocks):
            model.append(self._block)
        model.append(nn.Linear(self.dim_h, 2))
        self.fg = nn.Sequential(*model)
        
    def rotation(self, dx):
        th = self._model_t(dx) + self.theta
        rotation = torch.cat([torch.cos(th), -torch.sin(th), torch.sin(th),  torch.cos(th)],-1)
        return rotation.view([*th.shape[:-1],2,2])

    def decay(self, dx):
        return (torch.cos(np.pi * dx/self.CUTOFF) + 1.).pow(1/2)
        
    # the magnitude of the control force
    def forward(self, dx):
        y =  self.fe(dx) * cutoff_func(dx[...,:1]-self.CUTOFF) * self.decay(dx[...,:1])
        y = self.fg(y[...,1:,:].sum(-2))
        return y

# defining the dynamical system of interests: dX_t = F(X_t)dt + \sqrt{2\epsilon}dW_t
# active Brownian particles
class ODE(nn.Module):
    def __init__(self, num_particles, v, unit_size, dev):
        super().__init__()
        self.dev = dev
        self.epsilon = 1.
        self.dim_r = num_particles
        self.unit_size = unit_size
        self.v = v

        self.corrector = 2**(1/3)*torch.eye(num_particles).to(dev)

    def F(self, theta, dX1, dX2):
        dis = dX1**2 + dX2**2 + self.corrector.expand([*theta.shape[:-1], self.dim_r,self.dim_r])
        
        dU = torch.relu(dis**(-13/2) - 1/2 * dis**(-7/2))
        
        F1 = 48.*self.epsilon * (dU * dX1).sum(-1) + self.v * torch.cos(theta)
        F2 = 48.*self.epsilon * (dU * dX2).sum(-1) + self.v * torch.sin(theta)
        
        return torch.cat([F1, F2, torch.zeros(F1.shape).to(self.dev)], -1)
    
    def forward(self, theta, dX1, dX2):
        return self.F(theta, dX1, dX2)

In [None]:
# defining the SDE with trial driven force u(x) and fixed diffusion matrix D
class SDE(nn.Module):
    sde_type = 'stratonovich'
#     sde_type = 'ito'
    noise_type = 'general'

    def __init__(self, Drift, Diffusion, num_neighbor, unit_size, dev = "cpu"):
        super(SDE, self).__init__()
        self.dev = dev
        self.dim_x = Diffusion.size(1)
        self.num_neighbor = num_neighbor
        self.unit_size = unit_size
        self.dim_x_ABP = int(2/3*self.dim_x)
        self.dim_r = self.dim_x//3

        # drift & const diffusion matrix D:
        self._drift_0 = Drift.to(dev)
        self._diffusion = Diffusion.to(dev)
        self.eyes = torch.eye(self.dim_r, self.dim_r).to(dev)
        self.gamma = 1.
        
    def MIC(self, x, axis=0):
        return (x + self.unit_size[axis]/2) % self.unit_size[axis] - self.unit_size[axis]/2
    
    def PBC(self, x):
        x[...,:self.dim_r] = x[...,:self.dim_r] % self.unit_size[0]
        x[...,self.dim_r:2*self.dim_r] = x[...,self.dim_r:2*self.dim_r] % self.unit_size[1]
        x[...,self.dim_x_ABP:] = x[...,self.dim_x_ABP:] % (2 * np.pi)
        return x
    
    def NN_input(self, x):
        X1 = x[...,:self.dim_r].unsqueeze(-1).expand(*x.shape[:-1],self.dim_r,self.dim_r).transpose(-1,-2)
        X2 = x[...,self.dim_r:2*self.dim_r].unsqueeze(-1).expand(*x.shape[:-1],self.dim_r,self.dim_r).transpose(-1,-2)
        dX1 = self.MIC(X1.transpose(-1,-2)-X1, 0)
        dX2 = self.MIC(X2.transpose(-1,-2)-X2, 1)
        D = torch.sqrt(dX1**2 + dX2**2)
        index = torch.argsort(D)[...,:self.num_neighbor]
        D[D > CUTOFF] = 0.
        return dX1, dX2, torch.gather(D,-1,index)[...,None], index
    
    def control_force(self, dX1, dX2, dx, index):
        dX1 = torch.gather(dX1,-1,index)
        dX2 = torch.gather(dX2,-1,index)
        D = torch.sqrt(dX1.pow(2) + dX2.pow(2))
        rot = neural_network.rotation(dx)
        direction = torch.stack([(dX1/(D+1e-9)), (dX2/(D+1e-9))],-1)
#         direction = torch.einsum('...ij,...j->...i', rot, direction)
        F = neural_network(dx) * direction.sum(-2)
        return F
    
    # the trial driven force u(x)
    def drift(self, t, x):
        x = self.PBC(x)
        dX1, dX2, dx, index = self.NN_input(x)
        F0 = self._drift_0(x[...,self.dim_x_ABP:], dX1, dX2)
        F = self.control_force(dX1, dX2, dx, index)
        F = torch.cat([F[...,0], F[...,1], torch.zeros([*x.shape[:-1], self.dim_r]).to(dev)], -1)
        return F + F0

    # the diffusion matrix
    def diffusion(self, t, x):
        return self._diffusion.expand(x.size(0), self.dim_x, self.dim_x)

    # the F(x)
    def drift_0(self, t, x):
        x = self.PBC(x)
        dX1, dX2, dx, index = self.NN_input(x)
        F0 = self._drift_0(x[...,self.dim_x_ABP:], dX1, dX2)
        return F0

In [None]:
num_neighbor, dim_h, num_blocks = 20+1, 10, 1
dim_x, aspect, rho, v = 400, 1, .6, 100
CUTOFF = 3.5

unit_size = np.sqrt(dim_x / rho) + 0.0001
unit_size = np.array([unit_size * np.sqrt(aspect), unit_size / np.sqrt(aspect)])

dev = "cuda:0" if torch.cuda.is_available() else "cpu"
print(dev)

neural_network = Neural_Network(1, dim_h, num_blocks, CUTOFF, 0., dev).to(dev)

Dt = 1.
Dr = 3.*Dt
diffusion = torch.tensor( np.sqrt(2 * np.kron(np.diag([Dt, Dt, Dr]), np.eye(dim_x))) ).float().to(dev)

ode = ODE(dim_x, v, unit_size, dev)
sde = SDE(ode, diffusion, num_neighbor, unit_size, dev).to(dev)
if sde.sde_type == 'ito':
    sde_method = 'euler'
else:
    sde_method = 'midpoint'
Lambda_k = []
A_k = []
K_k = []
swaps_k = []

In [None]:
dt = .08e-4
batch_size = 10
biasing = -12.

X, Y = np.meshgrid(np.arange(20)/20 * unit_size[0], np.arange(20)/35 * unit_size[1])
x0 = torch.cat([torch.rand([batch_size, dim_x]) * unit_size[0],
                torch.rand([batch_size, dim_x]) * unit_size[1],
                torch.rand([batch_size, dim_x]) * 2 * np.pi],-1).to(dev)
x0[:,:2*dim_x] = torch.tensor(np.hstack([X.flatten() * 1., Y.flatten()])).to(dev)
with torch.no_grad():
    x_init = torchsde.sdeint(sde, x0.to(dev), torch.arange(0, 100 * dt, dt).to(dev), dt = dt,
                             method=sde_method, names={'drift': 'drift_0', 'diffusion': 'diffusion'})[-1]

optimizer = torch.optim.Adagrad(neural_network.parameters(), lr = 1e-2)
print('done')

In [None]:
T = 20 * dt
ts = torch.arange(0, T+dt, dt).to(dev)

step = 0
while step < 300:
    start_time = time.time()

    with torch.no_grad():
        traj = torchsde.sdeint(sde, x_init.to(dev), ts, dt = dt,
                               method=sde_method, names={'drift': 'drift', 'diffusion': 'diffusion'})
        x_init = traj[-1].detach()

        dX1, dX2, dx, index = sde.NN_input(traj)
        F0 = ode(traj[..., 2*dim_x:], dX1, dX2)[:-1,:,:2*dim_x]
        b = torch.cat([torch.cos(traj[...,2*dim_x:]), torch.sin(traj[...,2*dim_x:])], 2)
        b_mid = (b[1:,]+b[:-1])/2

    optimizer.zero_grad()
        
    F = sde.control_force(dX1, dX2, dx, index)[:-1]
    F = torch.cat([F[...,0], F[...,1]],-1)

    K_T = torch.sum(F**2, (0,-1))/Dt/2 *dt/T
    A_T = v/Dt * torch.sum(b_mid * (F + F0), (0,-1)) *dt/T
    
    loss_batch = (K_T - biasing * A_T) / dim_x
        
    loss_batch.mean().backward()
    optimizer.step()

    K_k.append(K_T.cpu().detach().numpy().flatten()/ dim_x)
    A_k.append(A_T.cpu().detach().numpy().flatten()/ dim_x)
    Lambda_k.append(biasing * A_k[-1] - K_k[-1])
    step += 1

    t_simul = time.time()
    print('%i - %.2f sec - loss: %.4f - %.4f * %.4f = %.4f' 
            % (step, float(t_simul-start_time), np.mean(K_k[-1]), biasing, 
               np.mean(A_k[-1]), np.mean(Lambda_k[-1])))

In [None]:
fig, ax = plt.subplots(1,3, figsize=(18,3))
psi = np.array(Lambda_k).mean(-1)
observable = np.array(A_k).mean(-1)
ax[0].plot(psi)
ax[1].plot(np.array(K_k).mean(-1))
ax[2].plot(observable)

ax[0].set_ylabel(r'$\psi(\lambda)$')
ax[1].set_ylabel(r'$D_{KL}$')
ax[2].set_ylabel(r'$A_T$')
for i in range(3):
    ax[i].grid()
    ax[i].set_xlabel('training steps')
plt.show()

In [None]:
b = 0
scale = 1000
fig_snapshot, ax = plt.subplots(figsize=(10,6))
ax.set_aspect(1)
x = traj[-1,b,:dim_x].cpu().detach().numpy()
y = (traj[-1,b,dim_x:dim_x*2].cpu().detach().numpy() + 4) % unit_size[1]
Fx = F[-1,b,:dim_x].cpu().detach().numpy()
Fy = F[-1,b,dim_x:].cpu().detach().numpy()

ax.scatter(x,y,s=200*200/dim_x, alpha=0.5, c='steelblue',linewidths=1.2)
ax.quiver(x,y, Fx, Fy, scale_units='height')
ax.set_xlim([0, unit_size[0]])
ax.set_ylim([0, unit_size[1]])
plt.show()