In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from database.query import traj_group
from clustering.substate_clusters import substates
from plot.plot_utilities import edgeformat, hist1d, hist2d, savefig

In [None]:
# Input traj_ids
traj_ids = traj_group(3)
%store traj_ids

traj_ids_closed = traj_group(2)
%store traj_ids_closed

# TM helix positions at the extracellular end

In [None]:
from conf.tmhelix import helix_positions, outer_leaflet_defs

helixnums = [1, 2, 6, 8, 11, 12]

# Create an instance of helical distance calculations
helix_dist = helix_positions(traj_ids, helixnums, leveldefs=outer_leaflet_defs)

# Calculate inter-residue distances
helix_dist.hdist_level()

## xy-coordinates of TM1 and TM11 and cluster analysis

In [None]:
tm1xy = helix_dist.helix_com.query('helix == 1')[['x', 'y']].values
tm2xy = helix_dist.helix_com.query('helix == 2')[['x', 'y']].values
tm11xy = helix_dist.helix_com.query('helix == 11')[['x', 'y']].values

### TM1-xy: 2 clusters

In [None]:
N = 2
clusters = substates(N, *tm1xy.T)
# GM does a bit bett
clusters.gaussian_mixture()

tm1xy_clusters = clusters.states

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10,5), sharey=True, gridspec_kw={'wspace': 0.1})

# Raw breakdown
axs[0].scatter(*tm1xy.T, c=tm1xy_clusters, s=2)
axs[0].scatter(*clusters.centers.T, marker='x', color='red')

for i, mu in enumerate(clusters.centers):
    axs[0].annotate(i, mu, color='red', fontsize=12)
    
# Quick contour view
hist2d(*tm1xy.T, bins=50, range=[[0,30], [50,70]]).hist2d_contour(axs[2])

# Formatting axes
for ax in axs.flatten():
    edgeformat(ax)
    ax.set_aspect('equal', adjustable='box', anchor='C')
    ax.set_xlim(10,30)
    ax.set_ylim(50,70)
    ax.grid(True, linestyle='--')

### TM11-xy: 2 clusters

In [None]:
N = 2
clusters = substates(N, *tm11xy.T)
clusters.gaussian_mixture()

tm11xy_clusters = clusters.states

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10,5), sharey=True, gridspec_kw={'wspace': 0.1})

# Raw breakdown
axs[0].scatter(*tm11xy.T, c=tm11xy_clusters, s=2)
axs[0].scatter(*clusters.centers.T, marker='x', color='red')

for i, mu in enumerate(clusters.centers):
    axs[0].annotate(i, mu, color='red', fontsize=12)
    
# Quick contour view
hist2d(*tm11xy.T, bins=50, range=[[0,20], [45,65]]).hist2d_contour(axs[2])

# Formatting axes
for ax in axs.flatten():
    edgeformat(ax)
    ax.set_aspect('equal', adjustable='box', anchor='C')
    ax.set_xlim(0,20)
    ax.set_ylim(45,65)
    ax.grid(True, linestyle='--')

# HOLE2 profile

In [None]:
from utils.dataset import read_trajdata

hole2_data, nframes_avail, traj_ids_avail = read_trajdata('hole2', traj_ids=traj_ids)

hole2_df = pd.DataFrame()

hole2_df['traj_id'] = np.repeat(traj_ids_avail, nframes_avail)
hole2_df['timestep'] = np.hstack([np.arange(n) for n in nframes_avail])
hole2_df[np.arange(90,150+1)] = hole2_data

# Replace -1 with NaN
hole2_df.replace(-1, np.nan, inplace=True)

hole2_df = hole2_df.query('traj_id in @traj_ids')

%store hole2_df
hole2_df

# Translocations on record

In [None]:
from database.query import get_translocation

transloc_df = pd.DataFrame(get_translocation(), columns=get_translocation()[0].keys())
transloc_df['timestep'] = (transloc_df['timestep'] * transloc_df['stepsize'] / 1000).astype(int)
transloc_df['stepsize'] = 1000

transloc_df = transloc_df.query('traj_id in @traj_ids')

%store transloc_df
transloc_df

# Sidechain positions

In [None]:
from conf.conf_analysis import sc_central

## R334

In [None]:
scdat = sc_central(traj_ids, resids=334)
scdat.load_sccinfo()
r334sc = scdat.sc_coord_set

fig, axs = plt.subplots()

r334_hist = hist1d(r334sc['z'], bins=60, range=[130,145])
r334_hist.plot(axs)

%store r334sc

## R134

In [None]:
scdat = sc_central(traj_ids, resids=134)
scdat.load_sccinfo()
r134sc = scdat.sc_coord_set

fig, axs = plt.subplots()

r134_hist = hist1d(r134sc['z'], bins=60, range=[110,120])
r134_hist.plot(axs, color='darkgreen')

## E1124

In [None]:
scdat = sc_central(traj_ids, resids=1124)
scdat.load_sccinfo()
e1124sc = scdat.sc_coord_set

In [None]:
fig, axs = plt.subplots()

e1124_hist = hist1d(e1124sc['z'], bins=60, range=[130,165])
e1124_hist.plot(axs)

%store e1124sc

## E1126

In [None]:
scdat = sc_central(traj_ids, resids=1126)
scdat.load_sccinfo()
e1126sc = scdat.sc_coord_set

In [None]:
fig, axs = plt.subplots()

e1124_hist = hist1d(e1126sc['z'], bins=60, range=[130,165])
e1124_hist.plot(axs)

%store e1126sc

# PCA

In [None]:
from pca.pca_analysis_workflow import analyze_pca

In [None]:
name = 'tmpc1v2'
datadir = '6msm_tmpc/all_tmpc.stride1.realign'
n_pcs = 2

N = 4
xrange = [-70,70]
yrange = [-70,70]

main = analyze_pca(datadir, n_pcs, traj_ids)
_ = main.variance_plots()
_ = main.state_clustering(N, xrange, yrange)

# State assignments

In [None]:
states_df = pd.DataFrame()

states_df[['traj_id', 'timestep']] = scdat.sc_coord_set[['traj_id', 'timestep']]
states_df[name] = main.clusters.states

states_df['r334'] = (r334sc['z'] < 136).astype(int)
states_df['r134'] = (r134sc['z'] < 115).astype(int)
states_df['e1124'] = (e1124sc['z'] > 140).astype(int)
states_df['e1126'] = (e1126sc['z'] > 139).astype(int)

%store states_df
states_df

# State labels and color scheme

In [None]:
state_labels = [(0, r"$\beta$", "stray", '#ff7f01'), 
                (1, r"$\delta$", "intermediate", 'green'), 
                (2, r"$\gamma$", "closed", '#0000ff'), 
                (3, r"$\alpha$", "open", '#ff0000')]
map_assign = {index: label for index, label, _, _ in state_labels}
color_assign = {index: color for index, _, _, color in state_labels}

%store map_assign
%store color_assign