In [1]:
import pickle
import numpy as np
import sys
import copy
import os
import matplotlib.pyplot as plt
import glob
import torch
from model import MixtureOfExperts


In [2]:
# loading pytorch model

model_name = f"model_experiment_5_1101443/checkpoint_savetime_1101443_batchsize_16384_numexperts_64.pt"
device = 'cpu'
num_experts = int( model_name.split('_')[-1][:-3] )
print(num_experts)
num_vars = 6
gate_features = 128
expert_features = 32
out_features = 3
model = MixtureOfExperts(num_vars, gate_features, expert_features, out_features, num_experts)
model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))
model.to(torch.double)
model.to(device)
model.eval()

64


MixtureOfExperts(
  (gate): SoftmaxGate(
    (linear_gate): LinearHN(
      (hn): MLP(
        (sequential): Sequential(
          (0): Linear(in_features=6, out_features=128, bias=True)
          (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (2): SiLU()
          (3): Dropout(p=0.5, inplace=False)
          (4): Linear(in_features=128, out_features=128, bias=True)
          (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (6): SiLU()
          (7): Dropout(p=0.5, inplace=False)
          (8): Linear(in_features=128, out_features=128, bias=True)
          (9): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (10): SiLU()
          (11): Dropout(p=0.5, inplace=False)
          (12): Linear(in_features=128, out_features=448, bias=True)
        )
      )
      (linear): Linear(in_features=6, out_features=64, bias=True)
    )
    (softmax): Softmax(dim=-1)
  )
  (experts): ModuleList(
    (0-63): 64 x LinearHN(
      (hn): MLP(
    

In [3]:
# ## loading dataset 
# #This step requires the HT9 dataset. To remove this dependence, we will pickle only the necessery information.

# from loaders import filtered_train_test_dataset, load_simulations, load_data

# input_names = ['vmJ2', 'temperature', 'evm', 'rhoc', 'rhow', 'flux']
# output_names = ['evm', 'rhoc', 'rhow']

# # LOADING THE DATASET AND DATA TRANSFORMS
# train_dataset, test_dataset = filtered_train_test_dataset()
# data_transform = train_dataset.data_transform
# xtr_full, ytr_full = load_data('train')

# simdata = load_simulations(mode='valid')

# # pickling some run specific information -- needed for the model

# sim_id = 0
# init_data = {'vmJ2' : simdata['sim'][sim_id]['xinit'][0],
#              'temperature' :simdata['sim'][sim_id]['xinit'][1],
#              'evm' : simdata['sim'][sim_id]['xinit'][2],
#              'rhoc' : simdata['sim'][sim_id]['xinit'][3],
#              'rhow' : simdata['sim'][sim_id]['xinit'][4],
#              'flux' : simdata['sim'][sim_id]['xinit'][5],
#              }

# with open(f"init_data.pickle","wb") as fhandle:
#     pickle.dump(init_data, fhandle)

# # pickling the data transform

# with open(f"transform.pickle","wb") as fhandle:
#     pickle.dump(data_transform, fhandle)
    

In [4]:
with open(f"transform.pickle","rb") as fhandle:
    data_transform = pickle.load(fhandle)

with open(f"init_data.pickle","rb") as fhandle:
    init_data = pickle.load(fhandle)

In [5]:
## surrogate eval and jacobian wrapper, for pointwise evaluations

class ModelWrapper:

    def __init__(self, model, data_transform, init_data):
        self.data_transform = data_transform
        self.model = model
        self.init_data = init_data
        
        self.ninputs = self.data_transform['input'].length
        self.noutputs = self.data_transform['output'].length

    def eval(self, x: np.array):
        # x is (n,)
        x_extended = np.array([[self.init_data['vmJ2'], 
                       self.init_data['temperature'],
                       x[0], x[1], x[2],
                       self.init_data['flux']]])
        z = torch.tensor(self.data_transform['input'].transform(x_extended), dtype=torch.double)         
        # eval surrogates              
        with torch.no_grad():
            outs  = self.model(z)
        return self.data_transform['output'].inverse_transform(outs.numpy()).reshape((-1,))

    def eval_jacobian(self, x: np.array):
        '''
        pytorch jacobian modified from
        https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa
        '''
        # x is (n,)

        x_extended = np.array([[self.init_data['vmJ2'], 
                       self.init_data['temperature'],
                       x[0], x[1], x[2],
                       self.init_data['flux']]])
        Jin = self.data_transform['input'].forward_derivative(x_extended)
        z = torch.tensor( self.data_transform['input'].transform(x_extended), dtype=torch.double).squeeze()
        z = z.repeat(self.noutputs, 1)
        z.requires_grad_(True)
        y = self.model(z)
        y.backward(torch.eye(self.noutputs))
        Jmod = z.grad.data.numpy()
        Jout = data_transform['output'].inverse_derivative( y[0].detach().numpy()[None,:] )
        J = Jout[:,None]*(Jmod * Jin) # (3,6) jacobian for all 6 inputs
        return J[:,2:5] # (3,3) jacobian for just the state variables 

    def eval_func_wrapper(self, t, x):
        '''
        wrapper for scipy.integrate.solve_ivp
        '''
        return self.eval(x)

    def eval_jacobian_wrapper(self, t, x):
        '''
        wrapper for scipy.integrate.solve_ivp
        '''
        return self.eval_jacobian(x)
    

In [6]:
# Comparing finite difference to formulated Jacobian
# from scipy.optimize import check_grad

# sim_id = 0
# init_data = {'vmJ2' : simdata['sim'][sim_id]['xinit'][0],
#              'temperature' :simdata['sim'][sim_id]['xinit'][1],
#              'evm' : simdata['sim'][sim_id]['xinit'][2],
#              'rhoc' : simdata['sim'][sim_id]['xinit'][3],
#              'rhow' : simdata['sim'][sim_id]['xinit'][4],
#              'flux' : simdata['sim'][sim_id]['xinit'][5],
#              } #used just to define a model
# f = Func(model, data_transform, init_data)

# x0 = xtr_full[1000, 2:5]

# check_grad(f.eval, f.eval_jacobian, x0)

In [7]:
from scipy.optimize import approx_fprime

# sim_id = 0
# init_data = {'vmJ2' : simdata['sim'][sim_id]['xinit'][0],
#              'temperature' :simdata['sim'][sim_id]['xinit'][1],
#              'evm' : simdata['sim'][sim_id]['xinit'][2],
#              'rhoc' : simdata['sim'][sim_id]['xinit'][3],
#              'rhow' : simdata['sim'][sim_id]['xinit'][4],
#              'flux' : simdata['sim'][sim_id]['xinit'][5],
#              } #used just to define a model
f = ModelWrapper(model, data_transform, init_data)

#x0 = xtr_full[50000, 2:5] #requires HT9 dataset 
x0 = np.array([2.32967174e-02, 1.90846970e+12, 6.66795851e+12]) #just an example point
#x0 = 100*np.random.rand(3) #this data point is unlikely to be in the relevant region of the input space, but we can still perform evaluations as a test.

J = f.eval_jacobian(x0) # based on formulas implemented above and in data_transforms.py

Jfd = approx_fprime(x0, f.eval, epsilon = 1e-30) #very small epsilon are needed for accurate finite difference approximations...

In [8]:
J

array([[-3.80411722e+00,  1.17431627e-13,  1.38752595e-13],
       [ 2.80443323e+13, -5.50367257e-01, -9.22529403e-01],
       [ 1.68246314e+14, -2.07063543e+00, -3.91520040e+00]])

In [9]:
Jfd

array([[-3.80411587e+00,  1.17431617e-13,  1.38752645e-13],
       [ 2.80443218e+13, -5.50367405e-01, -9.22529742e-01],
       [ 1.68246211e+14, -2.07063525e+00, -3.91520167e+00]])

In [10]:
f.eval(x0)

array([ 9.08092910e-02, -7.28390357e+11, -2.73717052e+12])