In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from copy import deepcopy

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
jax.config.update("jax_platform_name", "cpu")
# jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import equinox as eqx 

In [None]:
from mc2.utils.final_data_evaluation import (
    FINAL_MATERIALS, TestSet, ResultSet, predict_test_scenarios, validate_result_set, visualize_result_set
)
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids

---

## Checkout available models:

In [None]:
FINAL_MATERIALS

In [None]:
for material_name in FINAL_MATERIALS:
    print(material_name)
    print(get_exp_ids(material_name=material_name, model_type=None))
    print()

## Choose and load models:

In [None]:
exp_ids = {
    "A": 'A_GRU_fe0f6b18-a096-41',
    "B": 'B_GRU_62500cee-b06f-48',
    "C": 'C_GRU_d01265a7-ca67-41',
    "D": 'D_GRU_fcb7f6b4-95c4-4c',
    "E": 'E_GRU_20182caa-07d2-41',
}

models = {material_name: reconstruct_model_from_exp_id(exp_id) for material_name, exp_id in exp_ids.items()}
# models

## Run models on test data:

In [None]:
test_data = {material_name: TestSet.from_material_name(material_name) for material_name in FINAL_MATERIALS}

In [None]:
test_data

In [None]:
result_sets = predict_test_scenarios(
    models,
    test_data,
    exp_ids,
)

In [None]:
result_sets

## Inspect results:

In [None]:
for material_name in FINAL_MATERIALS:
    result_set = result_sets[material_name]
    test_set = test_data[material_name]
    assert result_set.material_name == material_name
    validate_result_set(result_set, test_set)

In [None]:
# Visualize result_sets
for material_name, result_set in result_sets.items():
    print("Visualization for material:", material_name)
    visualize_result_set(result_set)
    plt.show()

## Store results:

In [None]:
from mc2.data_management import DATA_ROOT

In [None]:
# store to csv
save_path = DATA_ROOT / "final_results"

for material_name, result_set in result_sets.items():
    with open(save_path / f"{result_set.exp_id}_final_test_prediction.csv", "w") as f:
        np.savetxt(f, result_set.H, delimiter=",")
        f.close()