In [1]:
import jax.numpy as jnp

from jax import grad
import jax

import numpy as np
import matplotlib.pyplot as plt
import time
from functools import partial

# Dataset

In [2]:
class Cheb_poly():
    def __init__(self, N, M):
        super(Cheb_poly, self).__init__()
        self.N = N
        self.M = M
        
    def make_one_function(self):
        coefficents = []
        for i in range(self.N):
            coefficents.append(np.random.uniform(-self.M,self.M))
        Cheb = np.polynomial.chebyshev.Chebyshev(coefficents,[0,1])
        return Cheb
    
    def sample_one_function(self, Cheb, x, y):
        return(Cheb(x),y,Cheb.integ(lbnd=0)(y))
    
    def sample_functions(self, n, m):
        input_u_list = []
        input_y_list = []
        output_list = []
        input_u = np.zeros((2,n,np.max(m)))
        input_y = np.zeros((2,n,1))
        output = np.zeros((2,n,1))
        for i in range(n):
            #y = np.array([1])
            func = self.make_one_function() # Make new u(x) function
            y = np.random.uniform(0, 1, 1) # Output location
            for j in range(m.shape[0]):
                x = np.linspace(0, 1, m[j]) # Input Location
                p1, p2, p3 = self.sample_one_function(func,x, y)
                input_u[j,i,:] = np.pad(p1,(0, int(np.max(m)-m[j])))
                input_y[j,i,:] = p2
                output[j,i,:] = p3
                
        for j in range(m.shape[0]):
            input_u_list.append(input_u[j,:,0:m[j]])
            input_y_list.append(input_y[j])
            output_list.append(output[j])
            
        return input_u_list, input_y_list, output_list
    
    def sample_functions_test(self, n, m, y_dim):
        input_u_list = []
        input_y_list = []
        output_list = []
        input_u = np.zeros((2,n,np.max(m)))
        input_y = np.zeros((2, n, y_dim))
        output = np.zeros((2, n, y_dim))
        for i in range(n):
            #y = np.array([1])
            func = self.make_one_function() # Make new u(x) function
            y = np.linspace(0, 1, y_dim) # Output location
            for j in range(m.shape[0]):
                x = np.linspace(0, 1, m[j]) # Input Location
                p1, p2, p3 = self.sample_one_function(func,x, y)
                input_u[j,i,:] = np.pad(p1,(0, int(np.max(m)-m[j])))
                input_y[j,i,:] = p2
                output[j,i,:] = p3
                
        for j in range(m.shape[0]):
            input_u_list.append(input_u[j,:,0:m[j]])
            input_y_list.append(input_y[j])
            output_list.append(output[j])
            
        return input_u_list, input_y_list, output_list
    
Cheb = Cheb_poly(5, 1) # Chebyshev polynomial degree, absolute coefficent bound value i.e. [-1,1]
sensors = np.array([10, 100])
input_u, input_y, output = Cheb.sample_functions(10000, sensors) # Number of u(x) functions, Number of sensors   
input_u_test, input_y_test, output_test = Cheb.sample_functions(10000, sensors) # Number of u(x) functions, Number of sensors 

# DeepOpNet

```
import torch
import torch.nn as nn

# Base Neural Network class
class Net(nn.Module):
    def __init__(self, layers, act=nn.ReLU()):
        super(Net, self).__init__()
        self.act = act
        self.fc = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.fc.append(nn.Linear(layers[i], layers[i+1]))
    def forward(self, x):
        for i in range(len(self.fc) - 1):
            x = self.fc[i](x)
            x = self.act(x)
        x = self.fc[-1](x)
        return x
    
branch = Net([sensors[0],50,50])
trunk = Net([1,50,50])

torch_branch_params = []
for i in range(len(branch.fc)):
    W = branch.fc[i].weight.data.numpy()
    b = branch.fc[i].bias.data.numpy()
    
    torch_branch_params.append((W.T, b))
    
torch_trunk_params = []
for i in range(len(trunk.fc)):
    W = trunk.fc[i].weight.data.numpy()
    b = trunk.fc[i].bias.data.numpy()
    
    torch_trunk_params.append((W.T, b))
```

## InitialisingParams

In [3]:
def init_params(layers:list, seed:int) -> list:
    initialiser = jax.nn.initializers.glorot_normal()
    key = jax.random.PRNGKey(seed)
    
    params=[]
    for i in range(len(layers)-1):
        key, subkey = jax.random.split(key)
        W = initialiser(key=subkey, shape=(layers[i],layers[i+1]), dtype=jnp.float32)
        b = jnp.zeros(shape=(layers[i+1],), dtype=jnp.float32)
        params.append(dict(Weights=W, Biases=b))
        
    return params

# def init_params(torch_params:list) -> list:
#     params=[]
#     for i in range(len(torch_params)):
#         W = torch_params[i][0]
#         b = torch_params[i][1]
#         params.append(dict(Weights=W, Biases=b))
        
#     return params

In [4]:
# branchParams = init_params(torch_branch_params)
branchParams = init_params([sensors[0],50,50], 0)

# Verifying shapes
jax.tree_map(lambda x:x.shape, branchParams)

[{'Biases': (50,), 'Weights': (10, 50)},
 {'Biases': (50,), 'Weights': (50, 50)}]

In [5]:
# trunkParams = init_params(torch_trunk_params)
trunkParams = init_params([1,50,50], 1)

# Verifying shapes
jax.tree_map(lambda x:x.shape, trunkParams)

[{'Biases': (50,), 'Weights': (1, 50)}, {'Biases': (50,), 'Weights': (50, 50)}]

## Forward

### BranchNet

In [6]:
@jax.jit
def branch_forward(branchParams:list, inputU:np.ndarray) -> np.ndarray:
    hidden = branchParams[:-1]
    last = branchParams[-1]
    x = inputU
    for layer in hidden:
        z = jnp.dot(x,layer["Weights"]) + layer["Biases"]
        x = jax.nn.relu(z)
        
    z = jnp.dot(x,last["Weights"]) + last["Biases"]
    return z

### TrunkNet

In [7]:
@jax.jit
def trunk_forward(trunkParams:list, inputY:np.ndarray) -> np.ndarray:
    hidden = trunkParams[:-1]
    last = trunkParams[-1]
    x = inputY
    for layer in hidden:
        z = jnp.dot(x,layer["Weights"]) + layer["Biases"]
        x = jax.nn.relu(z)
        
    z = jnp.dot(x,last["Weights"]) + last["Biases"]
    return z

### DeepOpNet

In [8]:
@jax.jit
def deep_op_net(branchParams:list, trunkParams:list, 
                inputU:np.ndarray, inputY:np.ndarray) -> np.ndarray:
    branchOut = branch_forward(branchParams, inputU)
    trunkOut = trunk_forward(trunkParams, inputY)
    
    return jnp.sum(branchOut*trunkOut, axis=-1, keepdims=True)

## LossFunction

In [9]:
@jax.jit
def loss(branchParams:list, trunkParams:list, 
         inputU:np.ndarray, inputY:np.ndarray, 
         output:np.ndarray) -> np.ndarray:
    pred = deep_op_net(branchParams, trunkParams, inputU, inputY)
    return jnp.mean((output - pred)**2)

## BackwardPass

In [10]:
@jax.jit
def update(branchParams:list, trunkParams:list, 
           inputU:np.ndarray, inputY:np.ndarray, 
           output:np.ndarray, lr:float):

    d_branchParams, d_trunkParams = jax.grad(loss, argnums=(0,1))(branchParams, trunkParams, 
                                                                   inputU, inputY, 
                                                                   output)
    return (jax.tree_map(lambda p,g:p-lr*g, branchParams, d_branchParams), 
            jax.tree_map(lambda p,g:p-lr*g, trunkParams, d_trunkParams))

# Evaluating

## Initial Loss

In [11]:
mse = loss(branchParams, trunkParams, input_u[0], input_y[0], output[0])
mse

DeviceArray(0.21452382, dtype=float32)

## Performing 1000 Updates

In [12]:
for i in range(1000):
    branchParams, trunkParams = update(branchParams, trunkParams, 
                                       input_u[0], input_y[0], 
                                       output[0], 0.001)

## Loss after updates

In [13]:
loss(branchParams, trunkParams, input_u[0], input_y[0], output[0])

DeviceArray(0.02939601, dtype=float32)