In [None]:
! pip install -v "git+https://github.com/AI-multimodal/OmniXAS.git"

In [None]:
import json
import os

import numpy as np
from matplotlib import pyplot as plt
from sklearn.multioutput import MultiOutputRegressor
from sklearn.svm import SVR

from omnixas.data import MLSplits
from omnixas.model.trained_model import ModelMetrics


In [None]:
def fetch_dataset_elements(data_dir="dataset/omnixas_2/features/m3gnet/"):
    files = os.listdir(data_dir)
    elements = [file.split("_")[-1].split(".")[0] for file in files if "json" in file]
    return elements

In [None]:
# Load ML split data
data_dir = "dataset/omnixas_2"
elements = fetch_dataset_elements()



In [None]:
element = elements[0] # select different  elements here
split_json = json.load(open(f"{data_dir}/splits/split_{element}.json"))
split = MLSplits.parse_obj(split_json)

In [None]:
model = MultiOutputRegressor(SVR())  # use any model you want
model.fit(split.train.X, split.train.y)

In [None]:
targets = split.val.y
predictions = model.predict(split.val.X)

In [None]:
plt.plot(predictions.T, alpha=0.5)

In [None]:
def get_eta(split, metrics):
    train_mean = split.train.y.mean(axis=0)
    targets = split.val.y
    mean_model_predictions = np.tile(train_mean, (targets.shape[0], 1))
    mean_model_metrics = ModelMetrics(
        targets=targets,
        predictions=mean_model_predictions,
    )
    return (
        mean_model_metrics.median_of_mse_per_spectra / metrics.median_of_mse_per_spectra
    )

In [None]:
metrics = ModelMetrics(predictions=predictions, targets=targets)
eta = get_eta(split, metrics)
print(f"MSE: {metrics.mse}", f"eta: {eta}")

In [None]:
plt.hist(np.log(metrics.mse_per_spectra), bins=20, alpha=0.5, density=True)
plt.xlabel("log(MSE)")
plt.ylabel("Density")
plt.title(f"Element: {element} \n eta: {round(eta, 2)}")

In [None]:
deciles = metrics.deciles
fig, axs = plt.subplots(9, 1, figsize=(6, 20))
for i, (d, ax) in enumerate(zip(deciles, axs)):
    ax.plot(d[0], label="target")
    ax.plot(d[1], label="prediction")
    ax.fill_between(
        range(len(d[0])),
        d[0],
        d[1],
        alpha=0.5,
        interpolate=True,
    )
    ax.legend()
    ax.set_title(f"Decile {i+1}")
fig.tight_layout()
fig.show()