# Notebook for analyzing simulation data

# Imports

In [15]:
import warnings
warnings.catch_warnings(record=True)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

import logging
from polysaccharide2.genutils.logutils.IOHandlers import LOG_FORMATTER
logging.basicConfig(
    level=logging.INFO,
    format =LOG_FORMATTER._fmt,
    datefmt=LOG_FORMATTER.datefmt,
    force=True
)
LOGGER = logging.getLogger(__name__)

from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd

import mdtraj
from openmm.unit import Unit
from openmm.unit import nanometer

from polysaccharide2.analysis import mdtrajutils
from polysaccharide2.genutils.fileutils.pathutils import assemble_path
from polysaccharide2.openmmtools.serialization import SimulationPaths, SimulationParameters

# Defining Paths

## Simulation data

In [3]:
raw_data_dir = Path('wasp_sims')
mol_dirs = {
    path.stem : path
        for path in raw_data_dir.iterdir()
}

## Data output

In [41]:
data_dir = Path('analysis_output/data')
data_dir.mkdir(exist_ok=True)

colina_dir   = data_dir / 'colina_data'
openff_dir   = data_dir / 'openff_data'
combined_dir = data_dir / 'combined_data'

for super_dir in (openff_dir, combined_dir, colina_dir):
    super_dir.mkdir(exist_ok=True)
    for subdir_name in ('rdfs', 'props'):
        subdir = super_dir / subdir_name
        subdir.mkdir(exist_ok=True)
        globals()[f'{super_dir.name}_{subdir_name}'] = subdir # assign to variables in namespace

## Plot output

In [42]:
figure_dir  =  data_dir / 'figures'
figure_dir.mkdir(exist_ok=True)

for subdir_name in ('rdfs', 'props'):
    subdir = figure_dir / subdir_name
    subdir.mkdir(exist_ok=True)
    globals()[f'{figure_dir.name}_{subdir_name}'] = subdir # assign to variables in namespace

# Analyzing individual trajectories

In [None]:
min_rad : float = 0.0
max_rad : float = 2.0
rad_unit : Unit = nanometer
stride : int = 1
prevent_overwrites : bool = False

for i, sim_paths_path in enumerate(raw_data_dir.glob('**/production/*_paths.json')):
    working_dir = sim_paths_path.parent
    assert(working_dir.is_dir())
    prefix = working_dir.name

    LOGGER.info(f'Analyzing trajectory found in: {working_dir} (# {i + 1})')
    sim_paths = SimulationPaths.from_file(sim_paths_path)
    sim_params = SimulationParameters.from_file(sim_paths.parameters_path)

    # load MDTraj trajectories
    LOGGER.info(f'Loading trajectory from {sim_paths.trajectory_path}')
    traj = mdtraj.load(sim_paths.trajectory_path, top=sim_paths.topology_path, stride=stride)
    LOGGER.info('Stripping solvent')
    traj_no_solv = traj.remove_solvent(inplace=False)
    unique_elems = mdtrajutils.unique_elem_types(traj_no_solv)

    # computing and saving shape property time series'
    if (sim_paths.time_data_path is None) or not (sim_paths.time_data_path.exists() and prevent_overwrites):
        prop_data = mdtrajutils.acquire_time_props(traj_no_solv, time_points=sim_params.integ_params.time_points[::stride]) 
        LOGGER.info('Computed time series\' from trajectory')
        time_data_path = assemble_path(working_dir, prefix, extension='csv', postfix='time_series')
        prop_data.to_csv(time_data_path, index=False)
        sim_paths.time_data_path = time_data_path

    # RDFs
    ## determine IDs of relevant atom pairs
    if (sim_paths.spatial_data_path is None) or not (sim_paths.spatial_data_path.exists() and prevent_overwrites):
        pair_dict = {
            'chain O - water O' : traj.top.select_pairs('not water and element O', 'water and element O'),
            # 'water O - water O' : traj.top.select_pairs('water and element O', 'water and element O') # too many waters, prohibitive memory usage
        }
        
        if 'N' in mdtrajutils.unique_elem_types(traj):
            pair_dict['chain N - water O'] = traj.top.select_pairs('not water and element N', 'water and element O')

        ## computing and saving RDFs 
        rdf_data = mdtrajutils.acquire_rdfs(traj, pair_dict, min_rad=min_rad, max_rad=max_rad, rad_unit=rad_unit)
        LOGGER.info('Computed radial distribution functions (RDFs) from trajectory')
        rdf_data_path = assemble_path(working_dir, prefix, extension='csv', postfix='rdfs')
        rdf_data.to_csv(rdf_data_path, index=False)
        sim_paths.spatial_data_path = rdf_data_path

    LOGGER.info(f'Saving updated paths to {sim_paths_path}\n') # add newline for breathing room
    sim_paths.to_file(sim_paths_path) # update JSON file with paths on disc

# Collating individually processed data into unified collections

In [52]:
stat_fns = {
    'observables'   : np.mean,
    'uncertainties' : np.std
}

total_data_props = defaultdict(lambda : defaultdict(lambda : defaultdict(list)))
total_data_rdfs  = defaultdict(lambda : defaultdict(lambda : defaultdict(pd.DataFrame)))

for i, sim_paths_path in enumerate(raw_data_dir.glob('**/production/*_paths.json')):
    # extract level info from simulation filetree
    working_dir = sim_paths_path.parent
    assert(working_dir.is_dir())
    mol_name, charge_method, conf_name, prefix = str(working_dir.relative_to(raw_data_dir)).split('/')

    # load analyzed data
    sim_paths = SimulationPaths.from_file(sim_paths_path)
    sim_params = SimulationParameters.from_file(sim_paths.parameters_path)

    # reading property data
    time_data = pd.read_csv(sim_paths.time_data_path)
    time_steps, time_samples = mdtrajutils.props_to_plot_data(time_data)
    for prop_name, time_series in time_samples.items():
        total_data_props[mol_name][charge_method][prop_name].append(time_series.mean()) # take equilibrium average over each time series

    # reading RDF data
    rdf_data = pd.read_csv(sim_paths.spatial_data_path)
    radii_openff, rdfs = mdtrajutils.rdfs_to_plot_data(rdf_data)
    for atom_pair_name, rdf in rdfs.items():
        total_data_rdfs[mol_name][charge_method][atom_pair_name][conf_name] = rdf

## Processing shape properties

In [54]:
for mol_name, mol_dict in total_data_props.items():
    dframe = pd.DataFrame.from_dict({
        f'Sage 2.0.0 - {chg_method}' : {
            (stat_name, prop_name) : stat_fn(prop_data)
                for stat_name, stat_fn in stat_fns.items()
                    for prop_name, prop_data in data_dict.items()
        }
        for chg_method, data_dict in mol_dict.items()
    })
    # dframe.to_csv(openff_data_props / f'{mol_name}.csv') # index deliberately left in
    dframe.to_csv(openff_data_props / f'{mol_name}.csv') # index deliberately left in

In [55]:
for ref_data_dir in colina_data_props.iterdir():
    filename = ref_data_dir.name

    new_data = pd.read_csv(openff_data_props / filename, index_col=[0, 1])
    ref_data = pd.read_csv(ref_data_dir         , index_col=[0, 1])
    data = pd.concat([new_data, ref_data], axis=1)

    data.to_csv(combined_data_props / filename)

## Processing RDFs

In [56]:
for mol_name, mol_dict in total_data_rdfs.items():
    df_dict = defaultdict(defaultdict)

    for chg_method, data_dict in mol_dict.items():
        for elem_pair, rdf_data in data_dict.items():
            framework = f'Sage 2.0.0 - {chg_method}'
            df_dict[framework][(elem_pair, radii_openff.columns[0])] = radii_openff['Radius (nanometer)'].to_list()
            for stat_name, stat_fn in stat_fns.items():
                df_dict[framework][(elem_pair, stat_name)] = list(stat_fn(rdf_data.to_numpy(), axis=1))

    dframe = pd.DataFrame.from_dict(df_dict)
    dframe.to_csv(openff_data_rdfs / f'{mol_name}.csv') # index deliberately left in

In [57]:
for path in openff_data_rdfs.iterdir():
    filename = path.name
    openff_rdfs = pd.read_csv(path, index_col=(0, 1))
    colina_rdfs = pd.read_csv(colina_data_rdfs / filename, index_col=(0, 1))

    combined_rdfs = pd.concat([openff_rdfs, colina_rdfs], axis=1)
    combined_rdfs.to_csv(combined_data_rdfs / filename)