# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
def foo(): pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
import datetime
import os
from pathlib import Path

import numpy as np
import pandas as pd
import scipy
import xarray as xr
from datatree import DataTree


def extinction_learning(datatree: DataTree) -> xr.Dataset:
    """Produce extinction learning results."""
    return xr.DataArray(data=np.ndarray(shape=(3)), dims=("Trial"))



In [None]:

def date_from_filename(filename: str) -> datetime.datetime:
    """Convert filename to datetime object.

    Args:
    ----
        filename (str): The filename to convert.

    Returns:
    -------
        datetime.datetime: The converted datetime object.
    """
    return datetime.datetime.strptime(
        filename[12:-4],
        "%d_%m_%Y_%H_%M",
    ).replace(tzinfo=datetime.UTC)



In [None]:

def session_attrs(filepath: Path) -> dict:
    """Extract session attributes from a file.

    Args:
    ----
        filepath (pathlib.Path): The path to the file.

    Returns:
    -------
        dict: The extracted session attributes.
    """
    c = scipy.io.loadmat(
        filepath,
        squeeze_me=True,
    )["c"]

    return {
        field_name: data
        for field_name, data in {
            field_name: c.item()[i] for i, field_name in enumerate(c.dtype.names)
        }.items()
        if not isinstance(data, np.ndarray)
    }


def dt_index_from_data_dir(data_dir: Path) -> pd.DatetimeIndex:
    """Create a pandas DatetimeIndex from the files in the specified data directory.

    Parameters
    ----------
        data_dir (Path): The path to the directory containing the data files.

    Returns
    -------
        pd.DatetimeIndex: A pandas DatetimeIndex object representing the timestamps
        extracted from the file names.
    """
    return pd.Index(
        pd.to_datetime(
            pd.DataFrame.from_records(
                [path.stem.split("_") for path in data_dir.glob("*.mat")],
                columns=["Experiment Label", "day", "month", "year", "hour", "minute"],
            )[["day", "month", "year", "hour", "minute"]],
        ),
        name="session_start_time",
    )



In [None]:
#| export
def sessions(data_dir: str) -> xr.Dataset:
    """Retrieve session data as a DataFrame."""
    data_home = Path(os.environ["XDG_DATA_HOME"])
    filenames = data_home / data_dir
    filepaths = filenames.glob("*.mat")
    return pd.DataFrame.from_records(
        [session_attrs(filepath) for filepath in filepaths],
        index=dt_index_from_data_dir(filenames),
    ).to_xarray()


def session(mat_file: Path) -> xr.DataArray:
    """Retrieve session data as a DataArray."""
    return xr.DataArray.from_dict(session_attrs(mat_file))