In [None]:
import sys
sys.path.append("../../gaia-surrogate")
import torch
from gaia.models import TrainingModel
from gaia.training import get_checkpoint_file
model_dir = "/proj/gaia-climate/team/kirill/gaia-surrogate/lightning_logs/version_1/"
model = TrainingModel.load_from_checkpoint(
            get_checkpoint_file(model_dir), map_location = "cpu",
        )


In [None]:
%debug

In [None]:
_ = model.eval()
example = torch.rand(10,164)
out = model.model(example)
traced_script_module = torch.jit.trace(model.model, example)
traced_script_module.save("traced_model.pt")
from collections import 
open("traced_model_arch_printed.txt","w").write(str(traced_script_module))

In [None]:
from collections import OrderedDict
import numpy as np

In [None]:
inputs = list(model.hparams.input_index.keys())
outputs = list(model.hparams.output_index.keys())

np.random.shuffle(inputs)
np.random.shuffle(outputs)



In [None]:
print(inputs, outputs)

In [None]:
class ModelForExport(torch.nn.Module):
    def __init__(self, training_model, input_order, output_order):
        super().__init__()
        self.input_normalize = training_model.input_normalize
        self.output_normalize = training_model.output_normalize
        self.model = training_model.model
        
        
        input_order_index = OrderedDict()
        i = 0
        
        for k in input_order:
            s,e  = training_model.hparams.input_index[k]
            v_size = e - s
            input_order_index[k] = (i,i + v_size)
            i = i + v_size

        self.register_buffer("input_order",torch.cat([torch.arange(*input_order_index[k]) for k in training_model.hparams.input_index.keys()]))
        self.register_buffer("output_order",torch.cat([torch.arange(*training_model.hparams.output_index[k]) for k in output_order]))
        
    def forward(self,x):
        x = x[:,self.input_order,...]
        x = self.input_normalize(x)
        
        y = self.model(x)
        y = self.output_normalize(y, normalize=False)
        y = y[:,self.output_order,...]
        return y
    
    
model_for_export = ModelForExport(model, inputs, outputs).eval()
example = torch.rand(10,164)
out = model_for_export(example)
traced_script_module = torch.jit.trace(model_for_export, example)
traced_script_module.save("traced_model.pt")

open("traced_model_arch_printed.txt","w").write(str(traced_script_module))