In [None]:
from glob import glob
import pandas as pd
import xarray as xr

import contextlib
import os
import pathlib
import re
import shutil
import time
from datetime import datetime
from typing import Union

import joblib
import numpy as np
import pandas as pd
import xarray as xr
from joblib import Parallel, delayed
from tqdm import tqdm

from pandas.testing import assert_frame_equal

from pismragis.processing import convert_netcdf_to_dataframe, ncfile2dataframe
from pismragis.analysis import prepare_df, sensitivity_analysis


In [None]:
infiles = glob("../tests/data/ts_gris_g1200m_v2023_RAGIS_id_*_1980-1-1_2020-1-1.nc")
df = convert_netcdf_to_dataframe(infiles, add_vars=False)
df.to_parquet("../tests/data/test_scalar.parquet")
df.to_csv("../tests/data/test_scalar.csv")

In [None]:
df = convert_netcdf_to_dataframe(infiles, add_vars=False, resample="yearly")
df = df[df.time.between("1980-01-01", "1982-01-01")].reset_index(drop=True)
df.to_parquet("../tests/data/test_scalar_YM.parquet")
df.to_csv("../tests/data/test_scalar_YM.csv")


In [None]:
infile = "../tests/data/ts_gris_g1200m_v2023_RAGIS_id_0_1980-1-1_2020-1-1.nc"

In [None]:
df = ncfile2dataframe(infile, add_vars=False)
df.to_parquet("../tests/data/test_scalar_file.parquet")
df.to_csv("../tests/data/test_scalar_file.csv")

In [None]:
df = ncfile2dataframe(infile, add_vars=True)
df.to_parquet("../tests/data/test_scalar_file_add_vars.parquet")
df.to_csv("../tests/data/test_scalar_file_add_vars.csv")

In [None]:
n_jobs = 1

In [None]:
    X_df = (
        pd.read_parquet("../tests/data/test_scalar_YM.parquet")
        .drop(columns=["Year", "resolution_m"])
        .sort_values(by=["time", "id"])
    )
    ensemble_file = "../tests/data/gris_ragis_ocean_simple_lhs_50_w_posterior.csv"
    Y_df = sensitivity_analysis(
        X_df, ensemble_file=ensemble_file, n_jobs=n_jobs, seed=seed
    )
    Y_df.to_parquet("../tests/data/test_sensitivity.parquet")


In [None]:
    Y_true_df = pd.read_parquet("../tests/data/test_sensitivity.parquet")
    Y_df = sensitivity_analysis(
        X_df, ensemble_file=ensemble_file, n_jobs=n_jobs, seed=seed
    )
    assert_frame_equal(Y_df, Y_true_df, rtol=1e-04)


In [None]:
    ds = xr.open_mfdataset(
        "../tests/data/ts_gris_g1200m_v2023_RAGIS_id_*.nc",
        combine="nested",
        concat_dim="id",
        preprocess=nc_add_id,
        parallel=True,
    )
    ens = (
        ds.sel(time=slice("1980-01-01", "1982-01-01"))[ens_vars_dict.keys()]
        .resample(time="1AS")
        .mean()
    )
    X_xr_df = (
        ens.to_dataframe()
        .rename(columns=ens_vars_dict)
        .reset_index()
        .dropna()
        .sort_values(by=["time", "id"])
        .reset_index(drop=True)
    )

    Y_xr_df = sensitivity_analysis(
        X_xr_df, ensemble_file=ensemble_file, n_jobs=n_jobs, seed=seed
    )
    assert_frame_equal(Y_xr_df, Y_true_df, atol=1e-01, rtol=1e-06)


In [None]:
    infile = "../tests/data/ts_gris_g1200m_v2023_RAGIS_id_0_1980-1-1_2020-1-1.nc"

    df_parquet_true = pd.read_parquet("../tests/data/test_scalar_file.parquet")
    df_csv_true = pd.read_csv(
        "../tests/data/test_scalar_file.csv",
        index_col=0,
        infer_datetime_format=True,
        parse_dates=["time"],
    )

    df = ncfile2dataframe(infile, add_vars=False)


In [None]:
    assert_frame_equal(df, df_parquet_true)


In [None]:
df_parquet_true

In [None]:
Y_true.to_parquet("tests/data/test_sensitivity.parquet")

In [None]:
ds = xr.open_mfdataset("tests/data/ts_gris_g1200m_v2023_RAGIS_id_*.nc", combine="nested", concat_dim="id", parallel=True)

In [None]:
ens_vars = ["grounding_line_flux", "limnsw"]
ens = ds.sel(time=slice("1980-01-01" ,"1983-01-01"))[ens_vars].resample(time="1AS").mean()
ens_df = ens.to_dataframe().reset_index().dropna()

In [None]:
ens_vars

In [None]:
n_jobs=4

Y_true_xr = sensitivity_analysis(ens_df, ensemble_file=ensemble_file, n_jobs=n_jobs, calc_variables=ens_vars)[sens_vars]

In [None]:
    df_parquet_true = pd.read_parquet("../tests/data/test_scalar_file_YM.parquet")
    df_csv_true = pd.read_csv(
        "../tests/data/test_scalar_file_YM.csv",
        index_col=0,
        infer_datetime_format=True,
        parse_dates=["time"],
    )

    df = ncfile2dataframe(infile, resample="yearly", add_vars=False)
    assert_frame_equal(df, df_parquet_true)
    assert_frame_equal(df, df_csv_true)


In [None]:
def ncfile2dataframe(
    infile: Union[str, pathlib.Path],
    resample: Union[str, None] = None,
    add_vars: bool = True,
    norm_year: Union[None, float] = None,
    verbose: bool = False,
) -> pd.DataFrame:
    """Convert netCDF file to pandas.DataFrame"""

    if isinstance(infile, pathlib.Path):
        assert infile.exists()
    else:
        assert os.path.isfile(infile)
    if verbose:
        print(f"Opening {infile}")
    with xr.open_dataset(infile) as ds:
        if resample == "monthly":
            ds = ds.resample(time="1MS").mean()
        elif resample == "yearly":
            ds = ds.resample(time="1YS").mean()
        else:
            pass
        if isinstance(infile, pathlib.Path):
            m_id_re = re.search("id_(.+?)_", str(infile))
        else:
            m_id_re = re.search("id_(.+?)_", infile)
        assert m_id_re is not None
        m_id: Union[str, int]
        try:
            m_id = int(m_id_re.group(1))
        except:
            m_id = str(m_id_re.group(1))

        if isinstance(infile, pathlib.Path):
            m_dx_re = re.search("gris_g(.+?)m", str(infile))
        else:
            m_dx_re = re.search("gris_g(.+?)m", infile)
        assert m_dx_re is not None
        m_dx = int(m_dx_re.group(1))
        datetimeindex = ds.indexes["time"]
        years = [to_decimal_year(x.to_pydatetime()) for x in datetimeindex]
        nt = len(datetimeindex)
        id_S = pd.Series(data=np.repeat(m_id, nt), index=datetimeindex, name="id")
        S = [id_S]
        for m_var in ds.data_vars:
            if m_var not in (
                "time_bounds",
                "time_bnds",
                "timestamp",
                "run_stats",
                "pism_config",
            ):
                if hasattr(ds[m_var], "units"):
                    m_units = ds[m_var].units
                    m_S_name = f"{m_var} ({m_units})"
                else:
                    m_units = ""
                    m_S_name = f"{m_var}"
                data = np.squeeze(ds[m_var].values)
                m_S = pd.Series(data=data, index=datetimeindex, name=m_S_name)
                S.append(m_S)
        m_Y = pd.Series(data=years, index=datetimeindex, name="Year")
        S.append(m_Y)
        df = pd.concat(S, axis=1).reset_index()
        df["resolution_m"] = m_dx
        
        if add_vars:
            df = add_vars_to_dataframe(df)
        
        if norm_year:
            norm_year_idx = np.nonzero(np.array(years) == norm_year)[0][0]
            df["limnsw (kg)"] -= df["limnsw (kg)"][norm_year_idx]
            if add_vars:
                df["Cumulative ice sheet mass change (Gt)"] += df[
                    "Cumulative ice sheet mass change (Gt)"
                ][norm_year_idx]
                df["SLE (cm)"] += df["SLE (cm)"][norm_year_idx]

    return df


def add_vars_to_dataframe(df: pd.DataFrame):
    """Add additional variables to DataFrame"""

    if "limnsw (kg)" in df.columns:
        df["Cumulative ice sheet mass change (Gt)"] = (
            df["limnsw (kg)"] - df["limnsw (kg)"][0]
        ) / 1e12
        df["SLE (cm)"] = df["Cumulative ice sheet mass change (Gt)"] * gt2cmsle
        if "grounding_line_flux (Gt year-1)" in df.columns:
            df["Rate of ice discharge (Gt/yr)"] = -df["grounding_line_flux (Gt year-1)"]
        if "tendency_of_ice_mass_due_to_surface_mass_flux (Gt year-1)" in df.columns:
            df["Rate of surface mass balance (Gt/yr)"] = df[
                "tendency_of_ice_mass_due_to_surface_mass_flux (Gt year-1)"
            ]
    return df


In [None]:
    df_parquet_true = pd.read_parquet("../tests/data/test_scalar_file.parquet")
    df_csv_true = pd.read_csv(
        "../tests/data/test_scalar_file.csv",
        index_col=0,
        infer_datetime_format=True,
        parse_dates=["time"],
    )

    df = ncfile2dataframe(infile, add_vars=False)
    assert_frame_equal(df, df_parquet_true)
    assert_frame_equal(df, df_csv_true)

    df_parquet_true = pd.read_parquet("../tests/data/test_scalar_file_YM.parquet")
    df_csv_true = pd.read_csv(
        "../tests/data/test_scalar_file_YM.csv",
        index_col=0,
        infer_datetime_format=True,
        parse_dates=["time"],
    )

    df = ncfile2dataframe(infile, resample="yearly", add_vars=False)
    assert_frame_equal(df, df_parquet_true)
    assert_frame_equal(df, df_csv_true)


In [None]:
infile

In [None]:
    df_parquet_true = pd.read_parquet("../tests/data/test_scalar_file_add_vars.parquet")


In [None]:
df_parquet_true

In [None]:
    df = ncfile2dataframe(infile, add_vars=True, norm_year=1992.0)


In [None]:
df

In [None]:
df_parquet_true

In [None]:
def convert_netcdf_to_dataframe(
    infiles: list,
    resample: Union[str, None] = None,
    n_jobs: int = 4,
    add_vars: bool = True,
    norm_year: Union[None, float] = None,
    verbose: bool = False,
) -> pd.DataFrame:
    """
    Convert list of netCDF files to Pandas DataFrame.


    """
    n_files = len(infiles)
    print("Converting netcdf files to pandas.DataFrame")
    print("-------------------------------------------")
    start_time = time.perf_counter()
    with tqdm_joblib(tqdm(desc="Processing files", total=n_files)) as progress_bar:
        result = Parallel(n_jobs=n_jobs)(
            delayed(ncfile2dataframe)(infile, resample, add_vars, norm_year, verbose)
            for infile in infiles
        )
        del progress_bar
    finish_time = time.perf_counter()
    time_elapsed = finish_time - start_time
    print(f"Conversion finished in {time_elapsed:.0f} seconds")
    print("-------------------------------------------")

    df = pd.concat(result)

    return df.sort_values(by=["time", "id"]).reset_index(drop=True)


In [None]:
    df_parquet_true = pd.read_parquet("../tests/data/test_scalar_add_vars.parquet")
    infiles = glob("../tests/data/ts_gris_g1200m_v2023_RAGIS_id_*_1980-1-1_2020-1-1.nc")

    df = convert_netcdf_to_dataframe(infiles, add_vars=True, norm_year=1992.0)


In [None]:
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument"""

    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        """TQDM Callback"""

        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()


In [None]:
df

In [None]:
    infiles = glob("../tests/data/ts_gris_g1200m_v2023_RAGIS_id_*_1980-1-1_1982-1-1.nc")
    df = convert_netcdf_to_dataframe(infiles, resample="yearly", add_vars=False)
