# Overview all materials:

This notebook provides an overview of all materials and the performance a model for each of them.
Note, that all of these models are trained in a cross-validation setting (i.e., some of the training data is withheld intentionally to investigate the extrapolation performance of the models).

If you have questions, please:
- post them as an issue in https://github.com/upb-lea/magnet-challenge-2/issues
- or write an e-mail to hendrik.vater@uni-siegen.de

(In the following, all paths will be given relative to the repository root)

In [None]:
# optional setup
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose cuda-device
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory

import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")  # optionally run on cpu

In [None]:
from mc2.data_management import AVAILABLE_MATERIALS, MaterialSet, DataSet

from mc2.utils.data_plotting import plot_sequence_prediction, plot_hysteresis_prediction
from mc2.utils.model_evaluation import reconstruct_model_from_file, plot_model_frequency_sweep, evaluate_cross_validation
from mc2.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL

## Present materials:

The raw material data is to be stored at `data/raw` and (when all data is present) should look like this:

```text
└── data/raw
    ├── 3C90/
    ├── 3C94/
    ├── 3E6/
    ├── 3F4/
    ├── 77/
    ├── 78/
    ├── Material A/
    ├── Material B/
    ├── Material C/
    ├── Material D/
    ├── Material E/
    ├── N27/
    ├── N30/
    ├── N49/
    ├── N87/
    └── sort_raw_data.py
```

Then run `python data/raw/sort_raw_data.py` from the command line to finish preparing the raw data. The folder should then look like this:

```text
└── data/raw
    ├── 3C90/
    ├── 3C94/
    ├── 3E6/
    ├── 3F4/
    ├── 77/
    ├── 78/
    ├── A/
    ├── B/
    ├── C/
    ├── D/
    ├── E/
    ├── N27/
    ├── N30/
    ├── N49/
    ├── N87/
    └── sort_raw_data.py
```


Upon first load of a each file, it will be stored in the cache at `data/cache` in the form of `.parquet` files.
The first load will be substantially slower as the data is loaded directly from `.csv`.
Following loads of the material data will go much faster.

In [None]:
print("Number of materials:", len(AVAILABLE_MATERIALS))
print(AVAILABLE_MATERIALS)
print()

The whole data set is quite a large amount of data: ~17 GB.

If it does not fit into your RAM or VRAM, you may use a subset of the materials at a time, 
e.g., `data_set = DataSet.from_material_names(AVAILABLE_MATERIALS[:5])` loads the first 5 materials in the list.

In [None]:
data_set = DataSet.from_material_names(AVAILABLE_MATERIALS)

Take a stop to inspect the `data_set` object, it is quite expressive thanks to being an `equinox.module` (a lot of the implementation builds on this package, check it out here (https://docs.kidger.site/equinox/))

- A `DataSet` consists of a collection of `MaterialSet`s (One for each material)
- A `MaterialSet` consists of a collection of `FrequencySet`s (One for each frequency)
- A `FrequencySet` holds the actual data in the form of `jax.Array`s

Checkout the methods for each of the classes. There are handy ways to filter each object.

In [None]:
display(data_set)

In [None]:
data_set.at_material("3C90")  # only data for 3C90

In [None]:
data_set.filter_frequencies([50_000])  # only Frequency sets with f=50 kHz

In [None]:
print(data_set.material_names)
subset = data_set.filter_materials(["A", "B", "C", "D", "E"]).filter_frequencies([80_000, 800_000])
print(subset)

### Data visualization:

In [None]:
import matplotlib.pyplot as plt
from mc2.utils.data_plotting import plot_hysteresis, plot_single_sequence, plot_frequency_sweep

In [None]:
# HB curves
for material_set in data_set:
    print(f"Full sequences overlaid for Material: '{material_set.material_name}'")
    fig, axs = plt.subplots(1,1,figsize=(5, 5))
    for frequency_set in material_set: 
        for sequence_idx in range(frequency_set.B.shape[0]):
            B_values = frequency_set.B[sequence_idx, :]
            H_values = frequency_set.H[sequence_idx, :]
            axs.plot(H_values, B_values, alpha=.3)
    axs.set_title(material_set.material_name)
    axs.set_ylabel("B in Vs/m^2")
    axs.set_xlabel("H in A/m")
    axs.grid(True, alpha=0.3)
    fig.tight_layout()
    plt.show()

In [None]:
# simply filter the data set if you only want to consider a subset of the full material data:
# e.g. only plot materials 'A' and 'B' at 25 degree celsius
sub_set = data_set.filter_materials(["A", "B"]).filter_temperatures([25])

for material_set in sub_set:
    print(f"Full sequences overlaid for Material: '{material_set.material_name}'")
    fig, axs = plt.subplots(1,1,figsize=(5, 5))
    for frequency_set in material_set: 
        for sequence_idx in range(frequency_set.B.shape[0]):
            B_values = frequency_set.B[sequence_idx, :]
            H_values = frequency_set.H[sequence_idx, :]
            axs.plot(H_values, B_values, alpha=.3)
    axs.set_title(material_set.material_name)
    axs.set_ylabel("B in Vs/m^2")
    axs.set_xlabel("H in A/m")
    axs.grid(True, alpha=0.3)
    fig.tight_layout()
    plt.show()

In [None]:
for material_set in data_set:
    print(f"Exemplary trajectories for Material: '{material_set.material_name}'")
    plot_frequency_sweep(material_set, loader_key=jax.random.PRNGKey(0), sequence_length=500, batch_size=2)
    plt.show()

## Present models:

In [None]:
from mc2.utils.model_evaluation import get_exp_ids

In [None]:
all_models = get_exp_ids()
all_models

In [None]:
get_exp_ids(material_name="3C90", model_type="GRU8", exp_name="reduced-features-f32")

In [None]:
exp_ids = {
    '3C90': '3C90_GRU8_reduced-features-f32_b5ce7dc9_seed12',
    '3C94': '3C90_GRU8_reduced-features-f32_b5ce7dc9_seed12',
    '3E6': '3E6_GRU8_reduced-features-f32_7ff91a7c_seed12',
    '3F4': '3F4_GRU8_reduced-features-f32_a83212e4_seed12',
    '77': '77_GRU8_reduced-features-f32_2eb8cc0c_seed12',
    '78': '78_GRU8_reduced-features-f32_3406a9c8_seed12',
    'N27': 'N27_GRU8_reduced-features-f32_2a482429_seed12',
    'N30': 'N30_GRU8_reduced-features-f32_b3ec1c0f_seed12',
    'N49': 'N49_GRU8_reduced-features-f32_6f23a1f0_seed12',
    'N87': 'N87_GRU8_reduced-features-f32_3f598f03_seed12',
    "A": 'A_GRU8_reduced-features-f32_2a1473b6_seed12',
    "B": 'B_GRU8_reduced-features-f32_c785b2c3_seed12',
    "C": 'C_GRU8_reduced-features-f32_348e220c_seed12',
    "D": 'D_GRU8_reduced-features-f32_b6ac55b5_seed12',
    "E": 'E_GRU8_reduced-features-f32_e88a2583_seed12',
}
models = {material_name: reconstruct_model_from_file(exp_id) for material_name, exp_id in exp_ids.items()}

The models are also `equinox.Module`s. They mostly consist of the actual data-driven model, a normalizer to transfer between the raw material data and data normalized between -1 and 1 (easier to deal with for data-driven models), and a featurization function, which manipulates/extends the input data to the data-driven model with the goal to make the input data more easily interpretable. For instance, the first and second derivative of the magentic flux is particularily helpful in predicting the magnetic field.

In [None]:
models["A"]

In [None]:
models["A"].n_params  # not necessarily available for all model types (yet)

In [None]:
from mc2.utils.model_evaluation import plot_model_frequency_sweep

In [None]:
for material_name, model in models.items():
    print(f"Exemplary trajectories including model prediction for Material: '{material_name}' with model '{exp_ids[material_name]}'")
    material_set = data_set.at_material(material_name)
    plot_model_frequency_sweep(model, material_set, loader_key=jax.random.PRNGKey(129), past_size=500)
    plt.show()

In [None]:
for material_name, model in models.items():
    print(f"Long-term predictions for Material: '{material_name}' with model '{exp_ids[material_name]}'")
    # sample and predict full sequences for all frequencies: