# Imports

In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd
from fitter import Fitter
import numpy as np
from pprint import pprint
from scipy import stats

# Functions

In [None]:
def trim_non_visible(s):
    """Strip if s is a string, otherwise return s."""
    if isinstance(s, str):
        return s.strip()
    return s


def read_data(file):
    """Reads the first 4 columns from the given file and drops empty rows."""
    data = pd.read_csv(
        file, delimiter=",", usecols=[0, 1, 2, 3], parse_dates=[[0, 1]]
    )
    data.columns = ["date", "m", "v"]
    data = data[data["date"] != "nan nan"]
    data = data.dropna(how="all")
    data["date"] = pd.to_datetime(data["date"])
    data = data.applymap(trim_non_visible)
    return data.sort_values(by=["date"])


def _get_time_differences(df):
    """Returns the time differences between rocks in hours."""
    return df["date"].diff().dt.total_seconds() / 3600


def add_time_differences(df):
    """Adds the time differences to the dataframe."""
    df["time_differences"] = _get_time_differences(df)
    return df


def add_energy(df):
    """Adds the energy to the dataframe."""
    df["e"] = 0.5 * df["m"] * df["v"] ** 2
    return df


def reorder_columns(df):
    """Reorders the columns of the dataframe."""
    cols = ["zone", "date", "time_differences", "m", "v", "e"]
    existing_columns = [col for col in cols if col in df.columns]
    return df[existing_columns]


def scatter_plot(
    df: pd.DataFrame,
    col: str,
    c="zone",
    colorbar=False,
    colormap="viridis",
    title=None,
):
    """Plots the given column of the given data frame as a scatter plot."""
    if title is None:
        title = f"{col.upper()} vs. Date"
    title = title + f"\nnumber of records: {len(df)}"
    df["date"] = pd.to_datetime(df["date"])
    ax = df.plot.scatter(
        x="date", y=col, c=c, colorbar=colorbar, colormap=colormap
    )
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    plt.title(
        f"{col} vs. date\nnumber of records: {len(df)}"
        if title is None
        else title
    )
    plt.xticks(rotation=90)


def plot_histogram(df: pd.DataFrame, col: str, zone: int, title: str = None):
    """Plots the given column of the given dataframe as a histogram."""
    if title is None:
        title = f"{col.upper()} for Zone {zone}"
    title = title + f"\nnumber of records: {len(df)}"
    df[col].hist(bins=np.sqrt(len(df[col])).astype(int) * 6)
    plt.xlabel(col.upper())
    plt.ylabel("Frequency")
    plt.title(title)
    plt.show()


def fit(values, distributions=None):
    """Fit distributions to the values and return the best fit."""
    if distributions is None:
        distributions = [
            stats.norm,
            stats.expon,
            stats.uniform,
            stats.gamma,
            stats.lognorm,
            stats.pareto,
            stats.weibull_min,
            stats.weibull_max,
        ]
    results = []

    for distribution in distributions:
        # Get parameters of the distribution (MLE)
        params = distribution.fit(values)

        # Compute the Kolmogorov-Smirnov test to assess the goodness of fit.
        _, p = stats.kstest(values, distribution.name, args=params)

        # Append results
        results.append((distribution, p, params))

    # Sort results by p-values (higher is better)
    results.sort(key=lambda x: x[1], reverse=True)

    return results


def plot_qq(values, fit_results, zone, col, num=5):
    """Plot Q-Q plot for the best fitting distributions."""
    for distribution, _, params in fit_results[:num]:
        plt.figure(figsize=(6, 6))
        stats.probplot(values, dist=distribution, sparams=params, plot=plt)
        plt.title(
            f"Q-Q Plot for {distribution.name} distribution vs {col.upper()} in zone {zone}"
        )
        plt.axis("equal")
        plt.show()

# EDA

In [None]:
data_1 = read_data("data/out_1.csv")
data_1["zone"] = 1
data_2 = read_data("data/out_2.csv")
data_2["zone"] = 2
# join the two dataframes and sort by date
df = pd.concat([data_1, data_2]).sort_values(by=["date"])

In [None]:
# summarize the data
print(df.describe())

## NAs, Zeros and Empty Strings

In [None]:
na_count = df.isna().sum()
zero_count = (df == 0).sum()
empty_string_count = (df == "").sum()

print("Number of NAs in each column:")
print(na_count)
print("\nNumber of zeros in each column:")
print(zero_count)
print("\nNumber of empty strings in each column:")
print(empty_string_count)

In [None]:
# summarize column 'm' for each zone
print("\n", df.groupby("zone")["m"].describe())

# replace zeros with median of the same zone
df["m"] = df.groupby("zone")["m"].transform(lambda x: x.replace(0, x.median()))

# summarize column 'm' for each zone
print("\n", df.groupby("zone")["m"].describe())

## Time Differences and Energy

In [None]:
# Add time differences and energy to the dataframes and reorder the columns.
# Also convert the zone column to a categorical variable.
# This is done so that the zone column is not used as a numerical variable.
df = reorder_columns(add_energy(add_time_differences(df)))
df["zone"] = df["zone"].astype("category")

## Write data to disk

In [None]:
# Save the zones into seperate files, so they can be compared to the original files.
df.to_csv("data/data.csv", index=False)

## Visualization
### Scatter

In [None]:
for col in ["m", "v", "e"]:
    scatter_plot(df, col)
scatter_plot(df, "e", c="m", colorbar=True)
plt.show()

### Histogram

In [None]:
for col in ["time_differences", "m", "v", "e"]:
    for zone in [1, 2]:
        plot_histogram(df[df["zone"] == zone], col, zone)

# Fit Distributions

MLE has been used to fit parameters and the Kolmogorov-Smirnov test has been used to determine goodness of fit as p value. Good strategy?

In [None]:
for col in ["time_differences", "m", "v", "e"]:
    for zone in [1, 2]:
        values = df[df["zone"] == zone][col].dropna()
        f = fit(values)
        print(f"Zone {zone}, column {col}")
        for distribution, p, params in f[:5]:
            print(f"{distribution.name}\n\tp: {p}\n\tparams: {params}")
        plot_qq(values, f, zone, col)