In [None]:
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def extract_data(checkpoint):
    """
    returns a dict
    {
        'n_conv_list':
        'kernel':
        'initial_filters':
        'val_acc_array':
        'best_acc':
    }
    """
    hp = {}
    conv_list = checkpoint['model_config']['config']['conv']
    
    hp['kernel'] = conv_list[0][2][0]
    hp['initial_filters'] = conv_list[0][1][0]
    hp['n_conv'] = []
    for i in conv_list: hp['n_conv'].append(len(i[0]))
        
    val_acc = checkpoint['val_acc']
    # clean up the list
    temp_array = []
    for i in range(len(val_acc)):
        
        temp_array.append([val_acc[i][0], val_acc[i][1], torch.tensor(val_acc[i][2]).mean().item()])
    temp_array = np.array(temp_array)
    
    n_batch = np.int32(np.max(temp_array[:, 1])) + 1
    
    hp['val_acc_array'] = temp_array[np.arange(0,498*n_batch + 1, n_batch),:]
    hp['best_acc'] = np.mean(np.flip(np.sort(hp['val_acc_array'][:, -1]))[0:10])
    
    return hp
        



In [None]:
checkpoints = []

directories = list((Path.home()/'group'/'project'/'scripting'/'output').glob('220728_hyper*'))
for directory in directories:
    checkpoints.extend(
        list((directory.glob('*/*.tar')))
    )
    
print(len(checkpoints))



In [None]:
hp_data = []

for i in range(len(checkpoints)):
    hp_data.append(extract_data(torch.load(checkpoints[i], map_location=torch.device('cpu'))))
    if i % 100 == 99:
        print(i+1)

In [None]:
hp_data = pd.DataFrame(hp_data)

In [None]:
hp_data.sort_values('best_acc', ascending=False).iloc[0:30]

In [None]:
checkpoint = torch.load(checkpoints[0], map_location=torch.device('cpu'))
print(checkpoint.keys())
print(checkpoint['model_config'])

In [None]:
extract_data(checkpoint)