In [1]:
import pickle
import numpy as np
import sys
import copy
import os
import matplotlib.pyplot as plt
import glob
import torch
from model import MixtureOfExperts

In [2]:
## loading pytorch model

model_name = f"model_experiment_5_1101443/checkpoint_savetime_1101443_batchsize_16384_numexperts_64.pt"
device = 'cpu'
num_experts = int( model_name.split('_')[-1][:-3] )
num_vars = 6
gate_features = 128
expert_features = 32
out_features = 3
model = MixtureOfExperts(num_vars, gate_features, expert_features, out_features, num_experts)
model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))
model.to(torch.double)
model.to(device)
model.eval()

MixtureOfExperts(
  (gate): SoftmaxGate(
    (linear_gate): LinearHN(
      (hn): MLP(
        (sequential): Sequential(
          (0): Linear(in_features=6, out_features=128, bias=True)
          (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (2): SiLU()
          (3): Dropout(p=0.5, inplace=False)
          (4): Linear(in_features=128, out_features=128, bias=True)
          (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (6): SiLU()
          (7): Dropout(p=0.5, inplace=False)
          (8): Linear(in_features=128, out_features=128, bias=True)
          (9): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (10): SiLU()
          (11): Dropout(p=0.5, inplace=False)
          (12): Linear(in_features=128, out_features=448, bias=True)
        )
      )
      (linear): Linear(in_features=6, out_features=64, bias=True)
    )
    (softmax): Softmax(dim=-1)
  )
  (experts): ModuleList(
    (0-63): 64 x LinearHN(
      (hn): MLP(
    

In [3]:
with open(f"transform.pickle","rb") as fhandle:
    data_transform = pickle.load(fhandle)

with open(f"init_data.pickle","rb") as fhandle:
    init_data = pickle.load(fhandle)

In [4]:
## surrogate eval and jacobian wrapper, for pointwise evaluations

class ModelWrapper:

    def __init__(self, model, data_transform, init_data):
        self.data_transform = data_transform
        self.model = model
        self.init_data = init_data
        
        self.ninputs = self.data_transform['input'].length
        self.noutputs = self.data_transform['output'].length

    def eval(self, x: np.array):
        # x is (n,)
        x_extended = np.array([x])
        z = torch.tensor(self.data_transform['input'].transform(x_extended), dtype=torch.double)         
        # eval surrogates              
        with torch.no_grad():
            outs  = self.model(z)
        return self.data_transform['output'].inverse_transform(outs.numpy()).reshape((-1,))

    def eval_jacobian(self, x: np.array):
        '''
        pytorch jacobian modified from
        https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa
        '''
        # x is (n,)

        x_extended = np.array([x])
        Jin = self.data_transform['input'].forward_derivative(x_extended)
        z = torch.tensor( self.data_transform['input'].transform(x_extended), dtype=torch.double).squeeze()
        z = z.repeat(self.noutputs, 1)
        z.requires_grad_(True)
        y = self.model(z)
        y.backward(torch.eye(self.noutputs))
        Jmod = z.grad.data.numpy()
        Jout = data_transform['output'].inverse_derivative( y[0].detach().numpy()[None,:] )
        J = Jout[:,None]*(Jmod * Jin) # (3,6) jacobian for all 6 inputs
        return J # (3,3) jacobian for just the state variables 

    def eval_func_wrapper(self, t, x):
        '''
        wrapper for scipy.integrate.solve_ivp
        '''
        return self.eval(x)

    def eval_jacobian_wrapper(self, t, x):
        '''
        wrapper for scipy.integrate.solve_ivp
        '''
        return self.eval_jacobian(x)
    

In [5]:
Phi = ModelWrapper(model, data_transform, init_data)

In [7]:
'''set1 = np.array([0.119132960099, 600.0690046472204, 0.0, 4669.9511535360325, 4406641771830.477, 1.018454806316354e-09])
Phi.eval(set1)'''

array([3.97202342e-12, 8.20187736e+17, 7.71206023e+09])