In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection

from clustering.substate_clusters import substates
from plot.plot_utilities import edgeformat, hist2d, savefig
from utils.core_utilities import overlapping_split

# Setup

In [None]:
%store -r states_df
%store -r traj_ids

In [None]:
# Which two distances to analyze/serve as analysis basis
%store -r pca_df

name_d1 = 'pc1'
name_d2 = 'pc2'

# datf = pd.DataFrame(tm11xy, columns=[name_d1, name_d2])
datf = pca_df

d1 = datf[name_d1]
d2 = datf[name_d2]

# Data range
xrange = [-70,70]
yrange = [-60,60]

N = 4

alldat = datf.iloc[:,2:].values

## Plot specifications

In [None]:
def plotspec(axs):
    axs.set_aspect('equal', adjustable='box', anchor='C')
    axs.set_xlim(*xrange)
    axs.set_ylim(*yrange)
    axs.grid(True, linestyle='--')

## Preparatory calculations

In [None]:
# Substate clustering
# states = substates(N, d1, d2)
states = substates(N, *alldat.T)

### Histogramming ###
calc_dens = hist2d(d1, d2, bins=50, range=[xrange, yrange])
dens = calc_dens.dens
xedges = calc_dens.xplot_edges
yedges = calc_dens.yplot_edges

# What do the trajectories look like on this slice of conformational space?

## Where do translocations occur

In [None]:
%store -r transloc_df
# trajectories with translocation events in the previous paper (Zeng, Linsdell, Pomes, 2023)
traj2plot = [12, 15]

In [None]:
tf_translocate = transloc_df.query("traj_id in @traj2plot")[['traj_id', 'timestep']]
pca_translocate = pd.merge(pca_df, tf_translocate, on=['traj_id', 'timestep'])

In [None]:
# Added smoothing settings
traj_stride = 1
avg_window = 5

for t in traj_ids:
    if not t in traj2plot:
        continue
    traj_d1 = pca_df.query('traj_id == @t')[name_d1][::traj_stride]
    traj_d2 = pca_df.query('traj_id == @t')[name_d2][::traj_stride]
    traj_select = (pca_df['traj_id'] == t).values
    
    fig, axs = plt.subplots(1,3, figsize=(12,4), sharey=True, gridspec_kw={'wspace':0.1})
    
    prep_segments = overlapping_split(np.vstack([traj_d1, traj_d2]).T)
    traj2d = LineCollection(prep_segments, linewidth=1, cmap=plt.cm.viridis)
    traj2d.set_array(np.arange(len(prep_segments)))
    
    axs[0].add_collection(traj2d)
    axs[0].scatter(traj_d1.iloc[0], traj_d2.iloc[0], s=4, marker='*', zorder=10, c='magenta')

    axs[1].scatter(traj_d1, traj_d2, s=2, c='red')
    axs[1].contour(xedges, yedges, dens.T)

    axs[0].scatter(pca_translocate.query('traj_id == @t')[name_d1], pca_translocate.query('traj_id == @t')[name_d2], c='black', zorder=2, marker='x')
    axs[1].scatter(pca_translocate.query('traj_id == @t')[name_d1], pca_translocate.query('traj_id == @t')[name_d2], c='black', zorder=2, marker='x')
    
    if t == 12:
        axs[1].scatter(pca_df.query('traj_id == 12 & timestep == 900')[name_d1], pca_df.query('traj_id == 12 & timestep == 900')[name_d2], s=36, c='cyan', zorder=2, marker='X')
        axs[1].scatter(pca_df.query('traj_id == 12 & timestep == 1000')[name_d1], pca_df.query('traj_id == 12 & timestep == 1000')[name_d2], s=36, c='cyan', zorder=2, marker='X')
    else:
        axs[1].scatter(pca_df.query('traj_id == 15 & timestep == 480')[name_d1], pca_df.query('traj_id == 15 & timestep == 480')[name_d2], s=36, c='cyan', zorder=2, marker='X')
        axs[1].scatter(pca_df.query('traj_id == 15 & timestep == 730')[name_d1], pca_df.query('traj_id == 15 & timestep == 730')[name_d2], s=36, c='cyan', zorder=2, marker='X')
        axs[1].scatter(pca_df.query('traj_id == 15 & timestep == 1000')[name_d1], pca_df.query('traj_id == 15 & timestep == 1000')[name_d2], s=36, c='cyan', zorder=2, marker='X')

    # A filled contour just for the current trajectory
    hist2d(traj_d1, traj_d2, bins=50, range=[xrange, yrange]).hist2d_contourf(axs[2])
    
    for ax in axs:
        edgeformat(ax)
        # ax.set_title('traj_id={}, E={}'.format(t, get_trajattr(t, 'voltage')))
        ax.set_aspect('equal', adjustable='box', anchor='C')
        ax.set_xlim(*xrange)
        ax.set_ylim(*yrange)
        ax.grid(True, linestyle='--')
        
    # savefig(f"{t}_traj2d_translocation_marked.pdf")