In [None]:
# INSTALL OMNIXAS PACKAGE
# run this once in the beginning to install the package
#  no need to run second time
! pip install -v "git+https://github.com/AI-multimodal/OmniXAS.git"

In [None]:
import json
import os

import numpy as np
from matplotlib import pyplot as plt
from sklearn.multioutput import MultiOutputRegressor
from sklearn.svm import SVR

from omnixas.data import MLSplits
from omnixas.model.metrics import ModelMetrics
from omnixas.model.xasblock_regressor import XASBlockRegressor

In [None]:
DATA_DIR = "dataset/omnixas_2"

def fetch_dataset_elements(data_dir):
    "parses filenames in directory to get the element for which data is available"
    files = os.listdir(data_dir)
    elements = [file.split("_")[-1].split(".")[0] for file in files if "json" in file]
    return elements

elements = fetch_dataset_elements(DATA_DIR+"/spectra")
print(f"Found data for {len(elements)} elements")

In [None]:
element = elements[0] # select different  elements here
element = "Cu"
split_json = json.load(open(f"{DATA_DIR}/splits/split_{element}.json"))
split = MLSplits.parse_obj(split_json)

In [None]:
# XASBLOCK model 
model = XASBlockRegressor(
    directory=f"checkpoints/{element}",
    max_epochs=100,
    early_stopping_patience=25,  # stops if val_loss does not improve for 25 epochs
    overwrite_save_dir=True,  # delete save_dir else adds new files to it
    input_dim=64,
    output_dim=200,
    hidden_dims=[200,200],
    initial_lr=1e-2,  # initial learning rate, will be optimized by lr finder later
    batch_size=128,
)
model.fit(split) # full split object needs to be passed coz it contains val data used in logging
# model.load()  # to load saved model from disk

In [None]:
# USE this to monitor training progress
# refer to this to understand implication  of train/val loss:
# https://machinelearningmastery.com/learning-curves-for-diagnosing-machine-learning-model-performance/ 
%load_ext tensorboard
# %reload_ext tensorboard # to restart tensorboard

%tensorboard --logdir checkpoints/

In [None]:
# SIMPLE MODELS
# using any simple model you want https://scikit-learn.org/1.5/supervised_learning.html
# model = MultiOutputRegressor(SVR())  
# model.fit(split.train.X, split.train.y) 

In [None]:
targets = split.val.y
predictions = model.predict(split.val.X)

In [None]:
# plotting predictions in validation set
plt.plot(predictions.T, alpha=0.5)
plt.show()

In [None]:
def get_eta(split, metrics):
    train_mean = split.train.y.mean(axis=0)
    targets = split.val.y
    mean_model_predictions = np.tile(train_mean, (targets.shape[0], 1))
    mean_model_metrics = ModelMetrics(
        targets=targets,
        predictions=mean_model_predictions,
    )
    return (
        mean_model_metrics.median_of_mse_per_spectra / metrics.median_of_mse_per_spectra
    )
metrics = ModelMetrics(predictions=predictions, targets=targets)
eta = get_eta(split, metrics)
print(f"MSE: {metrics.mse}", f"eta: {eta}")

In [None]:
plt.hist(np.log(metrics.mse_per_spectra), bins=20, alpha=0.5, density=True)
plt.xlabel("log(MSE)")
plt.ylabel("Density")
plt.title(f"Element: {element} \n eta: {round(eta, 2)}")
plt.show()

In [None]:
deciles = metrics.deciles
fig, axs = plt.subplots(9, 1, figsize=(6, 20))
for i, (d, ax) in enumerate(zip(deciles, axs)):
    ax.plot(d[0], label="target")
    ax.plot(d[1], label="prediction")
    ax.fill_between(
        range(len(d[0])),
        d[0],
        d[1],
        alpha=0.5,
        interpolate=True,
    )
    ax.legend()
    ax.set_title(f"Decile {i+1}")
fig.tight_layout()
fig.show()