In [None]:
from pathlib import Path
import datetime
import shutil
import pandas as pd

from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

In [None]:
def save_figure(fig, path, dpi=300):
    """Save a figure to a file."""
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(path, dpi=dpi)

In [None]:
# INPUTS
INPUT_FILE = "../03_data_complexity/01_extracted_features/internal_datasets/all_features_cropped_truncated.csv"
OUTPUT_FOLDER = ""
FEATURES = ["HoG", "Rank #unique colors"]

# Starting dates for different datasets
STARTING_DATES = {
    'accessions_dataset1': datetime.datetime(year=2022, month=5, day=11, hour=0, minute=0, second=0),
    'accessions_dataset2': datetime.datetime(year=2022, month=8, day=2, hour=0, minute=0, second=0),
}

# Special y-axis limits for specific data
SPECIAL_Y_LIM_DATA = {
    ('HoG', 'accessions_dataset2'): [4.6, 4.85],
}

####################################################################################################

# Convert paths to Path objects
data_path = Path(INPUT_FILE)
output_folder = Path(OUTPUT_FOLDER)

# Remove the output folder if it exists
if output_folder.exists():
    shutil.rmtree(output_folder)

# Read the data from CSV
data = pd.read_csv(data_path)

# Extract datetime components from filenames
filenames = [name.split('-')[0].split('_')[1:-1] for name in data['filename'].values]
datetimes = [
    datetime.datetime(
        int(name[0]), int(name[1]), int(name[2]), 0, 0, 0
    ) for name in filenames
]
data['datetimes'] = datetimes

# Sort data by datetimes and reset the index
data = data.sort_values(by='datetimes').reset_index(drop=True)

# Add a string representation of datetimes
data['datetimes_str'] = data['datetimes'].apply(lambda x: str(x))

# Iterate over groups of datasets, classes, and replicates
for dataset_name, dataset_data in data.groupby('dataset'):
    for class_name, class_data in dataset_data.groupby('class'):
        for replicate_name, replicate_data in class_data.groupby('replicate'):
            
            starting_date = replicate_data['datetimes'].min()
            
            if dataset_name == 'accessions_dataset1':
                # Ensure the starting date matches the expected date for dataset1
                if starting_date == STARTING_DATES[dataset_name]:
                    data.loc[replicate_data.index, 'datetimes_corrected'] = replicate_data['datetimes']
                else:
                    raise ValueError('Wrong starting date for accessions_dataset1')
            
            elif dataset_name == 'accessions_dataset2':
                # Ensure the starting date matches the expected date for dataset2 or is one week later
                if starting_date == STARTING_DATES[dataset_name]:
                    data.loc[replicate_data.index, 'datetimes_corrected'] = replicate_data['datetimes']
                elif starting_date == STARTING_DATES[dataset_name] + datetime.timedelta(days=7):
                    data.loc[replicate_data.index, 'datetimes_corrected'] = replicate_data['datetimes'] - datetime.timedelta(days=7)
                else:
                    raise ValueError('Wrong starting date for accessions_dataset2')
                    
# Sort data by corrected datetimes and reset the index
data = data.sort_values(by='datetimes_corrected').reset_index(drop=True)

# Add a string representation of corrected datetimes
data['datetimes_corrected_str'] = data['datetimes_corrected'].apply(lambda x: str(x))


In [None]:
def plot_timeseries_pale(dataset_name, ds_data, feature, level="class", transparency=1, special_y_lim=None):
    
    # Group data by the specified level and datetimes, and calculate the mean
    g_data = ds_data[[level, 'datetimes_corrected_str', feature]].groupby([level, 'datetimes_corrected_str']).mean().reset_index()
    
    # Calculate the mean and standard deviation across groups
    gm_data = g_data[['datetimes_corrected_str', feature]].groupby(['datetimes_corrected_str']).mean()
    gm_data.rename(columns={feature: f'Orig. {feature} Mean'}, inplace=True)
    
    gs_data = g_data[['datetimes_corrected_str', feature]].groupby(['datetimes_corrected_str']).std()
    gs_data.rename(columns={feature: f'Orig. {feature} Std'}, inplace=True)
    
    # Join the mean and standard deviation data
    j_data = gm_data.join(gs_data)
    n_std = 2
    j_data[f'Orig. {feature} Mean - {n_std} x Std'] = j_data[f'Orig. {feature} Mean'] - n_std * j_data[f'Orig. {feature} Std']
    j_data[f'Orig. {feature} Mean + {n_std} x Std'] = j_data[f'Orig. {feature} Mean'] + n_std * j_data[f'Orig. {feature} Std']
    j_data.reset_index(inplace=True)
    
    # Apply Gaussian filter for smoothing
    sigma = j_data.shape[0] / 10
    j_data[f'{feature} Mean'] = gaussian_filter1d(j_data[f'Orig. {feature} Mean'], sigma)
    j_data[f'{feature} Mean - {n_std} x Std'] = gaussian_filter1d(j_data[f'Orig. {feature} Mean - {n_std} x Std'], sigma)
    j_data[f'{feature} Mean + {n_std} x Std'] = gaussian_filter1d(j_data[f'Orig. {feature} Mean + {n_std} x Std'], sigma)

    # Create plot
    fig, ax = plt.subplots(figsize=(20, 10), dpi=300)
    
    # Set y-axis limits if specified
    if (feature, dataset_name) in special_y_lim:
        plt.ylim(special_y_lim[(feature, dataset_name)])

    # Plot each group's data with transparency
    hue_order = sorted(ds_data[level].unique())
    sns.lineplot(data=ds_data, x='datetimes_corrected_str', y=feature, hue=level, hue_order=hue_order, ax=ax, errorbar=None)
    for line in ax.lines:
        line.set_alpha(transparency)

    # Plot the mean and standard deviation lines
    sns.lineplot(data=j_data, x='datetimes_corrected_str', y=f'{feature} Mean',
                 color='red', dashes=[10, 10], linewidth=5, ax=ax, errorbar=None)
    sns.lineplot(data=j_data, x='datetimes_corrected_str', y=f'{feature} Mean - {n_std} x Std',
                 color='black', linewidth=5, ax=ax, errorbar=None)
    sns.lineplot(data=j_data, x='datetimes_corrected_str', y=f'{feature} Mean + {n_std} x Std',
                 color='black', linewidth=5, ax=ax, errorbar=None)
    
    # Format x-axis labels
    xticklabels = ds_data['datetimes_corrected_str'].unique()
    xticklabels = [x.split(' ')[0] for x in xticklabels]
    ax.xaxis.set_ticks(range(len(xticklabels)))
    ax.set_xticklabels(xticklabels, rotation=45)

    # Set labels and title
    ax.set_xlabel('Date', fontsize=14)
    ax.set_ylabel(feature, fontsize=14)
    ax.set_title(f'{feature} over time', fontsize=18, fontweight='bold')

    # Add grid and legend
    ax.grid()
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

    return fig

In [None]:
# Loop through each dataset in the data
for dataset_name in tqdm(data['dataset'].unique(), desc="Processing datasets..."):
    # Filter dataset-specific data
    dataset_data = data[data['dataset'] == dataset_name]

    # Loop through each feature for the dataset
    for feature in tqdm(FEATURES, desc="Processing class features...", leave=False):
        # Generate the plot for the current dataset and feature
        fig = plot_timeseries_pale(dataset_name, dataset_data, feature, level="class", transparency=0.3, special_y_lim=SPECIAL_Y_LIM_DATA)
        
        # Create a filename for the saved image
        feature_filename = '_'.join([x.lower() for x in feature.split(' ')])
        save_path = output_folder / dataset_name / f"{feature_filename}.png"
        
        # Save the plot and close the figure
        save_figure(fig, save_path)
        plt.close()