# Show the training progress

It is interesting to visually see how the training loss changes during training.
This notebook provides a simple way to visualize the training progress.
It assumes that we are recording the metrics during training into ".csv" files that can be read at anytime.

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Import libraries

In [14]:
import os
from pprint import pprint
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils.makepath import makepath as mkp

## Specify the paths to the training logs

First, make sure we know the relative path to the root of the project.

In [15]:
# Root directory of the project is two levels up from this notebook.
# Change this if the notebook is moved.
num_levels_up = 2

root_dir = mkp(".")
for _ in range(num_levels_up):
    root_dir = mkp(root_dir, "..")

os.listdir(root_dir)

['scripts',
 'requirements.txt',
 'mri.egg-info',
 'venv',
 'README.md',
 'figures',
 'config',
 'LICENSE',
 'utils',
 'networks',
 'tmp',
 'dyn_mri_test.py',
 '.gitignore',
 'gradops',
 'pyproject.toml',
 'gifs',
 'data',
 'pdhg',
 'data_lib',
 'wandb',
 'encoding_objects',
 '.git']

In [16]:
model_save_dir = mkp(
    root_dir,
    "tmp",
    # "example_model"
    # "mri_model_09_10-14_09"
    "mri_model_09_12-23_02-good_TGV-sigma_to_0_2-R_from_4"
    # "mri_model_09_14-14_37-good_TV-sigma_to_0_2-R_from_4"
)

# Get the metric log CSV files
csv_files = [
    # This is helpful when training with large epochs.
    # "train_intermediate_metrics.csv",

    "train_epoch_metrics.csv",
    "val_epoch_metrics.csv",
]
csv_filepaths = [mkp(model_save_dir, f) for f in csv_files]
pprint(csv_filepaths)

[PosixPath('../../tmp/mri_model_09_12-23_02-good_TGV-sigma_to_0_2-R_from_4/train_epoch_metrics.csv'),
 PosixPath('../../tmp/mri_model_09_12-23_02-good_TGV-sigma_to_0_2-R_from_4/val_epoch_metrics.csv')]


## Read the training logs and plot the training progress

Read the log files and plot the metrics to show the training progress.
We assums that the metrics are recorded during training into ".csv" files that can be read at anytime.

In [17]:
# Function to plot graphs from a CSV file
def plot_graphs_from_csv(file_path):
    # Load the CSV file
    df = pd.read_csv(file_path)

    # Extract the base name of the file without extension for titles
    base_name = os.path.basename(file_path).split('.')[0]

    # Create subplots
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=(
            'Loss (should go down)',
            'PSNR (should go up)',
            'SSIM (should go up)'
        )
    )

    # Add traces for loss
    fig.add_trace(go.Scatter(y=df['loss'], mode='lines', name='Loss'), row=1, col=1)

    # Add traces for PSNR
    fig.add_trace(go.Scatter(y=df['psnr'], mode='lines', name='PSNR'), row=1, col=2)

    # Add traces for SSIM
    fig.add_trace(go.Scatter(y=df['ssim'], mode='lines', name='SSIM'), row=1, col=3)

    # Update layout
    fig.update_layout(
        title_text=f'{base_name}: Metrics over Epochs',
        xaxis_title='Epochs',
        yaxis_title='Value',
        height=360,
        width=1000
    )

    # Show the figure
    fig.show()

# Function to read files and plot
def read_and_plot():
    for file_path in csv_filepaths:
        plot_graphs_from_csv(file_path)

Please re-run the following cell whenever you want to see the latest training progress.
Please note that there are three metrics: Loss (which in this case is Mean Squared Error (MSE)), 
Peak Signal-to-Noise Ratio (PSNR), and Structural Similarity Index (SSIM).
Ideally, Loss should decrease over time, while PSNR and SSIM should increase.

There are three rows of plots. The top row shows the intermediate results. 
This might be empty if the epoch is short and we did not set the intermediate intervals to be smaller than one epoch.
The middle row shows the metrics for the training set by epoch, 
while the bottom row shows the metrics for the validation set also by epoch.

In [18]:
# Re-run this line to update the plots
read_and_plot()