In [1]:
import json
import os
import glob
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

### Plots for the 15 Configurations

In [None]:
validation_mode = True
# you may change the path as needed
path = f"../MemoryEfficientSharding/results/*ep100*"
visualizations_path = "visualizations"
os.makedirs(visualizations_path, exist_ok=True)

files = glob.glob(path)

for file in files:
    
    filename = file.split("/")[-1]
    save_filename = f"{filename[:-5]}.pdf"
    batch_size = list(filter(lambda x: "bs" in x, filename.split('_')))[0]
    batch_size = int(batch_size[2:])

    with open(file, "r") as f:
        results = json.load(f)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x = results["epochs"],
            y = results["train_rmse"],
            mode = "lines+markers",
            yaxis="y",
            name="Train",
            marker_size=3
        )
    )

    if validation_mode:
        fig.add_trace(
            go.Scatter(
                x = results["epochs"],
                y = results["val_rmse"],
                mode = "lines+markers",
                yaxis="y",
                name="Validation",
                marker_size=3
            )
        )

    fig.update_layout(
        plot_bgcolor = "white",
        title = dict(
                    text= f"Batch Size: {batch_size}",
                    xanchor = "center",
                    x=0.5
                ),
        legend=dict(
            yanchor = "top",
            y = 0.99,
            xanchor = "right",
            x = 0.99
        ),
        xaxis=dict(
            title=dict(text="Epochs", font=dict(size=16)),
            # dtick=10,
            gridcolor="lightgrey",
            tickfont=dict(size=16),
            range = [min(results["epochs"])-1, max(results["epochs"])+3]
        ),
        yaxis=dict(
            title=dict(text=r"RMSE", font=dict(size=16)),
            gridcolor="lightgrey",
            zeroline=False,
            tickfont=dict(size=16),
        ),
        # yaxis_type = "log",
        height = 350,
        width = 500,
    )

    fig.write_image(f"{visualizations_path}/{save_filename}")


### Plot of the Improved Model

In [None]:
validation_mode = True
# you may change the path as needed
path = f"../MemoryEfficientSharding/results/*ep2500_with_validation*"

file = glob.glob(path)[0]
    
filename = file.split("/")[-1]
save_filename = f"{filename[:-5]}.pdf"
batch_size = list(filter(lambda x: "bs" in x, filename.split('_')))[0]
batch_size = int(batch_size[2:])

with open(file, "r") as f:
    results = json.load(f)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x = results["epochs"],
        y = results["train_rmse"],
        mode = "lines+markers",
        yaxis="y",
        name="Train",
        marker_size=3
    )
)

if validation_mode:
    fig.add_trace(
        go.Scatter(
            x = results["epochs"],
            y = results["val_rmse"],
            mode = "lines+markers",
            yaxis="y",
            name="Validation",
            marker_size=3
        )
    )

fig.update_layout(
    plot_bgcolor = "white",
    title = dict(
                text= f"Improved Model's Learning Curve",
                xanchor = "center",
                x=0.5
            ),
    legend=dict(
        yanchor = "top",
        y = 0.99,
        xanchor = "right",
        x = 0.99
    ),
    xaxis=dict(
        title=dict(text="Epochs", font=dict(size=16)),
        # dtick=10,
        gridcolor="lightgrey",
        tickfont=dict(size=16),
        range = [min(results["epochs"])-1, max(results["epochs"])+3]
    ),
    yaxis=dict(
        title=dict(text=r"RMSE", font=dict(size=16)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=16),
    ),
    # yaxis_type = "log",
    height = 350,
    width = 500,
)

fig.write_image(f"{visualizations_path}/{save_filename}")
fig.show()