# PSSM

> Functions related with PSSMs

## Setup

In [None]:
#| default_exp pssm

In [None]:
#| export
import numpy as np, pandas as pd
from katlas.data import *
from katlas.utils import *
from fastcore.meta import delegates
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from functools import partial
from fastcore.meta import delegates

# for plot
from matplotlib import pyplot as plt
import logomaker,math
import seaborn as sns

from matplotlib.colors import TwoSlopeNorm

# for plot two heatmaps
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FuncFormatter

# hierarchical clustering
from scipy.cluster.hierarchy import linkage, fcluster,dendrogram

In [None]:
#| hide
pd.set_option('display.max_rows', 5)
pd.set_option('display.max_columns', 100) # show all columns

```python
from katlas.pssm import *
```

In [None]:
#| hide
#| export
EPSILON = 1e-8

## PSSM

We need to compute **position-specific probability matrix (PSSM)** from a list of aligned site sequences.

For each position $i$ (e.g., from $-7$ to $+7$), the probability of observing amino acid $x$ is:

$$
P_i(x) = \frac{\text{count of amino acid } x \text{ at position } i}{\text{total counts at position } i}
$$

The following 23 amino acids are included:

- Standard amino acids:  
  `A`, `C`, `D`, `E`, `F`, `G`, `H`, `I`, `K`, `L`, `M`, `N`, `P`, `Q`, `R`, `S`, `T`, `V`, `W`, `Y`  
- Modified amino acids:  
  `s`, `t`, `y` (often used to denote phosphorylated `S`, `T`, `Y`)

In the output, the modified residues are renamed as:
- `s` → `pS`  
- `t` → `pT`  
- `y` → `pY`

The resulting matrix has:
- **Rows**: Amino acids (including `pS`, `pT`, `pY`),
- **Columns**: Sequence positions (centered on the phosphosite),
- **Values**: Probabilities of each amino acid at each position.

In [None]:
#| export
def get_prob(data, col: str='site_seq', aa_order=[i for i in 'PGACSTVILMFYWHKRQNDEsty']):
    "Get the probability matrix of PSSM from phosphorylation site sequences."

    # --- Normalize input to Series of sequences ---
    if isinstance(data, pd.DataFrame):
        if col not in data.columns:
            raise ValueError(f"Column '{col}' not found in DataFrame.")
        site = data[col]
    elif isinstance(data, (pd.Series, list)):
        site = pd.Series(data)
    else:
        raise TypeError("Input must be a DataFrame, Series, or list of sequences.")
    
    site = check_seqs(site)
    
    site_array = np.array(site.apply(list).tolist())
    seq_len = site_array.shape[1]
    
    position = list(range(-(seq_len // 2), (seq_len // 2)+1)) # add 1 because range do not include the final num
    
    site_df = pd.DataFrame(site_array, columns=position)
    melted = site_df.melt(var_name='Position', value_name='aa')
    
    grouped = melted.groupby(['Position', 'aa']).size().reset_index(name='Count')
    grouped = grouped[grouped.aa.isin(aa_order)].reset_index(drop=True)
    
    pivot_df = grouped.pivot(index='aa', columns='Position', values='Count').fillna(0)
    pssm_df = pivot_df / pivot_df.sum()
    
    pssm_df = pssm_df.reindex(index=aa_order, columns=position, fill_value=0)
    pssm_df = pssm_df.rename(index={'s': 'pS', 't': 'pT', 'y': 'pY'})
    
    return pssm_df

In [None]:
data = Data.get_ks_dataset()

In [None]:
data_k = data[data.kinase_uniprot=='P06493'] # CDK1

In [None]:
get_prob(data_k['site_seq'].tolist())

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.100762,0.083686,0.082422,0.090680,0.079631,0.082008,0.088333,0.077815,0.091584,0.083128,0.071487,0.080526,0.085106,0.091354,0.097561,0.087520,0.091350,0.104116,0.159677,0.082192,0.000000,0.758065,0.086360,0.088781,0.084951,0.101297,0.090171,0.107492,0.086743,0.098280,0.088889,0.085950,0.091211,0.078138,0.078138,0.085762,0.096909,0.096477,0.063973,0.081218,0.094996
G,0.069433,0.073542,0.068966,0.054576,0.081308,0.064435,0.084167,0.072848,0.080858,0.075720,0.087099,0.089565,0.060556,0.075041,0.083740,0.076985,0.080032,0.082324,0.066935,0.096696,0.000000,0.031452,0.092817,0.062954,0.050162,0.081037,0.066613,0.065961,0.076105,0.066339,0.074074,0.065289,0.067993,0.087282,0.068994,0.071607,0.085213,0.060403,0.074074,0.082910,0.064461
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
pT,0.028789,0.027895,0.031960,0.020151,0.020956,0.038494,0.019167,0.031457,0.023102,0.023868,0.032868,0.027938,0.033552,0.027732,0.038211,0.038088,0.039612,0.047619,0.045161,0.031426,0.324738,0.016129,0.037934,0.025827,0.038835,0.021070,0.035743,0.035831,0.027005,0.026208,0.023868,0.026446,0.022388,0.033250,0.023275,0.022481,0.028404,0.024329,0.028620,0.027919,0.018660
pY,0.003387,0.005917,0.011775,0.007557,0.005029,0.010879,0.005000,0.004139,0.006601,0.009877,0.009860,0.011504,0.009820,0.006525,0.010569,0.012156,0.016168,0.008071,0.012097,0.008058,0.010475,0.004032,0.009685,0.015335,0.009709,0.014587,0.004874,0.006515,0.005728,0.007371,0.010700,0.007438,0.010779,0.007481,0.008313,0.004996,0.004177,0.002517,0.008418,0.007614,0.008482


In [None]:
pssm_df = get_prob(data_k) # or get_prob(data_k['site_seq'])
pssm_df.head()

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.100762,0.083686,0.082422,0.09068,0.079631,0.082008,0.088333,0.077815,0.091584,0.083128,0.071487,0.080526,0.085106,0.091354,0.097561,0.08752,0.09135,0.104116,0.159677,0.082192,0.0,0.758065,0.08636,0.088781,0.084951,0.101297,0.090171,0.107492,0.086743,0.09828,0.088889,0.08595,0.091211,0.078138,0.078138,0.085762,0.096909,0.096477,0.063973,0.081218,0.094996
G,0.069433,0.073542,0.068966,0.054576,0.081308,0.064435,0.084167,0.072848,0.080858,0.07572,0.087099,0.089565,0.060556,0.075041,0.08374,0.076985,0.080032,0.082324,0.066935,0.096696,0.0,0.031452,0.092817,0.062954,0.050162,0.081037,0.066613,0.065961,0.076105,0.066339,0.074074,0.065289,0.067993,0.087282,0.068994,0.071607,0.085213,0.060403,0.074074,0.08291,0.064461
A,0.07536,0.071006,0.079058,0.058774,0.070411,0.067782,0.075833,0.086921,0.068482,0.069959,0.067379,0.062449,0.069558,0.070962,0.069919,0.073744,0.074373,0.071832,0.083065,0.087027,0.0,0.022581,0.09201,0.056497,0.079288,0.057536,0.077985,0.072476,0.06874,0.066339,0.07572,0.071074,0.065506,0.063175,0.069825,0.07244,0.067669,0.079698,0.058081,0.06176,0.07888
C,0.013548,0.006762,0.007569,0.009236,0.01425,0.011715,0.008333,0.00745,0.009901,0.007407,0.010682,0.008217,0.007365,0.008972,0.010569,0.005673,0.004042,0.007264,0.01371,0.008864,0.0,0.001613,0.008878,0.006457,0.010518,0.009724,0.008936,0.010586,0.007365,0.010647,0.006584,0.01157,0.012438,0.004988,0.011638,0.010824,0.011696,0.009228,0.013468,0.007614,0.012723
S,0.04149,0.053254,0.043734,0.052057,0.035205,0.041004,0.050833,0.050497,0.043729,0.039506,0.04355,0.041085,0.040098,0.038336,0.034146,0.029984,0.034762,0.016949,0.020968,0.028203,0.0,0.003226,0.029056,0.02502,0.038026,0.033225,0.044679,0.041531,0.043372,0.044226,0.050206,0.041322,0.046434,0.044057,0.050707,0.045795,0.045948,0.04698,0.050505,0.054146,0.047498


## Transform PSSM

In [None]:
#| export
def pSTY2sty(string):
    "Convert pS/pT/pY to s/t/y in a string."
    return string.replace('pS', 's').replace('pT', 't').replace('pY', 'y')

In [None]:
pspa_scale = Data.get_pspa_all_scale()
pspa_scale.columns.map(pSTY2sty)

Index(['-5P', '-5G', '-5A', '-5C', '-5S', '-5T', '-5V', '-5I', '-5L', '-5M',
       ...
       '4H', '4K', '4R', '4Q', '4N', '4D', '4E', '4s', '4t', '4y'],
      dtype='object', length=230)

In [None]:
#| export
def flatten_pssm(pssm_df,
                 use_sty=False, # if True, use s,t,y instead of pS,pT,pY
                 column_wise=True, # if True, column major flatten; else row wise flatten (for pytorch training)
                ):
    "Flatten PSSM dataframe to dictionary"
    
    pssm_df=pssm_df.copy()
    
    # Convert pS,pT,pY to s,t,y
    if use_sty: pssm_df.index=pssm_df.index.map(pSTY2sty)

    # Flatten a pssm_df
    # Column wise
    if column_wise: 
        pssm = pssm_df.unstack().reset_index(name='value')
        # Combine position column and residue identity column as new column for keys
        pssm['position_residue']=pssm.iloc[:,0].astype(str)+pssm.iloc[:,1]
    # Row wise
    else: 
        pssm = pssm_df.T.unstack().reset_index(name='value')
        pssm['position_residue']=pssm.iloc[:,1].astype(str)+pssm.iloc[:,0].astype(str)
    
    

    # Set index to be position+residue
    return pssm.set_index('position_residue')['value'].to_dict()

In [None]:
flat_pssm = pd.Series(flatten_pssm(pssm_df))
flat_pssm

-20P    0.100762
-20G    0.069433
          ...   
20pT    0.018660
20pY    0.008482
Length: 943, dtype: float64

In [None]:
#| export
def recover_pssm(flat_pssm: pd.Series):
    """Recover 2D PSSM from flattened PSSM Series.
    Only includes amino acids present in `flat_pssm`, preserving canonical order.
    """
    df = flat_pssm.copy().reset_index()
    df.columns = ['info', 'value']
    df['Position'] = df['info'].str.extract(r'(-?\d+)').astype(int)
    df['aa'] = df['info'].str.extract(r'-?\d+\s*(.*)')

    df = df.pivot(index='aa', columns='Position', values='value').fillna(0)

    aa_order_basic = list('PGACSTVILMFYWHKRQNDE')
    aa_order_phospho = aa_order_basic + ['pS', 'pT', 'pY']
    aa_order_lower = aa_order_basic + ['s', 't', 'y']

    # if already phospho-labeled (pS/pT/pY)
    if any(x.startswith('p') for x in df.index):
        order = [aa for aa in aa_order_phospho if aa in df.index]
        return df.reindex(index=order)
    # otherwise convert lowercase s/t/y to pS/pT/pY if they exist
    else:
        df = df.rename(index={'s': 'pS', 't': 'pT', 'y': 'pY'})
        order = [aa for aa in aa_order_phospho if aa in df.index]
        return df.reindex(index=order)

In [None]:
out = recover_pssm(flat_pssm)
out

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.100762,0.083686,0.082422,0.090680,0.079631,0.082008,0.088333,0.077815,0.091584,0.083128,0.071487,0.080526,0.085106,0.091354,0.097561,0.087520,0.091350,0.104116,0.159677,0.082192,0.000000,0.758065,0.086360,0.088781,0.084951,0.101297,0.090171,0.107492,0.086743,0.098280,0.088889,0.085950,0.091211,0.078138,0.078138,0.085762,0.096909,0.096477,0.063973,0.081218,0.094996
G,0.069433,0.073542,0.068966,0.054576,0.081308,0.064435,0.084167,0.072848,0.080858,0.075720,0.087099,0.089565,0.060556,0.075041,0.083740,0.076985,0.080032,0.082324,0.066935,0.096696,0.000000,0.031452,0.092817,0.062954,0.050162,0.081037,0.066613,0.065961,0.076105,0.066339,0.074074,0.065289,0.067993,0.087282,0.068994,0.071607,0.085213,0.060403,0.074074,0.082910,0.064461
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
pT,0.028789,0.027895,0.031960,0.020151,0.020956,0.038494,0.019167,0.031457,0.023102,0.023868,0.032868,0.027938,0.033552,0.027732,0.038211,0.038088,0.039612,0.047619,0.045161,0.031426,0.324738,0.016129,0.037934,0.025827,0.038835,0.021070,0.035743,0.035831,0.027005,0.026208,0.023868,0.026446,0.022388,0.033250,0.023275,0.022481,0.028404,0.024329,0.028620,0.027919,0.018660
pY,0.003387,0.005917,0.011775,0.007557,0.005029,0.010879,0.005000,0.004139,0.006601,0.009877,0.009860,0.011504,0.009820,0.006525,0.010569,0.012156,0.016168,0.008071,0.012097,0.008058,0.010475,0.004032,0.009685,0.015335,0.009709,0.014587,0.004874,0.006515,0.005728,0.007371,0.010700,0.007438,0.010779,0.007481,0.008313,0.004996,0.004177,0.002517,0.008418,0.007614,0.008482


In [None]:
out.loc[pssm_df.index,pssm_df.columns].equals(pssm_df)

True

Or recover from PSPA data, where s, t, y will be converted to pS, pT, and pY:

In [None]:
pspa = Data.get_pspa_all_norm()

In [None]:
flat_pssm_pspa = pspa.loc['AAK1'].dropna()
flat_pssm_pspa

-5P    0.0720
-5G    0.0245
        ...  
0T     1.0000
0Y     0.0000
Name: AAK1, Length: 213, dtype: float64

In [None]:
recovered = recover_pssm(flat_pssm_pspa)
recovered

Position,-5,-4,-3,-2,-1,0,1,2,3,4
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
P,0.0720,0.0534,0.1084,0.0226,0.1136,0.0,0.0463,0.0527,0.0681,0.0628
G,0.0245,0.0642,0.0512,0.0283,0.0706,0.0,0.7216,0.0749,0.0923,0.0702
...,...,...,...,...,...,...,...,...,...,...
pT,0.0201,0.0332,0.0303,0.0209,0.0121,1.0,0.0123,0.0409,0.0335,0.0251
pY,0.0611,0.0339,0.0274,0.0486,0.0178,0.0,0.0100,0.0410,0.0359,0.0270


PSPA is not scaled per position, and the recovered pssm_df also contained copies of pS,pT,pY in zero position (S,T,Y).

So we need to remove the redundant copy in zero position (leave pS/pT/pY only) and scaled to 1 per position.

In [None]:
#| export
def _clean_zero(pssm_df):
    "Zero out non-last three values in position 0 (keep only s,t,y values at center)"
    pssm_df = pssm_df.copy()
    pssm_df.loc[pssm_df.index[:-3], 0] = 0
    return pssm_df

In [None]:
_clean_zero(pssm_df)

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.100762,0.083686,0.082422,0.090680,0.079631,0.082008,0.088333,0.077815,0.091584,0.083128,0.071487,0.080526,0.085106,0.091354,0.097561,0.087520,0.091350,0.104116,0.159677,0.082192,0.000000,0.758065,0.086360,0.088781,0.084951,0.101297,0.090171,0.107492,0.086743,0.098280,0.088889,0.085950,0.091211,0.078138,0.078138,0.085762,0.096909,0.096477,0.063973,0.081218,0.094996
G,0.069433,0.073542,0.068966,0.054576,0.081308,0.064435,0.084167,0.072848,0.080858,0.075720,0.087099,0.089565,0.060556,0.075041,0.083740,0.076985,0.080032,0.082324,0.066935,0.096696,0.000000,0.031452,0.092817,0.062954,0.050162,0.081037,0.066613,0.065961,0.076105,0.066339,0.074074,0.065289,0.067993,0.087282,0.068994,0.071607,0.085213,0.060403,0.074074,0.082910,0.064461
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
pT,0.028789,0.027895,0.031960,0.020151,0.020956,0.038494,0.019167,0.031457,0.023102,0.023868,0.032868,0.027938,0.033552,0.027732,0.038211,0.038088,0.039612,0.047619,0.045161,0.031426,0.324738,0.016129,0.037934,0.025827,0.038835,0.021070,0.035743,0.035831,0.027005,0.026208,0.023868,0.026446,0.022388,0.033250,0.023275,0.022481,0.028404,0.024329,0.028620,0.027919,0.018660
pY,0.003387,0.005917,0.011775,0.007557,0.005029,0.010879,0.005000,0.004139,0.006601,0.009877,0.009860,0.011504,0.009820,0.006525,0.010569,0.012156,0.016168,0.008071,0.012097,0.008058,0.010475,0.004032,0.009685,0.015335,0.009709,0.014587,0.004874,0.006515,0.005728,0.007371,0.010700,0.007438,0.010779,0.007481,0.008313,0.004996,0.004177,0.002517,0.008418,0.007614,0.008482


In [None]:
#| export
def clean_zero_normalize(pssm_df):
    "Zero out non-last three values in position 0 (keep only s,t,y values at center), and normalize per position"
    pssm_df=pssm_df.copy()
    pssm_df.columns= pssm_df.columns.astype(int)
    pssm_df = _clean_zero(pssm_df)
    pssm_df = pssm_df/pssm_df.sum()
    return pssm_df

This function applies phosphosite-specific cleaning and normalization to a PSSM.

At the center position ($i = 0$), only the last three rows of the matrix — corresponding to phosphorylatable residues `s`, `t`, and `y` — are retained. All other amino acid values at position 0 are set to 0.

After masking, the matrix is column-normalized to ensure the probabilities at each position sum to 1:

$$
P_i(x) = \frac{P_i(x)}{\sum_{x'} P_i(x')}
$$


In [None]:
norm_pssm = clean_zero_normalize(recovered)
norm_pssm.head()

Position,-5,-4,-3,-2,-1,0,1,2,3,4
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
P,0.058446,0.041715,0.0861,0.017935,0.096068,0.0,0.042649,0.040482,0.05264,0.05026
G,0.019888,0.050152,0.040667,0.022459,0.059704,0.0,0.664702,0.057536,0.071346,0.056182
A,0.023054,0.055152,0.08888,0.042695,0.032558,0.0,0.02874,0.057613,0.044987,0.051701
C,0.037016,0.043747,0.052025,0.046663,0.026469,0.0,0.020542,0.052543,0.057355,0.048259
S,0.0345,0.048356,0.041859,0.044044,0.046089,0.0,0.013172,0.042403,0.044987,0.044818


## PSSM of Log odds 

In [None]:
#| export
def get_pssm_LO(pssm_df,
                site_type, # S, T, Y, ST, or STY
               ):
    "Get log odds PSSM: log2 (freq pssm/background pssm)."
    bg_pssms = Data.get_ks_background()
    flat_bg = bg_pssms.loc[f'ks_{site_type}']
    pssm_bg = recover_pssm(flat_bg)
    pssm_odds = ((pssm_df+EPSILON)/(pssm_bg+EPSILON)).dropna(axis=0,how='all').dropna(axis=1, how='all')
    # return pssm_odds
    # make sure all columns and index matched
    assert pssm_odds.shape == pssm_df.shape
    return np.log2(pssm_odds).replace([np.inf, -np.inf], 0).fillna(0)

Let $P_i(x)$ be the frequency of amino acid $x$ at position $i$ in the input PSSM, and let $B_i(x)$ be the background frequency of amino acid $x$ at the same position, derived from a background model corresponding to the specified site type (`S`, `T`, `Y`, or `STY`).

The log-odds score at each position $i$ for amino acid $x$ is computed as:

$$
\mathrm{LO}_i(x) = \log_2 \left( \frac{P_i(x) + \varepsilon}{B_i(x) + \varepsilon} \right)
$$

where $\varepsilon = 10^{-8}$ is a small constant added for numerical stability and to avoid division by zero.

This results in a matrix where:

- Positive values indicate enrichment over background,
- Negative values indicate depletion relative to background,
- Zero indicates no difference from the expected background.

In [None]:
data_y = data_k[data_k.site.str[0]=='Y']

In [None]:
pssm_y = get_prob(data_y,'site_seq')

In [None]:
pssm_LO = get_pssm_LO(pssm_y,'Y')
pssm_LO.head()

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,3.346408,-22.44393,-22.483757,-22.372869,-22.33063,-22.38935,-22.374998,-22.472709,-22.395309,1.987578,-22.494346,0.445197,-22.568837,1.592465,0.481417,-22.533689,-22.473171,0.558965,0.126645,-22.159026,0.0,-21.219306,-22.470595,-22.994372,-22.573771,0.385159,2.042311,-22.44457,-22.627742,-22.737503,-22.549071,-22.614291,-22.347248,-22.450324,-22.595771,-22.811749,-22.536017,-22.489922,-22.51522,-22.34967,-22.682968
G,-22.787984,-22.672109,-22.655669,-22.817496,-22.708532,-22.680581,-22.968862,-22.718379,-22.624127,-22.752478,-22.755633,-22.704314,1.820249,-22.602035,1.70629,-22.640845,1.574374,2.033874,2.576212,-22.485306,0.0,1.630733,1.522227,-22.301926,-22.682104,-22.810346,0.400104,0.432959,1.787902,-22.691406,-22.606718,-22.871788,2.968486,-22.716477,0.361061,0.294925,-22.713856,0.228288,-22.828378,0.399785,-22.816906
A,-22.752579,-22.787101,2.244467,-22.501932,1.566475,0.633103,0.209785,-22.641046,-22.668327,-22.755401,-22.86234,-22.751717,0.30285,-22.57604,-22.685439,-22.680965,-22.568254,-22.71928,-22.719076,-22.414048,0.0,-22.70892,1.140178,0.379419,0.404062,-22.548438,1.174414,1.373115,-22.526813,-22.87308,-22.715532,-22.607643,1.739891,-22.700813,0.116988,-22.679525,3.01613,1.791966,-22.670499,-22.569718,0.388038
C,-20.604884,2.94648,-20.610481,-20.69634,-20.705479,-20.613911,-20.674559,-20.555965,-20.395309,-20.465246,-20.723212,-20.683076,-20.273082,-20.562865,-20.151208,-20.637713,-20.490617,-20.197935,-20.073742,-20.197731,0.0,-20.723779,-20.427527,-20.857363,-20.533508,-20.380203,-20.784654,-20.54043,-20.32518,-20.087586,-20.474525,-20.292364,-20.675574,-20.827394,-20.497208,-20.41133,-20.843591,-20.77741,-20.951953,-20.735729,-20.464922
S,-22.084627,-22.150318,-22.199808,-22.059594,-22.210853,-22.220374,-22.174129,0.71176,-22.086471,-21.971941,-21.907636,-21.850281,-22.02135,-22.147827,-21.770936,-21.831141,-21.924603,-21.627747,-21.621229,-21.09212,0.0,-21.5004,-21.519269,-21.615861,-22.015041,-22.077133,-21.879565,-21.929472,-21.962609,-21.994476,-21.856395,-22.144041,-22.079396,-22.071062,-22.125238,-22.129558,-22.149269,-22.20452,-22.361597,-22.234534,-22.274444


In [None]:
pssm_y[0][pssm_y[0]==1].index

Index(['pY'], dtype='object', name='aa')

In [None]:
pssm_LO[0].sort_values() # log-odds is zero at center position when single site log-odds pssm

aa
P     0.0
G     0.0
     ... 
pT    0.0
pY    0.0
Name: 0, Length: 23, dtype: float64

In [None]:
#| export
def get_pssm_LO_flat(flat_pssm,
                    site_type, # S, T, Y, ST, or STY
                    ):
    pssm_df = recover_pssm(flat_pssm)
    return get_pssm_LO(pssm_df,site_type)

In [None]:
pssm_LO = get_pssm_LO_flat(flat_pssm,'STY')

In [None]:
pssm_LO

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.586661,0.394140,0.374573,0.453820,0.289470,0.358996,0.482618,0.248724,0.474829,0.389672,0.109081,0.301898,0.336760,0.501053,0.515588,0.480427,0.515543,0.755197,1.058815,0.375288,0.000000,2.390017,0.400188,0.338692,0.471584,0.492282,0.449106,0.719679,0.363679,0.522566,0.321624,0.304491,0.525998,0.281515,0.236806,0.264957,0.545151,0.560067,-0.022594,0.367797,0.484510
G,0.016858,0.076312,-0.002345,-0.350629,0.235031,-0.081357,0.200054,0.090158,0.242282,0.091220,0.285770,0.325105,-0.114557,0.097086,0.230386,0.205338,0.155133,0.172609,0.046828,0.230478,0.000000,-1.025210,0.477712,-0.167647,-0.357066,0.253819,0.024980,-0.031828,0.169769,-0.015015,0.151653,-0.064451,0.045944,0.379475,0.067256,0.072006,0.275037,-0.179219,0.127575,0.299872,-0.100680
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
pT,0.911584,0.912650,1.056510,0.470853,0.670212,1.311175,0.429132,1.116517,0.677640,0.513274,0.960222,0.625347,0.681687,0.332545,0.822099,0.639714,0.489330,0.812538,0.380904,0.113449,0.517671,-0.851332,0.094978,-0.048091,0.336723,-0.327027,0.592702,0.869663,0.410747,0.426680,0.319514,0.750686,0.527680,1.073488,0.566722,0.518353,0.979661,0.740521,1.009246,0.904865,0.266921
pY,-1.508258,-0.327763,0.479616,-0.163146,-0.511691,0.115212,-0.743236,-0.951924,-0.347715,-0.199543,-0.067569,0.148547,0.000651,-0.746039,-0.322579,-0.285517,0.017292,-1.040118,-0.684355,-1.496942,-4.500376,-2.396887,-1.123007,-0.237850,-0.777896,-0.000605,-1.465025,-0.917736,-0.957110,-0.367131,0.108010,-0.646108,0.325134,-0.135098,-0.112158,-0.764045,-1.033425,-1.598003,0.047446,-0.108367,-0.179734


## PSSMs of clusters

In [None]:
#| export
def get_cluster_pssms(df, 
                    cluster_col, 
                    seq_col='site_seq', 
                    id_col = 'sub_site',
                    count_thr=10, # if less than the count threshold, not include in the return
                    valid_thr=None, # percentage of not-nan values in pssm
                    IC_thr = None,
                      plot=False):
    "Extract motifs from clusters in a dataframe"
    pssms = []
    ids = []
    # drop duplicates based both on cluster column and substrate seq id column
    value_counts = df.drop_duplicates(subset=[cluster_col,id_col])[cluster_col].value_counts()
    for cluster_id, counts in tqdm(value_counts.items(),total=len(value_counts)):
        if count_thr is not None and counts < count_thr:
            continue

        if id_col is not None:
            df_cluster = df[df[cluster_col] == cluster_id].drop_duplicates(id_col)
        else:
            df_cluster = df[df[cluster_col] == cluster_id]
        
        n= len(df_cluster)
        pssm = get_prob(df_cluster, seq_col)
        valid_score = (pssm != 0).sum().sum() / (pssm.shape[0] * pssm.shape[1])

        if valid_thr is not None and valid_score <= valid_thr:
            continue

        pssms.append(flatten_pssm(pssm))
        ids.append(cluster_id)

        if plot:
            plot_logo(pssm, title=f'Cluster {cluster_id} (n={n})', figsize=(14, 1))
            plt.show()
            plt.close()

    pssm_df = pd.DataFrame(pssms, index=ids)
    return pssm_df

In [None]:
data_10 = data.head(10)

In [None]:
get_cluster_pssms(data,'kinase_group')

100%|█████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 10.17it/s]


Unnamed: 0,-20P,-20G,-20A,-20C,-20S,-20T,-20V,-20I,-20L,-20M,-20F,-20Y,-20W,-20H,-20K,-20R,-20Q,-20N,-20D,-20E,-20pS,-20pT,-20pY,-19P,-19G,-19A,-19C,-19S,-19T,-19V,-19I,-19L,-19M,-19F,-19Y,-19W,-19H,-19K,-19R,-19Q,-19N,-19D,-19E,-19pS,-19pT,-19pY,-18P,-18G,-18A,-18C,...,18E,18pS,18pT,18pY,19P,19G,19A,19C,19S,19T,19V,19I,19L,19M,19F,19Y,19W,19H,19K,19R,19Q,19N,19D,19E,19pS,19pT,19pY,20P,20G,20A,20C,20S,20T,20V,20I,20L,20M,20F,20Y,20W,20H,20K,20R,20Q,20N,20D,20E,20pS,20pT,20pY
TK,0.056754,0.069067,0.070474,0.016886,0.045028,0.039400,0.058630,0.049836,0.083255,0.024038,0.032716,0.023335,0.007974,0.021107,0.069184,0.056168,0.045966,0.036585,0.058865,0.081027,0.024508,0.011843,0.017355,0.057915,0.067158,0.073125,0.016731,0.045396,0.037440,0.053001,0.049959,0.087633,0.023751,0.031356,0.018018,0.009594,0.022230,0.074529,0.063999,0.043173,0.041652,0.060840,0.073359,0.025389,0.010881,0.012870,0.058123,0.067810,0.062792,0.014006,...,0.080097,0.027670,0.012621,0.012743,0.058601,0.065424,0.065302,0.014864,0.045078,0.038986,0.056408,0.048124,0.087354,0.018884,0.042276,0.017178,0.008894,0.018519,0.073830,0.061769,0.047149,0.040327,0.059698,0.078216,0.028509,0.010599,0.014011,0.066015,0.072738,0.065037,0.013692,0.048900,0.037897,0.054523,0.050122,0.077262,0.023227,0.039364,0.018337,0.008435,0.019071,0.070905,0.059413,0.040465,0.038509,0.056357,0.083496,0.025917,0.012836,0.017482
CMGC,0.080589,0.070340,0.083792,0.013709,0.050865,0.035874,0.053812,0.035874,0.074824,0.022037,0.029468,0.015247,0.007559,0.020884,0.069058,0.057143,0.044331,0.032031,0.053299,0.078668,0.042921,0.020115,0.007559,0.079718,0.075496,0.072937,0.012028,0.058733,0.035061,0.053871,0.032885,0.074472,0.021753,0.026104,0.013436,0.007806,0.021241,0.064491,0.063596,0.046449,0.037492,0.055534,0.080614,0.042482,0.018298,0.005502,0.074292,0.068930,0.074419,0.013786,...,0.084768,0.052185,0.016556,0.005960,0.074212,0.073813,0.074212,0.011970,0.053731,0.034047,0.050805,0.031520,0.074345,0.017423,0.028195,0.015029,0.010374,0.020215,0.070887,0.063173,0.045884,0.033914,0.056124,0.086980,0.048278,0.018619,0.006251,0.082465,0.066961,0.074980,0.012563,0.050789,0.034215,0.050521,0.038626,0.079925,0.019246,0.028869,0.016840,0.006816,0.023924,0.072307,0.060011,0.046111,0.035285,0.057070,0.074044,0.043972,0.018043,0.006415
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CK1,0.057860,0.076419,0.076965,0.014192,0.039847,0.037118,0.063865,0.052402,0.070961,0.017467,0.027293,0.012555,0.014192,0.024017,0.079694,0.050764,0.033297,0.035480,0.058406,0.086790,0.038210,0.015830,0.016376,0.061069,0.070883,0.069793,0.007634,0.033806,0.032715,0.053435,0.043075,0.087786,0.023991,0.022901,0.016903,0.010905,0.020720,0.076881,0.059978,0.056161,0.037077,0.061069,0.087786,0.041439,0.017448,0.006543,0.050000,0.067935,0.073913,0.010870,...,0.084629,0.041451,0.015544,0.012090,0.053148,0.068746,0.069324,0.012709,0.036973,0.021953,0.047371,0.042172,0.082034,0.017909,0.030040,0.017909,0.020220,0.019064,0.095321,0.047371,0.041594,0.034084,0.065280,0.098209,0.041594,0.021953,0.015020,0.051865,0.074592,0.088578,0.015152,0.039627,0.034382,0.055361,0.039627,0.074592,0.020979,0.034382,0.020396,0.008741,0.026224,0.079837,0.046620,0.038462,0.043124,0.067599,0.087995,0.027972,0.012238,0.011655
Atypical,0.071895,0.065359,0.063492,0.014939,0.053221,0.045752,0.058824,0.050420,0.085901,0.020542,0.017740,0.020542,0.005602,0.034547,0.049486,0.057890,0.040149,0.034547,0.059757,0.078431,0.041083,0.025210,0.004669,0.042870,0.052190,0.074557,0.013979,0.063374,0.036347,0.054986,0.031687,0.089469,0.027959,0.036347,0.015843,0.010252,0.026095,0.074557,0.056850,0.054054,0.034483,0.063374,0.068966,0.046598,0.018639,0.006524,0.051068,0.063138,0.072423,0.012071,...,0.072175,0.047483,0.013295,0.015195,0.059829,0.081671,0.066477,0.017094,0.047483,0.044634,0.050332,0.034188,0.063628,0.010446,0.037037,0.016144,0.006648,0.027540,0.047483,0.055081,0.057930,0.043685,0.059829,0.091168,0.039886,0.027540,0.014245,0.077290,0.060115,0.062977,0.015267,0.055344,0.029580,0.065840,0.026718,0.092557,0.017176,0.039122,0.023855,0.009542,0.016221,0.046756,0.054389,0.050573,0.055344,0.062977,0.068702,0.040076,0.020992,0.008588


## Entropy

In [None]:
#| export
def get_entropy(pssm_df,# a dataframe of pssm with index as aa and column as position
            return_min=False, # return min entropy as a single value or return all entropy as a pd.series
            exclude_zero=False, # exclude the column of 0 (center position) in the entropy calculation
            clean_zero=True, # if true, zero out non-last three values in position 0 (keep only s,t,y values at center)
            ): 
    "Calculate entropy per position of a PSSM surrounding 0. The less entropy the more information it contains."
    pssm_df = pssm_df.copy()
    pssm_df.columns= pssm_df.columns.astype(int)
    if 0 in pssm_df.columns:
        if clean_zero:                       
            pssm_df = _clean_zero(pssm_df)
        if exclude_zero:
            # remove columns starts with zero and columns with interger name 0
            cols_to_drop = [col for col in pssm_df.columns 
                            if col == 0 or (isinstance(col, str) and col.startswith('0'))]
            if cols_to_drop: pssm_df = pssm_df.drop(columns=cols_to_drop)

    pssm_df = pssm_df/pssm_df.sum()
    per_position = -np.sum(pssm_df * np.log2(pssm_df + EPSILON), axis=0)
    per_position[pssm_df.sum() == 0] = 0
    return float(per_position.min()) if return_min else per_position

Let $P_i(x)$ be the probability of amino acid $x$ at position $i$ in the PSSM, with $i \in \{-k, \dots, -1, 0, +1, \dots, +k\}$. The entropy at each position $i$ is defined as:

$$
H_i = - \sum_{x} P_i(x) \log_2 \left( P_i(x) + \varepsilon \right)
$$

where $\varepsilon = 10^{-8}$ is a small constant added for numerical stability.

If `exclude_zero=True`, the central position $i = 0$ is omitted from the entropy calculation.

If `clean_zero=True`, all values at position $i = 0$ are zeroed out except for amino acids Serine (S), Threonine (T), and Tyrosine (Y), typically the only possible phospho-acceptors in kinase motif analysis.

If `return_min=True`, the function returns the minimum entropy across all positions:

$$
H_{\text{spec}} = \min_i H_i
$$

Otherwise, the function returns the full vector $\{H_i\}$ for each position $i$, reflecting how much information (or uncertainty) is contained at each position in the motif.


In [None]:
# get entropy per position
get_entropy(pssm_df).sort_values()

Position
 0     0.987416
 1     1.740698
         ...   
-18    4.284598
 14    4.285491
Length: 41, dtype: float64

In [None]:
# calculate minimum entropy of surrouding positions
get_entropy(pssm_df,return_min=True,exclude_zero=True)

1.7406981100302623

In [None]:
#| export
@delegates(get_entropy)
def get_entropy_flat(flat_pssm:pd.Series,**kwargs): 
    "Calculate entropy per position of a flat PSSM surrounding 0"
    pssm_df = recover_pssm(flat_pssm)
    return get_entropy(pssm_df,**kwargs)

In [None]:
get_entropy_flat(flat_pssm).sort_values()

Position
 0     0.987416
 1     1.740698
         ...   
-18    4.284598
 14    4.285491
Length: 41, dtype: float64

In [None]:
get_entropy_flat(flat_pssm,return_min=True,exclude_zero=True)

1.7406981100302623

In [None]:
# test equal
(get_entropy_flat(flat_pssm).round(5) == get_entropy(pssm_df).round(5)).value_counts()

True    41
Name: count, dtype: int64

## Information Content

In [None]:
#| export
@delegates(get_entropy)
def get_IC(pssm_df,**kwargs):
    """
    Calculate the information content (bits) from a frequency matrix,
    using log2(3) for the middle position and log2(len(pssm_df)) for others.
    The higher the more information it contains.
    """
    
    entropy_position = get_entropy(pssm_df,**kwargs)
    
    max_entropy_array = pd.Series(np.log2(len(pssm_df)), index=entropy_position.index)

    # set exclude_zero to False
    exclude_zero = kwargs.get('exclude_zero', False)
    if exclude_zero is False: max_entropy_array[0] = np.log2(3)

    # information_content = max_entropy - entropy --> log2(N) - entropy
    IC_position = max_entropy_array - entropy_position

    # if entropy is zero, set to zero as there's no value
    IC_position[entropy_position == 0] = 0
    return IC_position

Let $P_i(x)$ be the frequency (probability) of amino acid $x$ at position $i$ in the PSSM. The standard information content (IC) at position $i$ is defined as:

$$
\mathrm{IC}_i = \max H_i - H_i
$$

which is:

$$
\mathrm{IC}_i = \log_2(N) - H_i
$$

where $N$ is the number of possible amino acids (i.e., $N = \text{len}(P_i)$).

At the center position ($i = 0$), only three amino acids (S, T, Y) are relevant, so the maximum entropy at each position is defined as:

$$
\max H_i =
\begin{cases}
\log_2(3) & \text{if } i = 0 \\
\log_2(N) & \text{otherwise}
\end{cases}
$$


In [None]:
# the higher the more conserved
get_IC(pssm_df,exclude_zero=True).sort_values()

Position
 14    0.238071
-18    0.238964
         ...   
 3     0.575586
 1     2.782864
Length: 40, dtype: float64

Check all zero cases:

In [None]:
pssm_df2=pssm_df.copy()

In [None]:
pssm_df2[-20]=0

In [None]:
get_entropy(pssm_df2,exclude_zero=True).sort_values()

Position
-20    0.000000
 1     1.740698
         ...   
-18    4.284598
 14    4.285491
Length: 40, dtype: float64

In [None]:
#| export
@delegates(get_IC)
def get_IC_flat(flat_pssm:pd.Series,**kwargs):
    """Calculate the information content (bits) from a flattened pssm pd.Series,
    using log2(3) for the middle position and log2(len(pssm_df)) for others."""
    
    pssm_df = recover_pssm(flat_pssm)
    return get_IC(pssm_df,**kwargs)

In [None]:
get_IC_flat(flat_pssm,exclude_zero=True).sort_values()

Position
 14    0.238071
-18    0.238964
         ...   
 3     0.575586
 1     2.782864
Length: 40, dtype: float64

In [None]:
(get_IC_flat(flat_pssm).round(5) == get_IC(pssm_df).round(5)).value_counts()

True    41
Name: count, dtype: int64

## Overall specificity

In [None]:
#| export
def get_specificity(pssm_df):
    "Get specificity score of a pssm, excluding zero position."
    ICs = get_IC(pssm_df, exclude_zero=True)
    # only consider IC with values
    ICs= ICs[ICs > 0]
    return float(2*ICs.max()+ICs.var())

We evaluated the overall specificity of a PSSM by combining two metrics: the maximum IC across surrounding positions and the variance of IC values:

$$
\text{Specificity Score} = 2 \times \max(\text{IC}) + \mathrm{Var}(\text{IC})
$$

In [None]:
get_specificity(pssm_df)

5.7236223416577445

In [None]:
#| export
def get_specificity_flat(flat_pssm):
    "Get specificity score of a pssm, excluding zero position."
    ICs = get_IC_flat(flat_pssm, exclude_zero=True)
    return float(2*ICs.max()+ICs.var())

In [None]:
get_specificity_flat(flat_pssm)

5.7236223416577445

## Plot

### Heatmap

In [None]:
#| export
@delegates(sns.heatmap)
def plot_heatmap_simple(matrix, # a matrix of values
                 title: str='heatmap', # title of the heatmap
                 figsize: tuple=(6,7), # (width, height)
                 cmap: str='binary', # color map, default is dark&white
                 **kwargs, # arguments for sns.heatmap()
                 ):
    
    "Plot heatmap based on a matrix of values"
    
    plt.figure(figsize=figsize)
    sns.heatmap(matrix, square=True,cmap=cmap, annot=False,**kwargs)
    plt.title(title)
    plt.ylabel('')
    plt.xlabel('')
    plt.yticks(rotation=0)

In [None]:
# plot_heatmap_simple(pssm_df,'kinase',figsize=(10,7))

In [None]:
#| export
def plot_heatmap(heatmap_df, ax=None, position_label=True, figsize=(5, 6), include_zero=True,scale_pos_neg=False, colorbar_title='Prob.'):
    """Plots a heatmap with specific formatting."""
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    mask = np.zeros_like(heatmap_df, dtype=bool)
    zero_position = len(heatmap_df.columns) // 2
    second_position = math.ceil(len(heatmap_df.columns) / 2)
    # If they overlap, move the second line one step to the right (if possible)
    if second_position == zero_position: second_position = second_position + 1

    if not include_zero:
        mask[:, zero_position] = True  # Mask position 0 if include_zero is False

    if scale_pos_neg:
        vmin,vmax = heatmap_df.min().min(),heatmap_df.max().max()
        norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
        sns.heatmap(
            heatmap_df,
            cmap='coolwarm',
            norm=norm,
            linewidth=0.3,
            ax=ax,
            mask=mask
        )
    else:
        sns.heatmap(
            heatmap_df,
            cmap='coolwarm',
            center=0,  # Center for diverging colormap
            linewidth=0.3,
            ax=ax,
            mask=mask
        )


    # Access and format the color bar
    colorbar = ax.collections[0].colorbar
    colorbar.ax.set_title(colorbar_title, loc='center')

    # Add vertical lines
    ax.axvline(zero_position, color='black', linewidth=0.5)
    ax.axvline(second_position, color='black', linewidth=0.5)

    # Format the heatmap border
    ax.patch.set_edgecolor("black")
    ax.patch.set_linewidth(1.5)

    # Hide axis labels
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.xaxis.set_ticks_position('top')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    if not position_label:
        ax.set_xticklabels([])

    return ax

This function visualizes a PSSM or log-odds matrix as a heatmap with diverging color scales centered at 0.

**Color scale behavior**:

- By default (`scale_pos_neg=False`), the colormap is **centered at 0**, but the full data range determines the color intensity:

  $$
  \text{color range} = [\min(\text{data}), \max(\text{data})], \quad \text{with center at } 0
  $$

  This is useful when you want to emphasize whether values are **above or below zero**, but without enforcing symmetry.

- If `scale_pos_neg=True`, the function uses a **balanced diverging scale** via `TwoSlopeNorm`, such that:

  $$
  \text{min color} = \min(\text{data}), \quad
  \text{center} = 0, \quad
  \text{max color} = \max(\text{data})
  $$

  The positive and negative ranges are **scaled separately**, ensuring that both ends of the heatmap have equal visual weight — especially helpful for symmetric data like log-odds matrices.

**Additional visual features**:
- The center position ($i = 0$) can be masked out using `include_zero=False`.

In [None]:
# plot_heatmap(pssm_df-0.3,scale_pos_neg=False,figsize=(20, 6));

In [None]:
# plot_heatmap(pssm_df-0.3,scale_pos_neg=True,figsize=(20, 6));

In [None]:
plt.close('all')

In [None]:
#| export
def plot_two_heatmaps(matrix1, matrix2, 
                      kinase_name="Kinase", title1='CDDM',title2='PSPA',
                      figsize=(4,4.5), cbar=True,scale_01=False,
                      **kwargs):
    """
    Plot two side-by-side heatmaps with black rectangle borders,
    titles on top, shared kinase label below, and only left plot showing y-axis labels.
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'wspace': 0.05})
    matrix1 = matrix1.drop(columns=0)
    matrix2 = matrix2.drop(columns=0)

    # Left heatmap (with y labels)
    im1=sns.heatmap(matrix1, square=False, cmap="Reds", 
                    **{'vmin':0, 'vmax':1} if scale_01 else {},
                annot=False, cbar=False, ax=axes[0], **kwargs)
    axes[0].set_title(title1, fontsize=12, pad=10)
    axes[0].tick_params(left=True, bottom=True)
    axes[0].tick_params(axis="y", rotation=0) 

    # Add rectangle border to left heatmap
    axes[0].add_patch(Rectangle((0,0), matrix1.shape[1], matrix1.shape[0], 
                                fill=False, edgecolor='black', lw=1.5))

    # Right heatmap (no y labels)
    im2=sns.heatmap(matrix2, square=False, cmap="Blues",
                    **{'vmin':0, 'vmax':1} if scale_01 else {},
                annot=False, cbar=False, ax=axes[1], **kwargs)
    axes[1].set_title(title2, fontsize=12, pad=10)
    axes[1].tick_params(left=False, labelleft=False, bottom=True)
    axes[1].tick_params(axis="y", rotation=0) 

    # Add rectangle border to right heatmap
    axes[1].add_patch(Rectangle((0,0), matrix2.shape[1], matrix2.shape[0], 
                                fill=False, edgecolor='black', lw=1.5))

    # Make y tick horizontal
    axes[0].set_title(title1, fontsize=12, pad=0)
    axes[1].set_title(title2, fontsize=12, pad=0)

    # ---- Add vertical separator lines ----
    xpos = (list(matrix1.columns).index(-1) + list(matrix1.columns).index(1)) / 2 + 0.5
    for ax in axes:
        ax.axvline(xpos, color='black', lw=0.75)
        ax.set_xticks([])
        # ax.set_xlabel("")
        ax.set_ylabel("")

    if cbar:
        for ax, im, label in zip(axes, [im1, im2], [title1, title2]):
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("bottom", size="5%", pad=0.25)
            cbar = plt.colorbar(ax.collections[0], cax=cax, orientation="horizontal")
            
            # remove trailing zeros like 0.00 → 0, 1.00 → 1
            cbar.ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:g}"))
            cax.tick_params(labelsize=8)

    # Shared kinase label below
    fig.suptitle(kinase_name, fontsize=14, x=0.52,y=0.96)

In [None]:
# pssm1 = recover_pssm(target.loc[idx])
# pssm2 = recover_pssm(pred.loc[idx])
# plot_two_heatmaps(pssm1,pssm2,f'{idx}:{score:.3f}','PSP','Predicted')

### Logo motif

In [None]:
#| export
def change_center_name(df):
    "Transfer the middle pS,pT,pY to S,T,Y for plot."
    df=df.copy()
    
    df.loc['S', 0] = df.loc['pS', 0]
    df.loc['T', 0] = df.loc['pT', 0]
    df.loc['Y', 0] = df.loc['pY', 0]
    df.loc[['pS', 'pT', 'pY'], 0] = 0
    return df

Now instead of pS, pT, and pY, the center name becomes S, T and Y:

In [None]:
change_center_name(pssm_df)[0].sort_values(ascending=False).head()

aa
S    0.664786
T    0.324738
Y    0.010475
A    0.000000
G    0.000000
Name: 0, dtype: float64

In [None]:
#| export
def get_pos_min_max(pssm_df):
    """
    Get min and max value of sum of positive and negative values across each position.
    """
    pssm_df = pssm_df.copy()
    pssm_neighbor = pssm_df.drop(columns=0)
    
    max_sum_pos = pssm_neighbor[pssm_neighbor>0].sum().max()
    max_sum_neg = pssm_neighbor[pssm_neighbor<0].sum().min()
    return max_sum_neg,max_sum_pos

In [None]:
#| export
def scale_zero_position(pssm_df):
    """
    Scale position 0 so that:
    - Positive values match the max positive column sum of other positions
    - Negative values match the min (most negative) column sum of other positions
    """
    max_sum_neg,max_sum_pos = get_pos_min_max(pssm_df)

    zero_col = pssm_df[0]
    zero_col_pos = zero_col[zero_col>0]
    zero_col_neg = zero_col[zero_col<0]
    
    scaled_col = zero_col.copy()
    if not zero_col_pos.empty and zero_col_pos.sum() != 0:
        scaled_col.loc[zero_col_pos.index] = max_sum_pos * (zero_col_pos / zero_col_pos.sum())
    if not zero_col_neg.empty and zero_col_neg.sum() != 0:
        scaled_col.loc[zero_col_neg.index] = max_sum_neg * (zero_col_neg / zero_col_neg.sum())

    pssm_df[0] = scaled_col
    return pssm_df
    

This function rescales **position 0** in a log-odds PSSM so that its total positive and negative stack heights match those of the most extreme positions on either side.

This ensures the central position visually matches the dynamic range of surrounding positions in log-odds logo plots.


In [None]:
#| export
def scale_pos_neg_values(pssm_df):
    """
    Globally scale all positive values by max positive column sum,
    and negative values by min negative column sum (preserving sign).
    """
    pssm_df = pssm_df.copy()
    max_sum_neg, max_sum_pos = get_pos_min_max(pssm_df)

    pos_part = pssm_df.clip(lower=0)
    neg_part = pssm_df.clip(upper=0)

    if max_sum_pos != 0: pos_part = pos_part / max_sum_pos
    if max_sum_neg != 0: neg_part = neg_part / abs(max_sum_neg)  # make sure sign is correct

    return pos_part + neg_part

In [None]:
#| export
def convert_logo_df(pssm_df,scale_zero=True,scale_pos_neg=False):
    "Change center name from pS,pT,pY to S, T, Y in a pssm and scaled zero position to the max of neigbors."
    pssm_df = change_center_name(pssm_df)
    if scale_zero: pssm_df = scale_zero_position(pssm_df)
    if scale_pos_neg: pssm_df = scale_pos_neg_values(pssm_df)
    return pssm_df

In [None]:
#| export
def plot_logo_raw(pssm_df,ax=None,title='Motif',ytitle='Bits',figsize=(10,2)):
    "Plot logo motif using Logomaker."
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    logo = logomaker.Logo(pssm_df.T, color_scheme='kinase_protein', flip_below=False, ax=ax)
    logo.ax.set_ylabel(ytitle)
    logo.style_xticks(fmt='%d')
    ax.set_title(title)

In [None]:
# plot_logo_raw(pssm_df)

In [None]:
#| export
def get_logo_IC(pssm_df):
    """
    For plotting purpose, calculate the scaled information content (bits) from a frequency matrix,
    using log2(3) for the middle position and log2(len(pssm_df)) for others.
    """
    IC_position = get_IC(pssm_df)
    
    return pssm_df.mul(IC_position, axis=1) # total_IC = pssm_df.sum().sum().round(2)



To visualize the motif using Logomaker, the scaled PSSM is computed by weighting each amino acid’s frequency at position $i$ by the position’s information content:

$$
\text{PSSM\_scaled}_i(x) = P_i(x) \cdot \mathrm{IC}_i
$$

This results in a matrix where the total stack height at each position equals the information content, and each letter’s height is proportional to its contribution. This is the standard format used by Logomaker to generate sequence logos.


In [None]:
get_logo_IC(pssm_df)

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.028979,0.021736,0.019696,0.026208,0.020191,0.020036,0.026940,0.020702,0.024652,0.022535,0.017114,0.021499,0.023403,0.026208,0.031930,0.024593,0.032412,0.039787,0.068936,0.029358,0.000000,2.109590,0.033707,0.051101,0.028948,0.037501,0.028892,0.036076,0.027630,0.028354,0.025461,0.025050,0.022227,0.020044,0.018602,0.024891,0.026862,0.027559,0.015651,0.022606,0.025402
G,0.019969,0.019102,0.016480,0.015774,0.020616,0.015743,0.025669,0.019380,0.021765,0.020527,0.020852,0.023912,0.016652,0.021528,0.027406,0.021633,0.028396,0.031460,0.028898,0.034539,0.000000,0.087526,0.036227,0.036235,0.017093,0.030001,0.021343,0.022138,0.024241,0.019139,0.021217,0.019028,0.016569,0.022389,0.016426,0.020782,0.023620,0.017255,0.018122,0.023077,0.017237
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
pT,0.008280,0.007245,0.007637,0.005824,0.005313,0.009405,0.005845,0.008369,0.006219,0.006470,0.007869,0.007459,0.009226,0.007956,0.012506,0.010703,0.014055,0.018197,0.019497,0.011225,0.194046,0.044885,0.014806,0.014866,0.013233,0.007800,0.011453,0.012025,0.008602,0.007561,0.006837,0.007708,0.005456,0.008529,0.005541,0.006525,0.007873,0.006950,0.007002,0.007771,0.004990
pY,0.000974,0.001537,0.002814,0.002184,0.001275,0.002658,0.001525,0.001101,0.001777,0.002677,0.002361,0.003071,0.002700,0.001872,0.003459,0.003416,0.005737,0.003084,0.005222,0.002878,0.006260,0.011221,0.003780,0.008827,0.003308,0.005400,0.001562,0.002186,0.001825,0.002127,0.003065,0.002168,0.002627,0.001919,0.001979,0.001450,0.001158,0.000719,0.002059,0.002119,0.002268


In [None]:
#| export
def plot_logo(pssm_df,title='Motif', scale_zero=True,ax=None,figsize=(10,1)):
    "Plot logo of information content given a frequency PSSM."
    pssm_df = get_logo_IC(pssm_df)
    pssm_df= convert_logo_df(pssm_df,scale_zero=scale_zero)
    plot_logo_raw(pssm_df,ax=ax,title=title,ytitle='IC (bits)',figsize=figsize)

In [None]:
# plot_logo(pssm_df,scale_zero=False,figsize=(10,1))

Set scale_zero to default True can have better vision of the side amino acids

In [None]:
# plot_logo(pssm_df,figsize=(10,1))

In [None]:
# plt.close('all')

### Logo motif of log-odds

In [None]:
#| export
def plot_logo_LO(pssm_LO,title='Motif', acceptor=None, scale_zero=True,scale_pos_neg=True,ax=None,figsize=(10,1)):
    "Plot logo of log-odds given a frequency PSSM."
    if acceptor is not None: 
        acceptor = acceptor.upper()
        assert acceptor in ['S','T','Y']
        pssm_LO= pssm_LO.copy()
        pssm_LO.loc[f'p{acceptor}',0]=0.1 # give it a value so that it can be shown on the motif

    # return pssm_LO
    pssm_LO= convert_logo_df(pssm_LO,scale_zero=scale_zero,scale_pos_neg=scale_pos_neg)
    ytitle = "Scaled Log-Odds" if scale_pos_neg else "Log-Odds (bits)"
    plot_logo_raw(pssm_LO,ax=ax,title=title,ytitle=ytitle,figsize=figsize)

To ensure the phosphorylated residue is visible at the center of a log-odds motif (position 0), two mechanisms are used:

1. **Acceptor override**:
   If the center column is entirely zero (e.g., masked), the user can specify an `acceptor` (`'S'`, `'T'`, `'Y'`, or `'STY'`). The function then assigns a small nonzero value (e.g., 0.1) to the corresponding phospho-residue row (`pS`, `pT`, `pY`) at position 0. This ensures the central letter appears in the logo plot, even when real log-odds values are absent.

2. **Stack height rescaling**:
   To maintain visual consistency with surrounding columns, position 0 is rescaled so that its total positive and negative stack heights match the most extreme values observed elsewhere.

Together, these adjustments ensure that:
- The phospho-acceptor appears explicitly at the center,
- The visual scale remains consistent with neighboring positions,
- The resulting logo can faithfully reflect both biological relevance and statistical signal.


In [None]:
pssm_LO = get_pssm_LO(pssm_df,'STY')
# plot_logo_LO(pssm_LO,scale_zero=False,scale_pos_neg=False)

In [None]:
## with zero position scaled to the max
# plot_logo_LO(pssm_LO,scale_zero=True,scale_pos_neg=False)

In [None]:
# # scaled positive and negative values for better visualization
# plot_logo_LO(pssm_LO,scale_zero=True,scale_pos_neg=True)

In [None]:
# for those specific site type (S,T or Y), show acceptor in the middle instead of empty
pssm_LO = get_pssm_LO(pssm_y,'Y')
# plot_logo_LO(pssm_LO,acceptor='Y')

In [None]:
# plt.close('all')

### Multiple logos

As multiple figures:

In [None]:
#| export
def plot_logos_idx(pssms_df,*idxs):
    "Plot logos of a dataframe with flattened PSSMs with index ad IDs."
    for idx in idxs:
        pssm = recover_pssm(pssms_df.loc[idx])
        plot_logo(pssm,title=f'Motif {idx}',figsize=(14,1))
        plt.show()
        plt.close()

In [None]:
pssms = Data.get_cddm()

In [None]:
# plot_logos_idx(pssms,*pssms.index[:2])

In one figure:

In [None]:
#| export
def plot_logos(pssms_df, 
               count_dict=None, # used to display n in motif title
               path=None,
               prefix='Motif'):
    """
    Plot all logos from a dataframe of flattened PSSMs as subplots in a single figure.
    """
    n = len(pssms_df)
    hspace=0.7
    # 14 is width, 1 is height for each logo
    fig, axes = plt.subplots(nrows=n, figsize=(14, n * (1+hspace)),gridspec_kw={'hspace': hspace+0.1})

    if n == 1:
        axes = [axes]  # ensure axes is iterable

    for ax, idx in zip(axes, pssms_df.index):
        pssm = recover_pssm(pssms_df.loc[idx])
        if count_dict is not None:
            plot_logo(pssm, title=f'{prefix} {idx} (n={count_dict[idx]:,})',ax=ax)
        else:
            plot_logo(pssm, title=f'{prefix} {idx}',ax=ax)

In [None]:
# plot_logos(pssms.head(2))

In [None]:
# plt.close('all')

### Logo motif + Heatmap

In [None]:
#| export
def plot_logo_heatmap(pssm_df, # column is position, index is aa
                       title='Motif',
                       figsize=(17,10),
                       include_zero=False
                      ):
    
    """Plot logo and heatmap vertically"""
    
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, 2, height_ratios=[1, 5], width_ratios=[4, 1], hspace=0.11, wspace=0)

    ax_logo = fig.add_subplot(gs[0, 0])
    plot_logo(pssm_df,ax=ax_logo,title=title)

    ax_heatmap = fig.add_subplot(gs[1, :])
    plot_heatmap(pssm_df,ax=ax_heatmap,position_label=False,include_zero=include_zero)

In [None]:
# plot_logo_heatmap(pssm_df,'Kinase',(17,10))

In [None]:
#| export
def plot_logo_heatmap_LO(pssm_LO, # pssm of log-odds
                             title='Motif',
                         acceptor=None,
                             figsize=(17,10),
                             include_zero=False,
                         scale_pos_neg=True
                      ):
    
    """Plot logo and heatmap of enrichment bits vertically"""
    
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, 2, height_ratios=[1, 5], width_ratios=[4, 1], hspace=0.11, wspace=0)

    ax_logo = fig.add_subplot(gs[0, 0])
    plot_logo_LO(pssm_LO,acceptor=acceptor,ax=ax_logo,title=title)

    ax_heatmap = fig.add_subplot(gs[1, :])
    plot_heatmap(pssm_LO,ax=ax_heatmap,position_label=False,include_zero=include_zero,scale_pos_neg=scale_pos_neg,colorbar_title='bits')

In [None]:
# plot_logo_heatmap_LO(pssm_LO,acceptor='Y')

In [None]:
pssm_LO = get_pssm_LO(pssm_df,'STY')
# plot_logo_heatmap_LO(pssm_LO,scale_pos_neg=False) # normal color scale

In [None]:
# plt.close('all')

## PSPA

### Plot

In [None]:
#| hide
#| export
def change_center_name_series(s: pd.Series) -> pd.Series:
    """Transfer the middle pS,pT,pY to S,T,Y for plot (Series version)."""
    s = s.copy()

    # Move values from pS,pT,pY → S,T,Y
    if "pS" in s.index: 
        s["S"] = s["pS"]
        s["pS"] = 0
    if "pT" in s.index:
        s["T"] = s["pT"]
        s["pT"] = 0
    if "pY" in s.index:
        s["Y"] = s["pY"]
        s["pY"] = 0
    return s

In [None]:
#| hide
#| export
def recover_pssm_pspa(row):
    pssm = _clean_zero(recover_pssm(row))
    pssm = pssm.loc[:, pssm.sum() != 0] # drop column with all zero
    return pssm

In [None]:
#| hide
#| export
def preprocess_pssm_pspa(pssm):
    pssm = pssm.copy()
    col0 = pssm[0]
    pssm = pssm.drop(columns=0).copy()
    pssm = pssm.drop(index='pS').rename(index={'pT':'pS/pT'})
    pssm = np.log2(pssm/pssm.median())
    col0 = change_center_name_series(col0)
    col0 = col0.drop(index='pS').rename(index={'pT':'pS/pT'})
    pssm[0] = col0
    pssm=scale_zero_position(pssm)
    return pssm

In [None]:
#| export
def plot_logo_pspa(row,title='Motif',figsize=(5,2)):
    pssm = recover_pssm_pspa(row)
    logo_pssm = preprocess_pssm_pspa(pssm)
    plot_logo_raw(logo_pssm,ytitle='log₂(Value / Median)',title=title,figsize=figsize)

In [None]:
#| export
def plot_logo_heatmap_pspa(row, # row of Data.get_pspa_all_norm()
                       title='Motif',
                       figsize=(6,10),
                       include_zero=False
                      ):

    """Plot logo and heatmap vertically"""
    pssm = recover_pssm_pspa(row)
    
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, 2, height_ratios=[1, 5], width_ratios=[4, 1], hspace=0.11, wspace=0)

    ax_logo = fig.add_subplot(gs[0, 0])
    
    logo_pssm = preprocess_pssm_pspa(pssm)
    plot_logo_raw(logo_pssm,ax=ax_logo, ytitle='log₂(Value / Median)',title=title)

    ax_heatmap = fig.add_subplot(gs[1, :])
    plot_heatmap(pssm,ax=ax_heatmap,position_label=False,include_zero=include_zero,colorbar_title='Value')

In [None]:
pspa= Data.get_pspa_all_norm()

In [None]:
# plot_logo_heatmap_pspa(pspa.iloc[0],title='kinase')

### Calculations

In [None]:
#| export
def raw2norm(df: pd.DataFrame, # single kinase's df has position as index, and single amino acid as columns
             PDHK: bool=False, # whether this kinase belongs to PDHK family 
            ):
    
    "Normalize single ST kinase data"
    columns_to_exclude = ['S', 'T', 'C', 't', 'y']
    
    if PDHK:
        columns_to_exclude.append('Y')
        divisor = 16
    else:
        divisor = 17
    
    s = df.drop(columns=columns_to_exclude).sum(1)
    df2 = df.div(s, axis=0)
    df2.C = df2.C / (df2.C.median() * divisor)
    df2['S'] = df2.drop(columns=columns_to_exclude).median(1)
    df2['T'] = df2.drop(columns=columns_to_exclude).median(1)
    df2 = round(df2, 4)
    
    return df2

This function implement the normalization method from [Johnson et al. Nature: An atlas of substrate specificities for the human serine/threonine kinome](https://www.nature.com/articles/s41586-022-05575-3#Sec6)

Specifically,
> - matrices were column-normalized at all positions by the sum of the 17 randomized amino acids (excluding serine, threonine and cysteine), to yield PSSMs. 
>- PDHK1 and PDHK4 were normalized to the 16 randomized amino acids (excluding serine, threonine, cysteine and additionally tyrosine)
>- The cysteine row was scaled by its median to be 1/17 (1/16 for PDHK1 and PDHK4). 
>- The serine and threonine values in each position were set to be the median of that position.
>- The S0/T0 ratio was determined by summing the values of S and T rows in the matrix (SS and ST, respectively), accounting for the different S vs. T composition of the central (1:1) and peripheral (only S or only T) positions (Sctrl and Tctrl, respectively), and then normalizing to the higher value among the two (S0 and T0, respectively, Supplementary Note 1)

This function is usually implemented with the below function, with `normalize` being a bool argument.

In [None]:
#| export
def get_one_kinase(df: pd.DataFrame, #stacked dataframe (paper's raw data)
                   kinase:str, # a specific kinase
                   normalize: bool=False, # normalize according to the paper; special for PDHK1/4
                   drop_s: bool= True, # drop s as s is a duplicates of t in PSPA
                  ):
    "Obtain a specific kinase data from stacked dataframe"
    
    p = pd.DataFrame(df.loc[kinase],columns = [kinase]).reset_index().rename(columns={'index':'substrate'})
    p['position'] = p.substrate.str.extract(r'(-?\d+)')
    p['aa'] = p.substrate.str[-1]
    p.position = p.position.astype(int)
    pp = p.pivot(index='position', columns='aa', values=kinase)
    if drop_s:
        if 's' in pp.columns:
            pp = pp.drop(columns=['s'])

    if normalize:
        pp = raw2norm(pp, PDHK=True if kinase == 'PDHK1' or kinase == 'PDHK4' else False)
    return pp

Retreive a single kinase data from PSPA data that has an format of kinase as index and position+amino acid as column.

In [None]:
data = Data.get_pspa_st_norm()

In [None]:
get_one_kinase(data,'PDHK1')

aa,A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y,t,y
position,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
-5,0.0594,0.0625,0.0589,0.0550,0.0775,0.0697,0.0687,0.0590,0.0515,0.0657,0.0687,0.0613,0.0451,0.0424,0.0594,0.0594,0.0594,0.0573,0.1001,0.0775,0.0583,0.0658
-4,0.0618,0.0621,0.0550,0.0511,0.0739,0.0715,0.0598,0.0601,0.0520,0.0614,0.0744,0.0549,0.0637,0.0552,0.0617,0.0608,0.0608,0.0519,0.0916,0.0739,0.0528,0.0752
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3,0.0486,0.0609,0.0938,0.0684,0.1024,0.0676,0.0544,0.0583,0.0388,0.0552,0.0637,0.0505,0.0686,0.0502,0.0561,0.0588,0.0588,0.0593,0.0641,0.1024,0.0539,0.0431
4,0.0565,0.0749,0.0631,0.0535,0.0732,0.0655,0.0664,0.0625,0.0496,0.0552,0.0627,0.0640,0.0677,0.0553,0.0604,0.0626,0.0626,0.0579,0.0864,0.0732,0.0548,0.0575


In [None]:
pspa=Data.get_pspa_all_norm()

### Plot PSPA logo motif (old)

In [None]:
#| export
def get_logo(df: pd.DataFrame, # stacked Dataframe with kinase as index, substrates as columns
             kinase: str, # a specific kinase name in index
             ):
    "Given stacked df (index as kinase, columns as substrates), get a specific kinase's logo"
    
    
    # get raw kinase to calculate S/T
    pp = get_one_kinase(df,kinase,normalize=False)
    
    # get S/T ratio value
    ss = pp['S'].sum()
    st = pp['T'].sum()

    S_ctrl = 0.75*ss - 0.25*st
    T_ctrl = 0.75*st - 0.25*ss

    S0 = S_ctrl / max(S_ctrl, T_ctrl)
    T0 = T_ctrl / max(S_ctrl, T_ctrl)

    S_ratio = S0/(S0+T0)
    T_ratio = T0/(S0+T0)
    
    # get normalized kinase
    norm_p = get_one_kinase(df,kinase, normalize=True)
    
    # calculate ratio, divide values by median, followed by log2 transformation
    ratio =norm_p.apply(lambda r: r/r.median(),axis=1)
    ratio = np.log2(ratio)

    m = ratio.apply(lambda row: row[row > 0].sum(), axis=1).max()

    new_row = pd.DataFrame({'S': S_ratio*m, 'T':T_ratio*m}, index=[0]) 

    ratio2 = pd.concat([ratio, new_row], ignore_index=False).fillna(0)
    
    # plot logo
    # logo_func(ratio2, kinase)
    plot_logo_raw(ratio2.T,title=kinase,ytitle='log₂(Value / Median)')

This function is to replicate the motif logo from [Johnson et al. Nature: An atlas of substrate specificities for the human serine/threonine kinome](https://www.nature.com/articles/s41586-022-05575-3). Given raw PSPA data, it can output a motif logo.

In [None]:
import pandas as pd

In [None]:
# load raw PSPA data
df = pd.read_csv('https://github.com/sky1ove/katlas_raw/raw/refs/heads/main/nbs/raw/pspa_st_raw.csv').set_index('kinase')
df.head()

Unnamed: 0_level_0,-5P,-5G,-5A,-5C,-5S,-5T,-5V,-5I,-5L,-5M,...,4H,4K,4R,4Q,4N,4D,4E,4s,4t,4y
kinase,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAK1,7614134.38,2590563.43,3001315.49,4696631.43,4944311.77,8315837.72,10056545.0,16433061.43,10499735.53,9133577.86,...,6020662.73,8938081.41,9983402.01,6833481.55,6364453.29,4189045.89,4921595.57,2705053.53,2705053.53,2909279.71
ACVR2A,4991039.28,5783855.86,7015770.78,8367603.09,7072052.48,7601399.57,7188292.41,7513915.73,7159894.71,6266122.81,...,6039472.76,5556300.56,5178734.62,6490097.7,5862480.97,6742905.78,6750653.36,7414220.16,7414220.16,6209576.97
ACVR2B,26480329.1,25689687.16,28137300.9,45175909.3,32876722.9,33516959.03,27011194.06,21996255.94,23412987.54,25670581.4,...,27984195.21,22496915.32,24236904.72,29132857.3,26527389.14,36388726.15,34729319.54,37906081.09,37906081.09,31761418.56
AKT1,18399509.29,18104681.05,16831835.48,17247743.9,22647275.57,17801288.32,13037570.99,13271896.32,14156489.52,15409761.84,...,29511541.69,50942663.29,48152924.11,32693882.62,28896602.57,19701350.3,13887460.52,17483074.6,17483074.6,11696833.54
AKT2,5439237.54,5569477.23,5805462.7,6301076.01,5004932.12,4812022.8,3906822.27,3776845.45,4450344.85,4629319.8,...,6812201.58,11590683.5,9932525.89,6544476.93,6252360.75,3629091.99,3510048.19,5499662.3,5499662.3,4188620.88


In [None]:
# # plot logo of a kinase
# get_logo(df, 'AAK1')

In [None]:
# plt.close('all')

## Compare PSSM

In [None]:
pssms = Data.get_pspa_all_scale()

In [None]:
# one example
pssm_df = recover_pssm(pssms.iloc[1])
pssm_df2 = recover_pssm(pssms.iloc[0])

### KL divergence

In [None]:
#| export
def kl_divergence(p1,  # target pssm p (array-like, shape: (AA, positions))
                  p2,  # pred pssm q (array-like, same shape as p1)
                 ):
    """
    KL divergence D_KL(p1 || p2) over positions.
    
    p1 and p2 are arrays (df or np) with index as aa and column as position.
    Returns average divergence across positions if mean=True, else per-position.
    """
    assert p1.shape == p2.shape
    p1, p2 = p1.align(p2, join='inner', axis=None)
    # Mask invalid positions (both zero)
    valid = (p1 + p2) > 0
    p1 = np.where(valid, p1, 0.0)
    p2 = np.where(valid, p2, 0.0)

    # KL divergence: sum_x p1(x) log(p1(x)/p2(x))
    kl = np.sum(p1 * np.log((p1 + EPSILON) / (p2 + EPSILON)), axis=0)

    return kl

The Kullback–Leibler (KL) divergence between two probability distributions \( P \) and \( Q \) is defined as:

$$
\mathrm{KL}(P \| Q) = \sum_{x \in \mathcal{X}} P(x) \log \left( \frac{P(x)}{Q(x)} \right)
$$

This measures the information lost when \( Q \) is used to approximate \( P \). It is **not symmetric**, i.e.,

$$
\mathrm{KL}(P \| Q) \ne \mathrm{KL}(Q \| P)
$$

and it is **non-negative**, meaning:

$$
\mathrm{KL}(P \| Q) \ge 0
$$

with equality if and only if \( P = Q \) almost everywhere.

In practical computation, to avoid numerical instability when \( P(x) = 0 \) or \( Q(x) = 0 \), we often add a small constant \( \varepsilon \):

$$
\mathrm{KL}_\varepsilon(P \| Q) = \sum_{x \in \mathcal{X}} P(x) \log \left( \frac{P(x) + \varepsilon}{Q(x) + \varepsilon} \right)
$$

In [None]:
kl_divergence(pssm_df,pssm_df2)

array([0.29182172, 0.11138481, 0.24590698, 0.46021635, 0.36874823,
       0.53858511, 1.51571614, 0.02905442, 0.08530757, 0.07753394])

In [None]:
kl_divergence(pssm_df,pssm_df2).mean()

np.float64(0.37242752573216287)

In [None]:
kl_divergence(pssm_df,pssm_df2).max()

np.float64(1.5157161422110503)

In [None]:
#| export
def kl_divergence_flat(p1_flat, # pd.Series of target flattened pssm p
                       p2_flat, # pd.Series of pred flattened pssm q
                       ):

    "p1 and p2 are two flattened pd.Series with index as aa and column as position"
    kld = kl_divergence(p1_flat,p2_flat) # do not do js.mean() because it's 1d
    total_position = len(p1_flat.index.str.extract(r'(-?\d+)').drop_duplicates())
    return float(kld/total_position)

In [None]:
%%time
kl_divergence_flat(pssms.iloc[1],pssms.iloc[0])

CPU times: user 1.37 ms, sys: 0 ns, total: 1.37 ms
Wall time: 1.37 ms


0.37242752573216287

### JS divergence

In [None]:
#| export
def js_divergence(p1, # pssm 
                  p2, # pssm
                  index=True,
                 ):
    "p1 and p2 are two arrays (df or np) with index as aa and column as position"
    assert p1.shape==p2.shape
    p1, p2 = p1.align(p2, join='inner', axis=None)
    if index: positions=p1.columns
    valid = (p1 + p2) > 0
    p1 = np.where(valid, p1, 0.0)
    p2 = np.where(valid, p2, 0.0)
    
    m = 0.5 * (p1 + p2)
    
    js = 0.5 * np.sum(p1 * np.log((p1+ EPSILON) / (m + EPSILON)), axis=0) + \
         0.5 * np.sum(p2 * np.log((p2+ EPSILON) / (m + EPSILON)), axis=0)
    return pd.Series(js,index=positions) if index else js

The Jensen-Shannon divergence between two probability distributions $ P $ and $ Q $ is defined as:

$$
\mathrm{JS}(P \| Q) = \frac{1}{2} \, \mathrm{KL}(P \| M) + \frac{1}{2} \, \mathrm{KL}(Q \| M)
$$

where $ M = \frac{1}{2}(P + Q) $ is the average (mixture) distribution, and $ \mathrm{KL} $ denotes the Kullback–Leibler divergence:

$$
\mathrm{KL}(P \| Q) = \sum_{x \in \mathcal{X}} P(x) \log \left( \frac{P(x)}{Q(x)} \right)
$$

Therefore,

$$
\mathrm{JS}_\varepsilon(P \| Q) = \frac{1}{2} \sum_{x \in \mathcal{X}} P(x) \log \left( \frac{P(x) + \varepsilon}{M(x) + \varepsilon} \right)
+ \frac{1}{2} \sum_{x \in \mathcal{X}} Q(x) \log \left( \frac{Q(x) + \varepsilon}{M(x) + \varepsilon} \right)
$$

In [None]:
js_divergence(pssm_df,pssm_df2)

Position
-5    0.065539
-4    0.025712
        ...   
 3    0.020949
 4    0.018206
Length: 10, dtype: float64

In [None]:
js_divergence(pssm_df,pssm_df2).max()

np.float64(0.34404931056288773)

In [None]:
js_divergence(pssm_df,pssm_df2).mean()

np.float64(0.08286124552178498)

In [None]:
#| export
def js_divergence_flat(p1_flat, # pd.Series of flattened pssm
                       p2_flat, # pd.Series of flattened pssm
                       ):

    "p1 and p2 are two flattened pd.Series with index as aa and column as position"
    js = js_divergence(p1_flat,p2_flat,index=False)
    total_position = len(p1_flat.index.str.extract(r'(-?\d+)').drop_duplicates())
    return float(js/total_position)

In [None]:
%%time
js_divergence_flat(pssms.iloc[1],pssms.iloc[0])

CPU times: user 1.26 ms, sys: 0 ns, total: 1.26 ms
Wall time: 1.26 ms


0.08286124552178498

### JS similarity

To convert the Jensen–Shannon divergence into a similarity measure, we first normalize it to bits by dividing by log(2), ensuring that the divergence lies within the range [0, 1]. 
$$
\mathrm{JS}_{\text{bits}}(P \| Q) = \frac{\mathrm{JS}(P \| Q)}{\log 2}
$$

The similarity is then defined as one minus this normalized divergence:
$$
\mathrm{Sim}_{\mathrm{JS}}(P, Q) = 1 - \mathrm{JS}_{\text{bits}}(P \| Q)
$$

Thus, $\mathrm{Sim}_{\mathrm{JS}}$ ranges from 0 (completely dissimilar) to 1 (identical distributions).

In [None]:
#| export
def js_similarity(pssm1,pssm2):
    "Convert JSD to bits to be in range (0,1) then 1-JSD."
    distance = js_divergence(pssm1,pssm2)/np.log(2)
    similarity = 1-distance
    return similarity

In [None]:
js_similarity(pssm_df,pssm_df2)

Position
-5    0.905448
-4    0.962905
        ...   
 3    0.969777
 4    0.973734
Length: 10, dtype: float64

In [None]:
#| export
def js_similarity_flat(p1_flat,p2_flat):
    "Convert JSD to bits to be in range (0,1) then 1-JSD. "
    return 1-(js_divergence_flat(p1_flat,p2_flat)/np.log(2))

In [None]:
# #| export
# def weighted_js_similarity(pssm1,pssm2):
#     pssm = (pssm1+pssm2)/2
#     ic = get_IC(pssm)
    
#     similarity = js_similarity(pssm1,pssm2)
#     return (ic*similarity)/ic.sum()

### Cosine similarity

In [None]:
#| export
def cosine_similarity(pssm1: pd.DataFrame, pssm2: pd.DataFrame) -> pd.Series:
    "Compute cosine similarity per position (column) between two PSSMs."
    
    assert pssm1.shape == pssm2.shape, "PSSMs must have the same shape"
    
    sims = {}
    for pos in pssm1.columns:
        v1 = pssm1[pos]
        v2 = pssm2[pos]
        v1,v2 = v1.align(v2, join='inner') # make sure the aa index match with each other

        norm1 = np.linalg.norm(v1)
        norm2 = np.linalg.norm(v2)

        if norm1 == 0 or norm2 == 0:
            sims[pos] = 0.0
        else:
            dot_product = sum(v1*v2) # np.dot(v1,v2)
            sims[pos] = dot_product / (norm1 * norm2)

    return pd.Series(sims)

The cosine similarity between two vectors \( P \) and \( Q \) (e.g., two PSSM columns representing amino acid probability distributions) is defined as:

$$
\mathrm{cos}(P, Q) = \frac{P \cdot Q}{\|P\| \, \|Q\|}
$$

where $ P \cdot Q = \sum_{i=1}^{n} P_i Q_i $ is the dot product between $ P $ and $ Q $, and $ \|P\| = \sqrt{\sum_{i=1}^{n} P_i^2} $ is the Euclidean norm of $ P $.

Since all entries of $ P $ and $ Q $ are nonnegative probabilities (i.e., $ P_i, Q_i \in [0,1] $), the cosine similarity lies within the range:

$$
0 \leq \mathrm{cos}(P, Q) \leq 1
$$


Given that pssm are probabilities between 0 and 1, cosine similarity is within (0,1)

In [None]:
cosine_similarity(pssm_df,pssm_df2).sort_values()

 1    0.130818
-2    0.606234
        ...   
 4    0.934967
 2    0.971066
Length: 10, dtype: float64

In [None]:
cosine_similarity(pssm_df,pssm_df2).mean()

np.float64(0.754148470457778)

In [None]:
#| export
def cosine_overall_flat(pssm1_flat, pssm2_flat):
    """Compute overall cosine similarity between two PSSMs (flattened)."""
    # match index for dot product
    pssm1_flat, pssm2_flat = pssm1_flat.align(pssm2_flat, join='inner')
    norm1 = np.linalg.norm(pssm1_flat)
    norm2 = np.linalg.norm(pssm2_flat)
    if norm1 == 0 or norm2 == 0: return 0.0
    dot_product = sum(pssm1_flat*pssm2_flat) # np.dot(pssm1_flat, pssm2_flat)
    return  dot_product/ (norm1 * norm2)

In [None]:
cosine_overall_flat(pssms.iloc[0],pssms.iloc[0])

np.float64(1.0000000000000004)

In [None]:
cosine_overall_flat(pssms.iloc[0],pssms.iloc[1])

np.float64(0.6614783212500965)

## End

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| hide

# Sometimes we need to duplicate values at zero positions to ensure that the scoring function treat zero position S/T/Y same as s/t/y.
# def duplicate_zero(pssm_df):
#     """
#     For position 0, synchronize values across standard (S, T, Y) and their
#     phosphorylated variants (s, t, y or pS, pT, pY). Any non-zero value is
#     treated as meaningful and copied to all variants in the group.
#     """
#     pssm_df = pssm_df.copy()

#     variant_groups = [
#         ['S', 's', 'pS'],
#         ['T', 't', 'pT'],
#         ['Y', 'y', 'pY']
#     ]

#     for group in variant_groups:
#         # Filter to existing amino acids in index
#         present = [aa for aa in group if aa in pssm_df.index]
#         if not present:
#             continue

#         # Look for non-zero value in existing variants
#         shared_value = 0
#         for aa in present:
#             val = pssm_df.at[aa, 0]
#             if val != 0:
#                 shared_value = val
#                 break  # prefer the first non-zero one

#         # Assign the shared value to all existing variants
#         for aa in present:
#             pssm_df.at[aa, 0] = shared_value

#     return pssm_df

# pssm_df[0].to_dict() # original

# duplicate_zero(pssm_df)[0].to_dict() # after














# def get_freq(df_k: pd.DataFrame, # a dataframe for a single kinase that contains phosphorylation sequence splitted by their position
#              aa_order = [i for i in 'PGACSTVILMFYWHKRQNDEsty'], # amino acid to include in the full matrix 
#              aa_order_paper = [i for i in 'PGACSTVILMFYWHKRQNDEsty'], # amino acid to include in the partial matrix
#              position = [i for i in range(-7,8)], # position to include in the full matrix
#              position_paper = [-5,-4,-3,-2,-1,1,2,3,4] # position to include in the partial matrix
#              ):
    
#     "Get frequency matrix given a dataframe of phosphorylation sites for a single kinase"
    

#     #Count frequency for each amino acid at each position
#     melted_k = df_k.melt(
#                     value_vars=[i for i in range(-7, 8)],
#                     var_name='Position', 
#                     value_name='aa')
    
#     # Group by Position and Amino Acid and count occurrences
#     grouped = melted_k.groupby(['Position', 'aa']).size().reset_index(name='Count')
    

#     # Remove wired amino acid
#     aa_include = [i for i in 'PGACSTVILMFYWHKRQNDEsty']
#     grouped = grouped[grouped.aa.isin(aa_include)].reset_index(drop=True)
    
#     # get pivot table
#     pivot_k = grouped.pivot(index='aa', columns='Position', values='Count').fillna(0)
    
#     # Get frequency by dividing the sum of each column
#     freq_k = pivot_k/pivot_k.sum()

    
#     # data from the kinase-substrate dataset, and format is Lew's paper's format
#     paper = freq_k.reindex(index=aa_order_paper,columns=position_paper,fill_value=0)

#     # full pivot data from kinase-substrate dataset
#     full = freq_k.reindex(index=aa_order,columns=position, fill_value=0)

    
#     return paper,full

# # get frequency matrix
# paper_format, full = get_freq(data_k)
# paper_format.head()

# def get_unique_site(df:pd.DataFrame = None,# dataframe that contains phosphorylation sites
#                     seq_col: str='site_seq', # column name of site sequence
#                     id_col: str='gene_site' # column name of site id
#                    ):
#     "Remove duplicates among phosphorylation sites; return df with new columns of acceptor and number of duplicates"
    
#     unique = df.groupby(seq_col).agg(
#         {id_col: lambda r: '|'.join(r.unique())} )
#     unique['num_site'] = unique[id_col].str.split('|').apply(len) 
#     unique = unique.reset_index()
#     position = len(unique[seq_col][0])//2
#     unique['acceptor'] = unique[seq_col].str[position]
    
#     return unique

# As there are lots of duplicates of the phosphorylation site sequence in the dataset, it could be helpful to remove the duplicated sequences. 

# Implement `get_unique_site` to get unique phosphorylation sites. Need to inform columns of sequence and id.

# df = Data.get_ochoa_site()
# unique = get_unique_site(df,seq_col='site_seq',id_col='gene_site')
# unique.sort_values('num_site',ascending=False).head()

# #| export
# def scale_zero_position(pssm_df):
#     "Scale position 0 to the max sum of neigboring position for better visualization."
#     pssm_df = pssm_df.copy()
#     m = pssm_df.sum()[pssm_df.sum().index!=0].max()
#     pssm_df[0] = m*(pssm_df[0]/pssm_df[0].sum())
#     return pssm_df