In [1]:
import os
import rootutils

%load_ext autoreload
%autoreload 2

rootutils.setup_root(os.path.abspath('./'), indicator=".project-root", pythonpath=True, dotenv=True, cwd=True)

PosixPath('/faststorage2/users/a.varlamov/cover_test')

In [2]:
import pandas as pd
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [3]:
def collate_energy_plots(
    base_dir: str,
    activation: str,
    dt_values: list[float],
    t1_values: list[float],
    nrows: int,
    ncols: int,
    output_path: str,
    figsize_scales: tuple[float, float] = (10, 6),
    img_name: str = 'energy_plot_10samples.png'
) -> None:
    """
    Collects energy_plot images into a grid for a given activation function.

    Parameters:
    - base_dir: path to the experiments folder
    - activation: activation function name (e.g., 'gelu')
    - dt_values: list of dt values in the order for rows
    - t1_values: list of t1 values in the order for columns
    - nrows: number of rows (e.g., len(dt_values))
    - ncols: number of columns (e.g., len(t1_values))
    - output_path: file path to save the composite image
    - img_name: name of the image file inside each run folder
    """
    # Create figure and axes
    fig, axes = plt.subplots(nrows, ncols, figsize=(figsize_scales[0]*ncols, figsize_scales[1]*nrows))
    # If only one row or col, ensure axes is 2D
    axes = axes.reshape((nrows, ncols))

    # Loop over grid positions
    for i, dt in enumerate(dt_values):
        for j, t1 in enumerate(t1_values):
            ax = axes[i, j]
            # Build search pattern for folder
            prefix = f"run_*_{activation}_t1={t1}_dt={dt}"
            # Find matching folders
            matches = [d for d in os.listdir(base_dir) if d.startswith('run_') and activation in d and f't1={t1}' in d and f'dt={dt}' in d]
            if matches:
                # Use first match
                folder = os.path.join(base_dir, matches[0])
                img_path = os.path.join(folder, img_name)
                if os.path.isfile(img_path):
                    img = Image.open(img_path)
                    ax.imshow(img)
                else:
                    # Missing image
                    ax.text(0.5, 0.5, 'Missing', ha='center', va='center')
            else:
                # Missing folder
                ax.text(0.5, 0.5, 'Missing', ha='center', va='center')

            # Hide ticks but keep labels
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)
            # Column titles for t1
            if i == 0:
                ax.set_title(f't1={t1}', pad=12, fontsize=20)
            # Row labels for dt
            if j == 0:
                ax.set_ylabel(f'dt={dt}', rotation=90, labelpad=16, fontsize=20)

    # Main title
    fig.suptitle(f'Activation: {activation}', fontsize=40)
    # plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.subplots_adjust(wspace=-0.05, hspace=-0.05)
    # Save
    fig.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    # plt.show()


In [4]:
dt_values = [0.01, 0.05, 0.1]
t1_values = [1.0, 3.0, 5.0, 10.0, 15.0]

collate_energy_plots(
    base_dir='experiments',
    activation='gelu',
    dt_values=dt_values,
    t1_values=t1_values,
    nrows=len(dt_values),
    ncols=len(t1_values),
    output_path='energy_plot_gelu.png',
    figsize_scales=(10, 6)
)

In [5]:
def collate_all_activations(
    base_dir: str,
    activations: list[str],
    dt_values: list[float],
    t1_values: list[float],
    output_dir: str,
    img_name: str = 'energy_plot_10samples.png'
) -> None:
    """
    Apply collate_energy_plots for multiple activation functions.

    Parameters:
    - base_dir: path to experiments folder
    - activations: list of activation names
    - dt_values: list of dt values for rows
    - t1_values: list of t1 values for columns
    - output_dir: folder to save composite images
    - img_name: filename inside each run folder
    """
    os.makedirs(output_dir, exist_ok=True)
    nrows = len(dt_values)
    ncols = len(t1_values)
    
    for act in tqdm(activations, desc='Collating activations'):
        out_path = os.path.join(output_dir, f'{act}.png')
        collate_energy_plots(
            base_dir=base_dir,
            activation=act,
            dt_values=dt_values,
            t1_values=t1_values,
            nrows=nrows,
            ncols=ncols,
            output_path=out_path,
            img_name=img_name
        )

In [6]:
collate_all_activations(
    base_dir='experiments',
    activations=["gelu", "softplus", "silu", "tanh", "leaky_relu", "elu", "relu", "mish", "squareplus", "sigmoid"],
    dt_values=[0.01, 0.05, 0.1],
    t1_values=[1.0, 3.0, 5.0, 10.0, 15.0],
    output_dir='energy_plots',
    img_name='energy_plot_10samples.png'
)

Collating activations:   0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
collate_all_activations(
    base_dir='experiments',
    activations=["gelu", "softplus", "silu", "tanh", "leaky_relu", "elu", "relu", "mish", "squareplus", "sigmoid"],
    dt_values=[0.01, 0.05, 0.1],
    t1_values=[1.0, 3.0, 5.0, 10.0, 15.0],
    output_dir='training_plots',
    img_name='training_plot.png'
)

Collating activations:   0%|          | 0/10 [00:00<?, ?it/s]

---
## 1e-4 LR:

In [9]:
collate_all_activations(
    base_dir='experiments',
    activations=["softplus", "tanh", "sigmoid"],
    dt_values=[0.01, 0.05, 0.1],
    t1_values=[5.0, 10.0, 15.0],
    output_dir='energy_plots',
    img_name='energy_plot_10samples.png'
)

Collating activations:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
collate_all_activations(
    base_dir='experiments',
    activations=["softplus", "tanh", "sigmoid"],
    dt_values=[0.01, 0.05, 0.1],
    t1_values=[5.0, 10.0, 15.0],
    output_dir='training_plots',
    img_name='training_plot.png'
)


Collating activations:   0%|          | 0/3 [00:00<?, ?it/s]