# For plotting figures from processed data post-analysis

# Imports

In [None]:
import warnings
warnings.catch_warnings(record=True)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

import logging
from polysaccharide2.genutils.logutils.IOHandlers import LOG_FORMATTER
logging.basicConfig(
    level=logging.INFO,
    format =LOG_FORMATTER._fmt,
    datefmt=LOG_FORMATTER.datefmt,
    force=True
)
LOGGER = logging.getLogger(__name__)

from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from polysaccharide2.graphics import plotutils

# Defining Paths

## Simulation data

In [None]:
sim_data_dir = Path('wasp_sims')
mol_dirs = {
    path.stem : path
        for path in sim_data_dir.iterdir()
}

## Analyzed data

In [None]:
data_dir_props = Path('analysis_output/data/combined_data/props')
data_dir_rdfs = Path('analysis_output/data/combined_data/rdfs')

## Plot output

In [None]:
figure_dir = Path('analysis_output/figures')
figure_dir.mkdir(exist_ok=True)

for subdir_name in ('rdfs', 'props', 'charges'):
    subdir = figure_dir / subdir_name
    subdir.mkdir(exist_ok=True)
    globals()[f'{figure_dir.name}_{subdir_name}'] = subdir # assign to variables in namespace

# Shape properties

In [None]:
# defining colors for MD framework (sampled from tab20)
color_set = (
    'blue',
    'orange',
    'green',
    'purple',
    'grey',
)
hues_per_color = 4
cmap = plt.get_cmap('tab20c')

cdict, carr = plotutils.label_discrete_cmap(cmap, color_set, hues_per_color)
dset_colors = {
    'Sage 2.0.0 - RCT'              : cdict['blue0'  ],
    'Sage 2.0.0 - Espaloma-AM1-BCC' : cdict['blue2'  ],
    'DREIDING - RESP'               : cdict['purple0'],
    'CHARMM - Mulliken'             : cdict['orange0'],
    'CHARMM-c - Mulliken'           : cdict['orange2'],
    'GAFF - RESP'                   : cdict['green0' ],
    'GAFF2 - RESP'                  : cdict['green2' ],
}

for prop_data_path in data_dir_props.iterdir():
    mol_name = prop_data_path.stem
    dframe = pd.read_csv(prop_data_path, index_col=[0, 1])
    dframe = dframe.sort_index(axis=1, key=lambda x : [list(dset_colors.keys()).index(c) for c in x]) # sort in same order as colors are defined
    obs = dframe.loc['observables']
    std = dframe.loc['uncertainties']

    frameworks = obs.columns
    num_dsets = len(frameworks)
    num_props = len(obs.index)
    x_pos = np.arange(num_dsets)

    fig, ax = plotutils.presize_subplots(nrows=1, ncols=num_props)
    # fig.suptitle(mol_name)

    for axis, (prop_name, prop_data) in zip(ax.flatten(), obs.iterrows()):
        uncert = std.loc[prop_name].to_numpy()

        axis.set_title(prop_name)
        axis.bar(x_pos, prop_data, yerr=uncert, color=[dset_colors[fr] for fr in frameworks])
        axis.set_xticks(x_pos)
        axis.set_xticklabels([fw.replace(' - ', '\n + ') for fw in obs.columns], rotation=-30)

    fig.savefig(figures_props / f'{mol_name}_properties.png', bbox_inches='tight')
    plt.close()

# RDFs

In [None]:
from ast import literal_eval
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


# set plotting parameters
fontsize = 14
scale = 10
aspect = 4/5

# subrange_min = 0.0 # subrange cutoffs, in nm
# subrange_max = 0.5
subrange_min = 0.2 # subrange cutoffs, in nm
subrange_max = 1.2

inset_head_fract : float = 1.2 # fraction of the width of the RDF plot to add to the top to accomodate the inset
inset_perc_x = 40 # percentage of main figure size to scale inset x axis
inset_perc_y = 40 # percentage of main figure size to scale inset y axis

# radii from new data
radii_label = 'Radius (nanometer)'

# loop over and plot
for data_path_rdf in data_dir_rdfs.iterdir():
    mol_name = data_path_rdf.stem
    combined_rdf_dframe = pd.read_csv(data_path_rdf, index_col=[0, 1])
    combined_rdf_dframe = combined_rdf_dframe.sort_index(axis=1, key=lambda x : [list(dset_colors.keys()).index(c) for c in x]) # sort in same order as colors are defined

    rdf_labels, series_names = combined_rdf_dframe.index.levels
    for rdf_name in rdf_labels:
        subframe = combined_rdf_dframe.xs(rdf_name)
        subframe = subframe.dropna(axis=1, inplace=False) # remove missing columns
        subframe = subframe.map(literal_eval) # de-stringify lists throughout

        fig, ax = plotutils.presize_subplots(nrows=1, ncols=1, scale=scale, elongation=aspect)
        fig.suptitle(mol_name)

        ax.set_xlabel(radii_label, fontsize=fontsize)
        ax.set_ylabel(rdf_name, fontsize=fontsize)
        ax.set_xlim(subrange_min, subrange_max)

        inset_ax = inset_axes(ax, # generate longer length-scale inset
            width=f'{inset_perc_x}%',
            height=f'{inset_perc_y}%',
            borderpad=3,
            loc='upper right'
        )
        inset_ax.set_xlabel(radii_label, fontsize=fontsize)

        for framework, data in subframe.items():
            radii, rdf_mean, rdf_std = [
                np.array(data[ser_name])
                    for ser_name in series_names
            ]
            subrange_idxs = np.where((subrange_min < radii) & (radii < subrange_max)) # extract region corresponding to 2-12 angstroms shown by Rukmani
            err_lower = rdf_mean - rdf_std # compute error band boundaries
            err_upper = rdf_mean + rdf_std

            # extract subranges
            radii_subrange    = radii[   subrange_idxs]
            rdf_mean_subrange = rdf_mean[subrange_idxs]
            rdf_std_subrange  = rdf_std[ subrange_idxs]
            err_lower_subrange  = err_lower[subrange_idxs]
            err_upper_subrange  = err_upper[subrange_idxs]
            ax.set_ylim(0, err_upper_subrange.max() * (1 + inset_head_fract)) # accomodate inset with maximum error
            
            # plot reduced main plot
            ax.plot(radii_subrange, rdf_mean_subrange, label=framework)
            ax.fill_between(radii_subrange, rdf_mean_subrange - rdf_std_subrange, rdf_mean_subrange + rdf_std_subrange, alpha=0.5)

            # plot longer length-scale inset
            inset_ax.plot(radii, rdf_mean, label=framework)
            inset_ax.fill_between(radii, rdf_mean - rdf_std, rdf_mean + rdf_std, alpha=0.5)
            
        ax.legend(loc='lower right', fontsize=fontsize)
        fig.savefig(figures_rdfs / f'{mol_name}_{rdf_name}.png', bbox_inches='tight')
        plt.close()

# Comparing charges

In [None]:
from math import ceil
from abc import ABC, abstractstaticmethod


class BinSizer(ABC):
    '''Abstract base for auto-sizing histogram bins
    Child class implementations taken from https://en.wikipedia.org/wiki/Histogram#Number_of_bins_and_width'''
    @abstractstaticmethod
    def num_bins(N : int) -> int:
        raise NotImplementedError

    @classmethod
    @property
    def registry(cls) -> dict[str, 'BinSizer']:
        '''Name-indexed dict of all inherited Component implementations'''
        return {
            subcomp.__name__ : subcomp
                for subcomp in cls.__subclasses__()
        }

class Sturges(BinSizer):
    '''Based on Sturges' Formula'''
    @staticmethod
    def num_bins(N: int) -> int:
        return 1 + ceil(np.log2(N)) 

class Sqrt(BinSizer):
    '''Based on square root'''
    @staticmethod
    def num_bins(N: int) -> int:
        return ceil(np.sqrt(N))

class Rice(BinSizer):
    '''Based on Rice's Rule'''
    @staticmethod
    def num_bins(N: int) -> int:
        return ceil(2 * N**(1/3))

In [None]:
from typing import Optional

from ast import literal_eval
import matplotlib.patches as mpl_patches
from matplotlib.colors import Normalize, Colormap

from openff.toolkit import Molecule

from polysaccharide2.openfftools import topology

from polysaccharide2.rdutils.rdkdraw import rdmol_prop_heatmap_colorscaled
from polysaccharide2.rdutils.labeling import molwise
from polysaccharide2.rdutils.rdconvert import RDConverter
from polysaccharide2.rdutils import rdprops

from polysaccharide2.genutils.maths.greek import GREEK_UPPER
from polysaccharide2.genutils.maths.statistics import RMSE
from polysaccharide2.graphics import plotutils


def generate_charge_plots(cmols : dict[str, Molecule], cmap : Colormap, n_bins : int=50, fontsize : int=14, cvtr : Optional[RDConverter]=None) -> tuple[tuple[plt.Figure, plt.Axes], tuple[plt.Figure, plt.Axes]]:
    '''Create charge difference heatmaps and histograms for a pair of parameterized Molecules'''
    # 0a) extract molecules
    chgd_offmol_1, chgd_offmol_2 = cmols.values()
    rdmols = [
        molwise.assign_ordered_atom_map_nums(offmol.to_rdkit(), in_place=False)
            for offmol in cmols.values()
    ]

    ## 0a) create strin labels
    method_str = ' vs '.join(cmols)
    chg_delta_str = f'{GREEK_UPPER["delta"]}q (elem. charge)'
    diff_str = f'{chg_delta_str} : {method_str}'

    # 1) generate heatmap
    diffmol = rdprops.difference_rdmol(*rdmols, prop='PartialCharge')
    if cvtr is not None:
        diffmol = cvtr.convert(diffmol)
    hm_fig, hm_ax = rdmol_prop_heatmap_colorscaled(diffmol, prop='DeltaPartialCharge', cmap=cmap, cbar_label=diff_str, orient='vertical')

    # 2) generate histogram
    deltas = literal_eval(diffmol.GetProp('DeltaPartialCharges'))
    deltas = np.array(deltas)
    vmin, vmax = deltas.min(), deltas.max()
    norm = Normalize(vmin, vmax)

    # generating histogram
    hist_fig, hist_ax = plotutils.presize_subplots(1, 1, scale=8, elongation=4/5)
    bin_vals, bin_edges, patches = hist_ax.hist(deltas, bins=n_bins, orientation='horizontal')

    # coloring histogram bars by atom color map
    for bin_edge, artist in zip(bin_edges, patches):
        plt.setp(artist, 'facecolor', cmap(norm(bin_edge)))

    # computing and labelling charge rmse
    rmse = RMSE(chgd_offmol_1.partial_charges, chgd_offmol_2.partial_charges)
    labels = [f'RMSE = {round(rmse.magnitude, 5)} e']
    handles = [mpl_patches.Rectangle((0, 0), 1, 1, fc="white", ec="white", lw=0, alpha=0)] * len(labels) # create empty handles to pin annotation text to
    hist_ax.legend(handles, labels, loc='best', handlelength=0, fontsize=fontsize)
    # hist_ax.annotate(f'rmse = {round(rmse.magnitude, 5)} e', (0.65, 0.9), xycoords='axes fraction', fontsize=14, bbox=dict(edgecolor='black', facecolor='white'))

    # Labelling and sizing axes
    # hist_ax.set_title(f'{polymer.mol_name} charge differences')
    _ = hist_ax.set_yticks((vmin, 0, vmax))
    hist_ax.set_xlabel('Number of atoms', fontsize=fontsize)
    hist_ax.set_ylabel(diff_str, fontsize=fontsize)

    return ((hm_fig, hm_ax), (hist_fig, hist_ax))

In [None]:
from polysaccharide2.genutils.fileutils.pathutils import assemble_path
fontsize : int = 15

for mol_name, mol_dir in mol_dirs.items():
    mol_dir_chg = figures_charges / mol_name
    mol_dir_chg.mkdir(exist_ok=True)

    # plot charges on reductions
    redux_cmols = {
        path.stem.split('_')[-1] : next(topology.topology_from_sdf(path).molecules)
            for path in mol_dir.iterdir()
                if (path.suffix == '.sdf') and ('reduced' in path.name)
    }

    ((redux_hm_fig, redux_hm_ax), (redux_hist_fig, redux_hist_ax)) = generate_charge_plots(redux_cmols, plt.get_cmap('turbo'), n_bins=20, fontsize=fontsize)
    redux_hm_fig.savefig(  mol_dir_chg / f'{mol_name}_redux_heatmap.png', bbox_inches='tight')
    redux_hist_fig.savefig(mol_dir_chg / f'{mol_name}_redux_histogram.png', bbox_inches='tight')
    plt.close()

    # plot charges on full molecules
    full_cmols = {}
    for charge_method in ('RCT', 'Espaloma-AM1-BCC'):
        full_sdf_path = assemble_path(mol_dir / charge_method, f'{mol_name}_{charge_method}', extension='sdf')
        full_cmols[charge_method] = next(topology.topology_from_sdf(full_sdf_path).molecules)

    ((hm_fig, hm_ax), (hist_fig, hist_ax)) = generate_charge_plots(full_cmols, plt.get_cmap('turbo'), n_bins=20, fontsize=fontsize)
    hm_fig.savefig(  mol_dir_chg / f'{mol_name}_heatmap.png', bbox_inches='tight')
    hist_fig.savefig(mol_dir_chg / f'{mol_name}_histogram.png', bbox_inches='tight')
    plt.close()