In [None]:
import sys, os
import pickle
import multiprocessing as mp
import numpy as np
import pandas as pd
from src.neuron import *
from src.utils import *
from src.constants import * 
from src.network import *
from src.validation import *
from src.viz import *
from src.genetic_algorithm import *


In [None]:
# ~~Check best plot from each pkl file~~
# Go through all the pkl files in the data directory and find the dna with the highest dna_score.
# Then, print the dna and the dna_score.
# Then, run the network with that dna score, print the results (including the new DNA score), and plot the them. 
# May not work in .ipynb because of multiprocessing.

def process_pkl_file(pkl_path):
    try:
        # Load the pickle file
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
        
        # Flatten the data structure
        flattened_data = []
        for generation, runs in data.items():
            for run in runs:
                if isinstance(run, dict) and 'dna_score' in run:
                    flattened_data.append(run)
                else:
                    print(f"Unexpected format in file {pkl_path}: {run}")
        
        # Find DNA with highest score
        best_dna = max(flattened_data, key=lambda x: x['dna_score'])
        print(f"Best DNA: {best_dna}")
        print(f"Best DNA score: {best_dna['dna_score']}")
        # Get DNA matrix
        dna_matrix = load_dna(best_dna['dna'])
        
        # Prepare network components
        all_neurons = create_neurons()
        splits, input_waves, alpha_array = create_experiment()
        criteria_dict = define_criteria()
        max_score = TMAX // BIN_SIZE * len(CRITERIA_NAMES)
        
        # Evaluate DNA
        dna_score, neuron_data = evaluate_dna(
            dna_matrix=dna_matrix,
            neurons=all_neurons,
            alpha_array=alpha_array,
            input_waves=input_waves,
            criteria=criteria_dict,
            curr_dna=best_dna['dna']
        )
        
        total_score = sum(dna_score.values())
        
        return {
            'file': os.path.basename(pkl_path),
            'dna': best_dna['dna'],
            'scores': dna_score,
            'total_score': total_score,
            'max_score': max_score,
            'neuron_data': neuron_data,
            'input_waves': input_waves
        }
    except Exception as e:
        print(f"Error processing file {pkl_path}: {e}")
        return None

if __name__ == '__main__':
    # Get all pkl files in data directory
    print("Starting to process pkl files")
    data_dir = 'data'
    pkl_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pkl')]
    
    # Process files in parallel
    with mp.Pool() as pool:
        results = pool.map(process_pkl_file, pkl_files)
    
    # Filter out None results (files that caused errors)
    results = [result for result in results if result is not None]
    
    # Print results and generate plots for each file
    for result in results:
        print(f"\nResults for {result['file']}:")
        print(f"    === DNA: {result['dna']}")
        print(f"    === Control: {result['scores']['control']}/{result['max_score']}")
        print(f"    === Experimental: {result['scores']['experimental']}/{result['max_score']}")
        print(f"    === Overall: {result['total_score']}({result['total_score']/(2*result['max_score']):.2%})")
        
        # # Plot results
        # for condition in ['experimental', 'control']:
        #     target_neurons_hist_Vs = np.array([result['neuron_data'][condition][name]['hist_V'] for name in NEURON_NAMES])
        #     plot_neurons_interactive(
        #         hist_Vs=target_neurons_hist_Vs, 
        #         neuron_names=NEURON_NAMES, 
        #         sq_wave=result['input_waves'][0], 
        #         go_wave=result['input_waves'][1], 
        #         show_u=False,
        #         title=f"{result['file']} - {condition}"
        #     )


In [None]:
# Combine all the pkl files in the data directory into a single pkl file.

def combine_pkl_files(directory, output_file):
    combined_data = []

    # Iterate over all files in the directory
    for filename in os.listdir(directory):
        if filename.endswith('.pkl'):
            file_path = os.path.join(directory, filename)
            
            # Load the pkl file
            with open(file_path, 'rb') as file:
                data = pickle.load(file)
                
                # Add the 'file' key to each entry
                for key, entries in data.items():
                    for entry in entries:
                        entry['file'] = filename
                        combined_data.append(entry)

    # Save the combined data to a new pkl file
    with open(output_file, 'wb') as output:
        pickle.dump(combined_data, output)

# Usage
combine_pkl_files('/Users/stevenwendel/Documents/GitHub/bg/data', 'combined_data.pkl')