In [1]:
import datetime
import os
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import re

num_epochs = 30
batch_size = [64, 256]
lr = 0.1
height = 28
num_sum_units, num_input_units = 8, 8
input_dim = 784
data = 'mnist'
lam = [0, 0.1, 1.0]
patience = 5
#hyperparams = f"S{seed}_E{num_epochs}_BS{batch_size}_LR{lr:.0e}_H{height}_NSM{num_sum_units}_NIM{num_input_units}_ID{input_dim}_{data}_LAM{lam}_PAT{patience}"
result_dir = 'results_PC_grid'
pattern = re.compile(
    r"(?P<datetime>\d{8}_\d{6})_"
    r"S(?P<seed>\d+)_"
    r"E(?P<num_epochs>\d+)_"
    r"BS(?P<batch_size>\d+)_"
    r"LR(?P<lr>[\de.-]+)_"
    r"H(?P<height>\d+)_"
    r"NSM(?P<num_sum_units>\d+)_"
    r"NIM(?P<num_input_units>\d+)_"
    r"ID(?P<input_dim>\d+)_"
    r"(?P<data>[^_]+)_"
    r"LAM(?P<lam>[\de.-]+)_"
    r"PAT(?P<patience>\d+)"
)

directories = os.listdir(result_dir)
data_dict = defaultdict(lambda: defaultdict(dict))

# Extract nll_val from directories and organize by batch_size and lam
for directory in directories:
    match = pattern.match(directory)
    if match:
        params = match.groupdict()
        batch_size = int(params['batch_size'])
        lam = float(params['lam'])
        config = (
            params['num_epochs'],
            batch_size,
            params['lr'],
            params['height'],
            params['num_sum_units'],
            params['num_input_units'],
            params['input_dim'],
            params['data'],
            lam,
            params['patience']
        )
        if batch_size in [64, 256] and lam in [0, 0.1, 1.0]:
            results_path = os.path.join(result_dir, directory, 'results.json')
            if os.path.exists(results_path):
                with open(results_path, 'r') as f:
                    results = json.load(f)
                    nll_val = results.get('nll_val')
                    if nll_val:
                        seed = int(params['seed'])
                        data_dict[batch_size][lam][seed] = nll_val

# Create scatter plots in a 2x2 grid
fig, axs = plt.subplots(2, 2, figsize=(15, 10))

batch_sizes = [64, 256]
lams = [0.1, 1.0]

for i, batch_size in enumerate(batch_sizes):
    for j, lam in enumerate(lams):
        ax = axs[i, j]
        seeds = set(data_dict[batch_size][0].keys()) & set(data_dict[batch_size][lam].keys())
        for seed in seeds:
            nll_val_0 = data_dict[batch_size][0][seed]
            nll_val_lam = data_dict[batch_size][lam][seed]
            ax.scatter(nll_val_0, nll_val_lam, label=f'Seed {seed}')
        ax.set_title(f'Batch Size {batch_size}, Lambda 0 vs Lambda {lam}')
        ax.set_xlabel('nll_val (Lambda 0)')
        ax.set_ylabel(f'nll_val (Lambda {lam})')
        ax.legend()
        ax.grid(True)

