# Notebook for testing generic ideas and developing functions

In [1]:
# Generic Imports
import re
from functools import partial
from collections import defaultdict
from itertools import combinations

# Numeric imports
from math import ceil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Typing and Subclassing
from typing import Any, Callable, ClassVar, Iterable, Optional, Union
from dataclasses import dataclass, field
from abc import ABC, abstractmethod, abstractproperty

# File I/O
import argparse
from pathlib import Path
import csv, json, pickle
from shutil import copyfile, rmtree
import importlib.resources as impres

## Removing fields from XML (useful for annoying barostate in OpenMM states

In [None]:
import xml.etree.ElementTree as ET

for sim_dir, sim_paths_file in pdir.simulation_paths.items():
    sim_paths = SimulationPaths.from_file(sim_paths_file)
    chk = sim_paths.checkpoint
    if chk.suffix == '.xml':
        tree = ET.parse(sim_paths.checkpoint)
        root = tree.getroot()

        par = next(root.iter('Parameters'))
        par.clear()

## Testing dynamic checkpoint file updating

In [None]:
import pickle

class Test:
    def __init__(self, val : int, checkpoint : Path) -> None:
        self.val = val
        self.checkpoint_path = checkpoint

    def to_file(self):
        if hasattr(self, 'checkpoint_path'):
            with self.checkpoint_path.open('wb') as file:
                pickle.dump(self, file)

    def __setattr__(self, __name: str, __value: Any) -> None:
        super().__setattr__(__name, __value)
        self.to_file()
        print(__name, __value)

In [None]:
p = Path('test.pkl')
p.touch()

t = Test(5, p)
t.other = 'word'

In [None]:
with p.open('rb') as file:
    v = pickle.load(file)

v.__dict__
v.foo = 'bar'

## Experimenting with grid size optimization WRT aspect and number of squares

In [None]:
from math import ceil, sqrt, floor

def size_penalty(N_targ : int, N_real : int) -> float:
    return (N_real / N_targ - 1)**2

def aspect_penalty(a_targ : float, a_real : float) -> float:
    # return (a_real / a_targ - 1)**2
    return 1 - min(a_targ / a_real, a_real / a_targ)

def dims(N : int, a : float=1/1, w1=1, w2=1) -> tuple[int, int]:
    '''Given a particular number of cells and an aspect ratio, yields the smallest 2x2 grid dimensions which accomodate at least N grid squares whose aspect ratio is closest to the '''
    return min( 
        ((r, ceil(N / r))
            for r in range(1, N + 1)
        ),
        key=lambda dims : w1*size_penalty(N, dims[0]*dims[1]) + w2*aspect_penalty(a, dims[0]/dims[1])
    )

a = 2/1
for N in range(1, 20):
    nrows, ncols = dims(N, a)
    fig, ax = plotutils.presize_subplots(nrows=nrows, ncols=ncols, scale=1)
    fig.suptitle(f'N = {N}')