# Operator Learning with DeepXDE

## Imports

In [1]:

import numpy as np
import torch
import torch.utils.data as  dt
import torch.nn as nn
import torch.optim as optim
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from jax.config import config
from tqdm import trange
import matplotlib.pyplot as plt
import matplotlib as mpl

In [2]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

$N$: Number of functions $u(x)$ in training dataset  
$P$: Number fo points inside domain at which $G(u)$ is evaluated (output evaluations)  
$m$: Number of points at which input function is evaluated  

## Data Generation

Effective data generation requires parallel data generation when training.  
__ID__:  Python string that identifies sample of sample
__train__: Trining data
__validation__: Validation data points  

We will access training and validation samples using `ID`.

In [3]:
# Define RBF kernel
def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs =jnp.expand_dims(x1 / lengthscales, 1) - \
           jnp.expand_dims(x2 / lengthscales, 0)
    r2 =jnp.sum(diffs**2, axis=2)
    return output_scale *jnp.exp(-0.5 * r2)

We obtain the corresponding 10000 ODE solutions by solving: 

$$\frac{dv(x)}{dx}=u(x)$$

Using an explicit Runge-Kutta method(RK45)→ JAX's odeint functiom.

In [4]:
def plot_uv(x,u,y,v):
    fig,ax = plt.subplots(figsize = (8,6))
    with mpl.rc_context({'font.size':18}):
        ax.plot(x,u,'k--',label = '$u(x) = ds/dx$',
                linewidth = 1.5)
        ax.plot(y,v,'o--',label = '$s(x)=s(0) + \int u(t)dt|_{t=y}$',
                linewidth = 1.5)
        ax.set_label('x')
        ax.set_ylabel('u')
        ax.tick_params(axis = 'y',color = 'black')
        ax.legend(loc = 'lower right',ncol=1)
        return ax

In [5]:
# Geneate training data corresponding to one input sample
def generate_one_training_data(key, m=100, P=1,
                               length_scale=0.2):
    # Sample GP prior at a fine grid
    N = 512
    gp_params = (1.0, length_scale)
    jitter = 1e-10
    X =jnp.linspace(0, 1, N)[:,None]
    K = RBF(X, X, gp_params)
    L =jnp.linalg.cholesky(K + jitter*np.eye(N))
    gp_sample =jnp.dot(L, jax.random.normal(key, (N,)))

    # Create a callable interpolation function  
    u_fn = lambda x, t:jnp.interp(t, X.flatten(), gp_sample)

    # Input sensor locations and measurements
    x =jnp.linspace(0, 1, m)
    u = jax.vmap(u_fn, in_axes=(None,0))(0.0, x)

    # Output sensor locations and measurements
    y_train = jax.random.uniform(key, (P,)).sort() 
    v_train = odeint(u_fn, 0.0,jnp.hstack((0.0, y_train)))[1:] # JAX has a bug and always returns s(0), so add a dummy entry to y and return s[1:]

    # Tile inputs
    u_train =jnp.tile(u, (P,1))

    return u_train, y_train, v_train

# Geneate test data corresponding to one input sample
def generate_one_test_data(key, m=100, P=100,
                           length_scale=0.2):
    # Sample GP prior at a fine grid
    N = 512
    gp_params = (1.0, length_scale)
    jitter = 1e-10
    X =jnp.linspace(0, 1, N)[:,None]
    K = RBF(X, X, gp_params)
    L =jnp.linalg.cholesky(K + jitter*np.eye(N))
    gp_sample =jnp.dot(L, jax.random.normal(key, (N,)))

    # Create a callable interpolation function  
    u_fn = lambda x, t:jnp.interp(t, X.flatten(), gp_sample)

    # Input sensor locations and measurements
    x =jnp.linspace(0, 1, m)
    u = jax.vmap(u_fn, in_axes=(None,0))(0.0, x)

    # Output sensor locations and measurements
    y =jnp.linspace(0, 1, P)
    v = odeint(u_fn, 0.0, y)

    # Tile inputs
    u =jnp.tile(u, (P,1))

    return u, y, v 

In [17]:
u_train.shape

(10000, 1, 100)

In [6]:
# Training Data
N_train = 10000 #Number of functions
m = 100 # number of input sensors
P_train = 1   # number of output sensors
key_train = jax.random.PRNGKey(0)  # use different key for generating training data and test data 
config.update("jax_enable_x64", True) # Enable double precision



In [18]:
keys_train = jax.random.split(key_train, N_train) # Obtain 10000 random numbers
gen_fn = jax.jit(lambda key: generate_one_training_data(key, m, P_train)) #lets call our function
u_train, y_train, v_train = map(np.array,jax.vmap(gen_fn)(keys_train))

In [22]:
u_train = u_train.reshape(N_train * P_train,-1)
y_train = y_train.reshape(N_train,-1,P_train)
v_train = v_train.reshape(N_train * P_train,-1)
print(u_train.shape)
print(y_train.shape)
print(v_train.shape)

(10000, 100)
(10000, 1, 1)
(10000, 1)


In [23]:
# Testing Data
N_test = 1 # number of input samples 
P_test = m   # number of sensors 
key_test = jax.random.PRNGKey(12345) # A different key 

keys_test = jax.random.split(key_test, N_test)
gen_fn = jax.jit(lambda key: generate_one_test_data(key, m, P_test))
u_test, y_test, v_test = map(np.array,jax.vmap(gen_fn)(keys_test))

In [24]:
u_test = u_test.reshape(N_test * P_test,-1)
y_test = y_test.reshape(N_test,-1,P_test)
v_test = v_test.reshape(N_test * P_test,-1)
print(u_test.shape)
print(y_test.shape)
print(v_test.shape)

(100, 100)
(1, 1, 100)
(100, 1)


## Model Compile

In [9]:
device = torch.device("cuda:0" if torch.has_cuda else "cpu")

In [25]:
class DataGenerator(dt.Dataset):
    
    def __init__(self,inputs:torch.Tensor,location:torch.Tensor,outputs:torch.Tensor):
        self.input_signals = inputs
        self.collocation_points = location
        self.output_signals = outputs
    
    def __len__(self):
        return len(self.input_signals)
    
    def __getitem__(self, index) -> tuple:
        v = self.input_signals[index,:]
        y = self.collocation_points[index,:,:]
        u = self.output_signals[index,:]
        return ((v,y),u)

In [26]:
training_set = DataGenerator(u_train,y_train,v_train)
training_loader = dt.DataLoader(training_set,batch_size = 32,
                                shuffle=True)

In [27]:
for ((b_in,y_loc),d_out) in training_loader:
    print(b_in.shape)
    print(y_loc.shape)
    print(d_out.shape)
    break

torch.Size([32, 100])
torch.Size([32, 1, 1])
torch.Size([32, 1])


In [29]:
class NN(nn.Module):
    """Base class for neural network modules"""
    def __init__(self):
        super().__init__()
        self.regulariser = None
    @property
    def num_trainable_parameters(self):
        """Evaluate number of trainable parameters for NN"""
        return sum(v.numel() for v in self.parameters() if v.requires_grad)
#%%
class MLP(NN):
    """Mulilayer perceptron network fully connected"""
    def __init__(self,layer_sizes,
                 activation = nn.ReLU(),
                 **init_kwargs):
        super().__init__()
        self.activation = activation       
        self.layers = nn.ModuleList(
            [nn.Linear(l_in,l_out,dtype = torch.float32) 
                for (l_in,l_out) in zip(layer_sizes,layer_sizes[1:])])

        self.apply(self._init_weights,**init_kwargs)
        
    def _init_weights(self,module:nn.Linear,initialiser = nn.init.xavier_normal_,
                      zero_initialiser = nn.init.zeros_):
        if isinstance(module,nn.Linear):
            initialiser(module.weight)
            zero_initialiser(module.bias)   
    
    def forward(self,inputs):
        x = inputs
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        x = self.layers[-1](x)
        
        return self.activation(x)

In [37]:
class DeepONet(NN):
    """Deep operator network for dataset in the format of Cartesian product.

    Args:
        layer_sizes_branch: A list of integers as the width of a fully connected network,
            or `(dim, f)` where `dim` is the input dimension and `f` is a network
            function. The width of the last layer in the branch and trunk net should be
            equal.
        layer_sizes_trunk (list): A list of integers as the width of a fully connected
            network.
        activation: If `activation` is a ``string``, then the same activation is used in
            both trunk and branch nets. If `activation` is a ``dict``, then the trunk
            net uses the activation `activation["trunk"]`, and the branch net uses
            `activation["branch"]`.
    """

    def __init__(
        self,
        layer_sizes_branch:list,
        layer_sizes_trunk:list,
        *args,**kwargs):
        
        super().__init__()
        #activation_branch = activation_trunk = activation

        self.branch = MLP(layer_sizes_branch, *args,**kwargs)
        self.trunk = MLP(layer_sizes_trunk, *args,**kwargs)
        
        self.b = torch.tensor(0.0,requires_grad = True)
    
    def forward(self, inputs):
        v_func:torch.Tensor = inputs[0] # Input signal (batch_size,resolution)
        y_loc:torch.Tensor = inputs[1].swapaxes(2,1) # Collocation points (batch_size,input_dim,num_points)->(b,p,n)
        print(v_func.shape)
        # Branch net to encode the input function
        v_func = self.branch(v_func)
        
        # Trunk net to encode the domain of the output function
        y_loc = self.trunk(y_loc).swapaxes(2,1) #Output dim (batch_size,output_layer_dim,num_points)
        
        # Dot product
        if v_func.shape[-1] != y_loc.shape[1]:
            raise AssertionError(
                "Output sizes of branch net and trunk net do not match.")
        x = torch.einsum("bl,blp->bp", v_func, y_loc)
        
        # Add bias
        x += self.b
        
        return x

In [38]:
# Choose a network
m = 100
dim_x = 1
net = DeepONet(
    [m, 40, 40],
    [dim_x, 40, 40])


In [39]:
net.branch.layers[0].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)

In [40]:
loss_fn = nn.MSELoss()
niter = 10000
opt = optim.Adam(net.parameters(),lr = 1e-3)

In [41]:
2000 % 1000

0

In [42]:
len(training_loader)

313

In [43]:
for epoch in range(5):
    print(f'EPOCH {epoch}')
    
    net.train(True)
    
    running_loss = 0.
    last_loss = 0. 
    
    for i,data in enumerate(training_loader):
        inputs,labels = data
        output = net(inputs)
        
        opt.zero_grad()
        
        loss = loss_fn(output,labels)
        loss.backward()
        
        opt.step()
        
        running_loss += loss.item()
        
        if i%50 ==49:
            last_loss = running_loss/32
            print(f'Batch :{i} \t Loss\t{last_loss:.5f}')
            running_loss = 0.

EPOCH 0
torch.Size([32, 100])


RuntimeError: mat1 and mat2 must have the same dtype

In [45]:
v_func,y_loc  = inputs
v_func.shape

torch.Size([32, 100])

In [46]:
net.branch(v_func)

RuntimeError: mat1 and mat2 must have the same dtype

In [49]:
print(v_func.dtype)
for layer in net.branch.layers:
    print(layer.weight.dtype)

torch.float64
torch.float32
torch.float32


In [None]:

# Define a Model
model = dde.Model(data, net)

# Compile and Train
model.compile("adam", lr=0.001, metrics=["mean l2 relative error"])
losshistory, train_state = model.train(iterations=10000)

# Plot the loss trajectory
dde.utils.plot_loss_history(losshistory)
plt.show()

In [None]:
net