In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from copy import deepcopy

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
import jax.numpy as jnp
import equinox as eqx

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    load_and_process_single_from_full_file_overview
)
from mc2.data_management import FrequencySet, MaterialSet, DataSet

## load data from pickle:

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

available_materials.remove("N49")
print(available_materials)
print(len(available_materials))


dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

# Structure of the data set:

- `DataSet`: holds a list of `MaterialSet` which can be accessed either through the index or via `at_material(material_name: str)`
- `MaterialSet`: holds a list of `FrequencySets` which can be accessed either through the index or via `at_frequency(frequency: float)`
- `FrequencySet`: holds the raw data as arrays with shape `(n_sequences, sequence_length)` for `B` and `H` and `(n_sequences)` for `T`

In [None]:
assert dataset[0] == dataset.at_material("78")
assert dataset[1] == dataset.at_material("3C90")
assert dataset[0] == dataset.at_material("78")
assert isinstance(dataset.at_material("78"), MaterialSet)

material_set_for_78 = dataset[0]
assert material_set_for_78[0] == material_set_for_78.at_frequency(50_000)
assert material_set_for_78[0] == material_set_for_78.at_frequency(material_set_for_78.frequencies[0])
assert isinstance(dataset[0][0], FrequencySet)

frequency_set_for_78_at_50kHz = dataset[0][0]
assert frequency_set_for_78_at_50kHz == dataset.at_material("78").at_frequency(50_000)

Filtering examples (three main functions):

- `{DataSet}.filter_materials(list[str] | str) -> {DataSet}`
- `{DataSet/MaterialSet}.filter_frequencies(list[float] | jnp.Array | float) -> {DataSet/MaterialSet}`
- `{DataSet/MaterialSet/FrequencySet}.filter_temperatures(list[float] | jnp.Array | float) -> {DataSet/MaterialSet/FrequencySet}` 

In [None]:
dataset_at_80kHz = dataset.filter_frequencies([80_000])
for m_set in dataset_at_80kHz:
    assert m_set.frequencies == jnp.array([80_000])

dataset_at_50_and_80kHz = dataset.filter_frequencies([50_000, 80_000])
for m_set in dataset_at_50_and_80kHz:
    assert jnp.all(m_set.frequencies == jnp.array([50_000, 80_000]))

In [None]:
dataset_at_25degrees = dataset.filter_temperatures([25])
for m_set in dataset_at_25degrees:
    for f_set in m_set:
        assert jnp.unique(f_set.T) == jnp.array([25])

# EDA:

### TODOS:
- build an exploratory data analysis
  - extend with further plots
  - generalize for all materials through subplots 
- use the one from MC1 as reference: https://github.com/upb-lea/hardcore-magnet-challenge/blob/main/notebooks/wk-1.1-eda.ipynb


### Questions:
- Is the end of a given sequence the starting point for the next?
- ...

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(20, 6), sharey=True, sharex=True)

fig.suptitle("Datapoints per frequency")

for idx, material_set in enumerate(dataset):
    row_idx = 0 if idx < 5 else 1
    col_idx = idx % 5
    
    full_set_lengths = []
    for freq_set in material_set:
        n_sequences, sequence_length = freq_set.B.shape
        full_set_length = n_sequences * sequence_length
    
        full_set_lengths.append(full_set_length)

    axs[row_idx, col_idx].plot(material_set.frequencies / 1e3, np.array(full_set_lengths) / 1e6)

for ax in axs[-1]:
    ax.set_xlabel("f in kHz")

for ax in axs[:, 0]:
    ax.set_ylabel("# of datapoints in M")
        
for ax_ in axs:
    for ax in ax_:
        ax.grid()

fig.tight_layout()

plt.show()

In [None]:
raise NotImplementedError("Generalize for all materials.")

print("unique temperatures per frequency")
unique_temperatures = {int(frequency): jnp.unique(freq_set.T) for frequency, freq_set in zip(material_set.frequencies, material_set)}
display(unique_temperatures)

for freq_set in material_set:
    plt.suptitle("Temperatures per sequence")
    plt.plot(freq_set.T, label=str(int(freq_set.frequency / 1e3)) + " kHz")
    plt.legend()

plt.xlabel("sequence index")
plt.ylabel("T in °C")
plt.grid()
plt.show()

There are always the temperature levels 25, 50, and 70 for all frequencies (at least for "3C90")

In [None]:
raise NotImplementedError("Generalize for all materials.")

# full_set_lengths = []
# for freq_set in material_set:

#     n_sequences, sequence_length = freq_set.B.shape
#     full_set_length = n_sequences * sequence_length

#     full_set_lengths.append(full_set_length)

# plt.suptitle("Datapoints per temperature")
# plt.plot(temperatures, n_data_points)
# plt.xlabel("T in °C")
# plt.grid()
# plt.ylabel("# of datapoints")