In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from copy import deepcopy

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
import equinox as eqx

# gpus = jax.devices()
# jax.config.update("jax_enable_x64", False)
#jax.config.update("jax_default_device", gpus[1])

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    load_and_process_single_from_full_file_overview
)
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
material_set = MaterialSet.load_from_file(pathlib.Path("../../data/processed") / "3C90.pickle")

In [None]:
train_material_set, val_material_set, test_material_set = material_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)

In [None]:
frequencies_total=[(set.frequency, len(set.T)) for set in material_set.frequency_sets]
frequencies_train=[(set.frequency, len(set.T)) for set in train_material_set.frequency_sets]
frequencies_val=[(set.frequency, len(set.T)) for set in val_material_set.frequency_sets]
frequencies_test=[(set.frequency, len(set.T)) for set in test_material_set.frequency_sets]

frequencies = [f for f, _ in frequencies_total]
counts_total = np.array([c for _, c in frequencies_total])
counts_train = np.array([c for _, c in frequencies_train])
counts_val = np.array([c for _, c in frequencies_val])
counts_test = np.array([c for _, c in frequencies_test])

# Convert frequencies to positions for x-axis
x = np.arange(len(frequencies))



# Plot stacked train, val, test
plt.bar(x, counts_train, label='Train')
plt.bar(x, counts_val, bottom=counts_train, label='Val')
plt.bar(x, counts_test, bottom=counts_train + counts_val, label='Test')

# Plot the total as hatched outline
plt.bar(
    x, counts_total,
    color='none',
    edgecolor='gray',
    hatch='//',
    label='Total'
)
# Formatting
plt.xticks(x, [f"{freq:.0f} Hz" for freq in frequencies], rotation=45)
plt.ylabel("Number of Sequences")
plt.xlabel("Frequency")
plt.title("Sequence Count per Frequency")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
counts_total, bin_edges = jnp.histogram(material_set[1].T)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
# Compute train histogram and plot
counts_train, _ = jnp.histogram(train_material_set[1].T, bins=bin_edges)
plt.bar(bin_centers, counts_train, width=bin_edges[1] - bin_edges[0],
        label='Train')

# Compute validation histogram and stack on top of train
counts_val, _ = jnp.histogram(val_material_set[1].T, bins=bin_edges)
plt.bar(bin_centers, counts_val, width=bin_edges[1] - bin_edges[0],
        bottom=counts_train, label='Val')

# Compute test histogram and stack on top of train + val
counts_test, _ = jnp.histogram(test_material_set[1].T, bins=bin_edges)
plt.bar(bin_centers, counts_test, width=bin_edges[1] - bin_edges[0],
        bottom=counts_train + counts_val, label='Test')

# Plot total data histogram with transparency for comparison
plt.bar(
    bin_centers,
    counts_total,
    width=bin_edges[1] - bin_edges[0],
    color='none',              # No fill color
    edgecolor='gray',          # Outline color for hatching
    hatch='//',                # Diagonal hatch pattern
    label='Total'
)
# Plot formatting
plt.xlabel("Temperature")
plt.ylabel("Count")
plt.title("Stacked Histogram of Temperatures")
plt.legend()
plt.grid(True)
plt.show()

## Torch

In [None]:
from mc2.data_management import AVAILABLE_MATERIALS, load_data_into_pandas_df, book_keeping, get_train_val_test_pandas_dicts
jax.config.update("jax_enable_x64", True)

In [None]:
data_d = load_data_into_pandas_df(material="3C90")
training_data, validation_data, test_data = get_train_val_test_pandas_dicts(data_dict=data_d,train_frac=0.7, val_frac=0.15, test_frac=0.15,seed=13)
#training_data, validation_data, test_data = get_train_val_test_pandas_dict(material_name="3C90",train_frac=0.7, val_frac=0.15, test_frac=0.15,seed=13)

In [None]:
training_data["3C90_1_B"].to_numpy().shape

In [None]:
validation_data["3C90_1_B"].to_numpy().shape

In [None]:
test_data["3C90_1_B"].to_numpy().shape