# pssm.core

> Functions related with PSSMs

In [None]:
#| default_exp pssm.core

## Overview

**PSSM Calculation**

To calculate a position-specific probability matrix (PSSM) from phosphorylation site sequences:

```python
pssm_df = get_prob(
    data=df,       # input DataFrame, Series, or list of sequences
    col='site_seq',    # column name containing sequences (if DataFrame)
)
```
Alternatively, use a list of sequences as input
```python
pssm_df = get_prob(
    data=df['site_seq'],       # input DataFrame, Series, or list of sequences
)
```


To extract flattened PSSMs for each cluster in a dataset:

```python
cluster_pssms = get_cluster_pssms(
    df=data,                    # DataFrame containing sequences and cluster labels
    cluster_col='kinase_group', # column name for cluster identifiers
    seq_col='site_seq',         # column name for site sequences
    id_col='sub_site',          # column for deduplication within clusters (None to skip)
    count_thr=10,               # minimum sequences required per cluster
    valid_thr=None,             # minimum fraction of non-zero PSSM values (None to skip)
    plot=False,                 # whether to plot sequence logos for each cluster
)
```

---

**PSSM Transformation**

To flatten a 2D PSSM into a dictionary (for storage or model input):

```python
flat_dict = flatten_pssm(
    pssm_df=pssm_df,     # 2D PSSM with amino acids as rows, positions as columns
    column_wise=True,    # True for column-major order; False for row-wise (PyTorch)
)
```

To recover a 2D PSSM from a flattened Series:

```python
pssm_df = recover_pssm(
    flat_pssm=flat_pssm,  # flattened PSSM as pd.Series with position+aa index
)
```

To clean the center position and normalize each column to sum to 1:

```python
pssm_norm = clean_zero_normalize(
    pssm_df=pssm,  # 2D PSSM (zeros out non-s/t/y at position 0, then normalizes)
)
```

---

**Entropy Analysis**

To calculate Shannon entropy per position (lower entropy = more conserved):

```python
entropy = get_entropy(
    pssm_df=pssm_df,      # 2D PSSM matrix
    return_min=False,     # True to return single minimum value
    exclude_zero=True,    # True to exclude center position from calculation
    clean_zero=True,      # True to zero out non-s/t/y at position 0
)
```

To calculate entropy from a flattened PSSM:

```python
entropy = get_entropy_flat(
    flat_pssm=flat_pssm,  # flattened PSSM as pd.Series
    return_min=True,      # return minimum entropy across positions
    exclude_zero=True,    # exclude center position
)
```

---

**Information Content**

To calculate information content per position (higher IC = more conserved):

```python
ic = get_IC(
    pssm_df=pssm_df,    # 2D PSSM matrix
    exclude_zero=True,  # exclude center position from calculation
)
```

To calculate IC from a flattened PSSM:

```python
ic = get_IC_flat(
    flat_pssm=flat_pssm,  # flattened PSSM as pd.Series
    exclude_zero=True,    # exclude center position
)
```

---

**Overall Specificity**

To compute an overall specificity score combining max IC and IC variance:

```python
spec = get_specificity(
    pssm_df=pssm_df,  # 2D PSSM matrix (excludes position 0 automatically)
)
```

To compute specificity from a flattened PSSM:

```python
spec = get_specificity_flat(
    flat_pssm=flat_pssm,  # flattened PSSM as pd.Series
)
```

## Setup

In [None]:
#| export
import numpy as np, pandas as pd
from katlas.data import *
from katlas.utils import *
from fastcore.meta import delegates
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from functools import partial

# for type
from typing import Sequence

In [None]:
#| hide
pd.set_option('display.max_rows', 5)
pd.set_option('display.max_columns', 100) # show all columns

In [None]:
#| hide
#| export
EPSILON = 1e-8

## PSSM

We need to compute **position-specific probability matrix (PSSM)** from a list of aligned site sequences.

For each position $i$ (e.g., from $-7$ to $+7$), the probability of observing amino acid $x$ is:

$$
P_i(x) = \frac{\text{count of amino acid } x \text{ at position } i}{\text{total counts at position } i}
$$

The following 23 amino acids are included:

- Standard amino acids:  
  `A`, `C`, `D`, `E`, `F`, `G`, `H`, `I`, `K`, `L`, `M`, `N`, `P`, `Q`, `R`, `S`, `T`, `V`, `W`, `Y`  
- Modified amino acids:  
  `s`, `t`, `y` (often used to denote phosphorylated `S`, `T`, `Y`)


The resulting matrix has:
- **Rows**: Amino acids,
- **Columns**: Sequence positions (centered on the phosphosite),
- **Values**: Probabilities of each amino acid at each position.

In [None]:
#| export
def get_prob(data: pd.DataFrame | pd.Series | Sequence[str], # input data, list or df
             col: str='site_seq', # column name if input is df
             ):
    "Get the probability matrix of PSSM from phosphorylation site sequences."

    aa_order=[i for i in 'PGACSTVILMFYWHKRQNDEsty']

    if isinstance(data, pd.DataFrame):
        if col not in data.columns:
            raise ValueError(f"Column '{col}' not found in DataFrame.")
        site = data[col]
    else:
        if isinstance(data, (str, bytes)):
            raise TypeError("Input looks like a single sequence string; pass [seq] or a Series instead.")
        try:
            site = pd.Series(data,copy=False)
        except Exception:
            raise TypeError("Input must be a DataFrame, Series, or list of sequences.")
    
    
    site = check_seqs(site)
    
    site_array = np.array(site.apply(list).tolist())
    seq_len = site_array.shape[1]
    # if seq_len % 2 == 0: raise ValueError(f"Expected odd sequence length (centered window). Got even length: {seq_len}")

    position = list(range(-(seq_len // 2), (seq_len // 2)+1)) # add 1 because range do not include the final num
    
    site_df = pd.DataFrame(site_array, columns=position)
    melted = site_df.melt(var_name='Position', value_name='aa')
    
    grouped = melted.groupby(['Position', 'aa']).size().reset_index(name='Count')
    grouped = grouped[grouped.aa.isin(aa_order)].reset_index(drop=True)
    
    pivot_df = grouped.pivot(index='aa', columns='Position', values='Count').fillna(0)
    pssm_df = pivot_df / pivot_df.sum()

    ordered_aa = [aa for aa in aa_order if aa in pssm_df.index]
    pssm_df = pssm_df.reindex(index=ordered_aa, columns=position, fill_value=0)
    
    return pssm_df

In [None]:
data = Data.get_ks_dataset()
data_k = data[data.kinase_uniprot=='P49841'] # CDK1

In [None]:
get_prob(data_k, col='site_seq').shape

(23, 41)

In [None]:
# or
pssm_df = get_prob(data_k['site_seq'].tolist())
pssm_df.tail()

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
D,0.059524,0.065476,0.044577,0.080238,0.051775,0.048744,0.047267,0.054572,0.064611,0.048387,0.052709,0.058565,0.049635,0.049563,0.071325,0.069666,0.058055,0.063768,0.037627,0.062229,0.0,0.034783,0.046579,0.056934,0.030702,0.084919,0.064516,0.056464,0.051128,0.072289,0.059091,0.054962,0.033846,0.063174,0.053042,0.054859,0.063091,0.05538,0.054054,0.05609,0.05314
E,0.061012,0.066964,0.077266,0.056464,0.075444,0.075332,0.090103,0.094395,0.066079,0.095308,0.086384,0.070278,0.068613,0.08309,0.093159,0.076923,0.076923,0.053623,0.060781,0.069465,0.0,0.023188,0.081514,0.075912,0.071637,0.080527,0.055718,0.063893,0.081203,0.078313,0.068182,0.087023,0.072308,0.064715,0.098284,0.073668,0.056782,0.072785,0.057234,0.057692,0.074074
s,0.037202,0.041667,0.035661,0.031204,0.038462,0.038405,0.044313,0.041298,0.070485,0.045455,0.04978,0.048316,0.10365,0.056851,0.05968,0.068215,0.223512,0.073913,0.082489,0.0767,0.677279,0.037681,0.081514,0.106569,0.274854,0.057101,0.076246,0.074294,0.129323,0.057229,0.054545,0.041221,0.083077,0.066256,0.045242,0.047022,0.069401,0.033228,0.063593,0.051282,0.045089
t,0.019345,0.010417,0.022288,0.019316,0.019231,0.016248,0.016248,0.011799,0.014684,0.020528,0.033675,0.036603,0.024818,0.017493,0.024745,0.039187,0.065312,0.030435,0.050651,0.026049,0.285094,0.021739,0.029112,0.037956,0.089181,0.033675,0.033724,0.040119,0.039098,0.031627,0.05,0.025954,0.030769,0.015408,0.020281,0.010972,0.014196,0.015823,0.015898,0.014423,0.014493
y,0.004464,0.008929,0.007429,0.008915,0.008876,0.01034,0.007386,0.00295,0.005874,0.016129,0.010249,0.014641,0.005839,0.01312,0.014556,0.017417,0.007257,0.007246,0.005789,0.01013,0.037627,0.02029,0.007278,0.017518,0.019006,0.004392,0.004399,0.008915,0.006015,0.00753,0.006061,0.012214,0.007692,0.004622,0.01248,0.004702,0.006309,0.006329,0.012719,0.00641,0.011272


### PSSM with weight score

In [None]:
#| export
def get_pssm_weight(
    data: pd.DataFrame,
    seq_col: str = "seq",
    score_col: str = "enrichment",          # selected/input ratio OR already-log2
    to_log2: bool = True,
    aa_order: str = "PGACSTVILMFYWHKRQNDEsty",
    center: str='pos',# per-position median centering or 'glob' global median centering
    count_thr: int=5, # threshold to filter out if <count_thr
    alpha: str|float = 'auto', # shrinkage strength: value * count / (count + alpha)
):
    """
    Position-specific amino acid enrichment matrix:
    PSSM(aa,pos) = mean( log2(score) ) over peptides with aa at pos.
    """
    # --- get sequences + scores ---
    if isinstance(data, pd.DataFrame):
        if seq_col not in data.columns:
            raise ValueError(f"Column '{seq_col}' not found.")
        if score_col not in data.columns:
            raise ValueError(f"Column '{score_col}' not found.")
        seq = data[seq_col].astype(str)
        score = pd.to_numeric(data[score_col], errors="coerce")
        d = pd.DataFrame({"seq": seq, "score": score})
    else:
        # if only pass sequences, this cannot compute enrichment-based PSSM
        raise TypeError("For enrichment-based PSSM, pass a DataFrame with seq_col and score_col.")

    # basic QC
    d = d.dropna(subset=["seq", "score"]).copy()
    d = d[d["seq"].str.len() == d["seq"].str.len().mode()[0]]  # keep modal length (usually 11)

    seq_len = int(d["seq"].str.len().iloc[0])
    center_idx = seq_len // 2

    # log2 transform
    if to_log2:
        # score is selected/input ratio; must be >0
        d = d[d["score"] > 0].copy()
        d["L"] = np.log2(d["score"].astype(float))
    else:
        d["L"] = d["score"].astype(float)

    # positions labels (-5..+5 for length 11)
    positions = list(range(-center_idx, center_idx + 1))

    # explode sequence into columns
    site_array = np.array(d["seq"].apply(list).tolist())
    site_df = pd.DataFrame(site_array, columns=positions)
    site_df["L"] = d["L"].to_numpy()

    # melt so each residue-position pair carries the peptide's log2 enrichment
    melted = site_df.melt(id_vars="L", var_name="Position", value_name="aa")

    # keep only those in AA order
    aa_order_list = [a for a in aa_order if a in melted["aa"].unique()]
    melted = melted[melted["aa"].isin(aa_order_list)]

    # aggregate: mean log2 enrichment for (Position, aa)
    stats = (
        melted.groupby(["aa", "Position"])["L"]
        .agg(["mean", "count"])
        .reset_index()
    )
    
    # build matrices
    pssm = stats.pivot(index="aa", columns="Position", values="mean")
    counts = stats.pivot(index="aa", columns="Position", values="count")

    # shrinkage
    if alpha is not None:
        if alpha=='auto':
            c = counts.to_numpy().ravel()
            c = c[c > 0]
            a = 0.5* np.median(c)
        elif isinstance(alpha,(int,float)):
            a = float(alpha)
        else:
            raise ValueError("alpha must be None, 'auto', or a numeric value")
        shrink_factor = counts / (counts + a)
        pssm = pssm * shrink_factor

    # mask low-support cells
    pssm[counts < count_thr] = np.nan

    # optional centering
    if center=='pos': pssm = pssm - pssm.median()
    elif center=='glob': pssm = pssm - np.nanmedian(pssm.to_numpy())

    return pssm.reindex(aa_order_list)

## Transform PSSM

In [None]:
#| export
def flatten_pssm(pssm_df,
                 column_wise=True, # if True, column major flatten; else row wise flatten (for pytorch training)
                ):
    "Flatten PSSM dataframe to dictionary"
    
    pssm_df=pssm_df.copy()

    # Flatten a pssm_df
    # Column wise
    if column_wise: 
        pssm = pssm_df.unstack().reset_index(name='value')
        # Combine position column and residue identity column as new column for keys
        pssm['position_residue']=pssm.iloc[:,0].astype(str)+pssm.iloc[:,1]
        
    # Row wise
    else: 
        pssm = pssm_df.T.unstack().reset_index(name='value')
        pssm['position_residue']=pssm.iloc[:,1].astype(str)+pssm.iloc[:,0].astype(str)
    
    # Set index to be position+residue
    return pssm.set_index('position_residue')['value'].to_dict()

In [None]:
flat_pssm = pd.Series(flatten_pssm(pssm_df))
flat_pssm

-20P    0.069940
-20G    0.087798
          ...   
20t     0.014493
20y     0.011272
Length: 943, dtype: float64

In [None]:
flat_pssm.reset_index()[0]

0      0.069940
1      0.087798
         ...   
941    0.014493
942    0.011272
Name: 0, Length: 943, dtype: float64

In [None]:
#| export
def recover_pssm(flat_pssm: pd.Series):
    "Recover 2D PSSM from flattened PSSM Series."
    df = flat_pssm.reset_index()
    df.columns=['index', 'value']
    df['Position'] = df['index'].str[:-1].astype(int)
    df['aa'] = df['index'].str[-1]

    df = df.pivot(index='aa', columns='Position', values='value').fillna(0)
    aa_order=tuple('PGACSTVILMFYWHKRQNDEsty')
    order = [aa for aa in aa_order if aa in df.index]
    return df.reindex(index=order).sort_index(axis=1)

In [None]:
out = recover_pssm(flat_pssm)
out

Position,-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
P,0.069940,0.077381,0.062407,0.084695,0.082840,0.072378,0.087149,0.081121,0.077827,0.071848,0.101025,0.074671,0.084672,0.106414,0.112082,0.094340,0.079826,0.133333,0.150507,0.131693,0.000000,0.420290,0.141194,0.094891,0.049708,0.187408,0.082111,0.077266,0.064662,0.106928,0.074242,0.077863,0.078462,0.070878,0.095164,0.092476,0.074132,0.090190,0.071542,0.099359,0.078905
G,0.087798,0.087798,0.069837,0.065379,0.076923,0.054653,0.088626,0.095870,0.064611,0.068915,0.070278,0.087848,0.058394,0.065598,0.066958,0.063861,0.068215,0.101449,0.076700,0.105644,0.000000,0.075362,0.065502,0.086131,0.046784,0.054173,0.060117,0.083210,0.084211,0.058735,0.063636,0.083969,0.084615,0.075501,0.076443,0.081505,0.067823,0.060127,0.077901,0.088141,0.066023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
t,0.019345,0.010417,0.022288,0.019316,0.019231,0.016248,0.016248,0.011799,0.014684,0.020528,0.033675,0.036603,0.024818,0.017493,0.024745,0.039187,0.065312,0.030435,0.050651,0.026049,0.285094,0.021739,0.029112,0.037956,0.089181,0.033675,0.033724,0.040119,0.039098,0.031627,0.050000,0.025954,0.030769,0.015408,0.020281,0.010972,0.014196,0.015823,0.015898,0.014423,0.014493
y,0.004464,0.008929,0.007429,0.008915,0.008876,0.010340,0.007386,0.002950,0.005874,0.016129,0.010249,0.014641,0.005839,0.013120,0.014556,0.017417,0.007257,0.007246,0.005789,0.010130,0.037627,0.020290,0.007278,0.017518,0.019006,0.004392,0.004399,0.008915,0.006015,0.007530,0.006061,0.012214,0.007692,0.004622,0.012480,0.004702,0.006309,0.006329,0.012719,0.006410,0.011272


In [None]:
out.equals(pssm_df)

True

Or recover from PSPA data

In [None]:
pspa = Data.get_pspa()

In [None]:
pssm=recover_pssm(pspa.loc['AAK1'].dropna())
pssm

Position,-5,-4,-3,-2,-1,0,1,2,3,4
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
P,0.0720,0.0534,0.1084,0.0226,0.1136,0.0,0.0463,0.0527,0.0681,0.0628
G,0.0245,0.0642,0.0512,0.0283,0.0706,0.0,0.7216,0.0749,0.0923,0.0702
...,...,...,...,...,...,...,...,...,...,...
t,0.0201,0.0332,0.0303,0.0209,0.0121,1.0,0.0123,0.0409,0.0335,0.0251
y,0.0611,0.0339,0.0274,0.0486,0.0178,0.0,0.0100,0.0410,0.0359,0.0270


PSPA is not scaled per position.

So we need to remove the redundant copy in zero position (leave s/t/y only) and scaled to 1 per position.

In [None]:
pssm.index[pssm.index.isin(['s','t','y'])]

Index(['s', 't', 'y'], dtype='object', name='aa')

In [None]:
#| export
def _clean_zero(pssm_df):
    "Zero out non-last three values in position 0 (keep only s,t,y values at center)"
    pssm_df = pssm_df.copy()
    standard_aa = list(set(pssm_df.index)-set(['s','t','y']))
    pssm_df.loc[standard_aa, 0] = 0
    return pssm_df

In [None]:
_clean_zero(pssm)

Position,-5,-4,-3,-2,-1,0,1,2,3,4
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
P,0.0720,0.0534,0.1084,0.0226,0.1136,0.0,0.0463,0.0527,0.0681,0.0628
G,0.0245,0.0642,0.0512,0.0283,0.0706,0.0,0.7216,0.0749,0.0923,0.0702
...,...,...,...,...,...,...,...,...,...,...
t,0.0201,0.0332,0.0303,0.0209,0.0121,1.0,0.0123,0.0409,0.0335,0.0251
y,0.0611,0.0339,0.0274,0.0486,0.0178,0.0,0.0100,0.0410,0.0359,0.0270


In [None]:
#| export
def clean_zero_normalize(pssm_df):
    "Zero out non-last three values in position 0 (keep only s,t,y values at center), and normalize per position"
    pssm_df=pssm_df.copy()
    pssm_df.columns= pssm_df.columns.astype(int)
    pssm_df = _clean_zero(pssm_df)
    return pssm_df/pssm_df.sum()

This function applies phosphosite-specific cleaning and normalization to a PSSM.

At the center position ($i = 0$), only the last three rows of the matrix — corresponding to phosphorylatable residues `s`, `t`, and `y` — are retained. All other amino acid values at position 0 are set to 0.

After masking, the matrix is column-normalized to ensure the probabilities at each position sum to 1:

$$
P_i(x) = \frac{P_i(x)}{\sum_{x'} P_i(x')}
$$


In [None]:
clean_zero_normalize(pssm)

Position,-5,-4,-3,-2,-1,0,1,2,3,4
aa,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
P,0.058446,0.041715,0.086100,0.017935,0.096068,0.000000,0.042649,0.040482,0.052640,0.050260
G,0.019888,0.050152,0.040667,0.022459,0.059704,0.000000,0.664702,0.057536,0.071346,0.056182
...,...,...,...,...,...,...,...,...,...,...
t,0.016316,0.025935,0.024067,0.016586,0.010233,0.908018,0.011330,0.031418,0.025895,0.020088
y,0.049598,0.026482,0.021763,0.038568,0.015053,0.000000,0.009211,0.031495,0.027750,0.021609


## PSSMs of clusters

In [None]:
#| export
def get_cluster_pssms(df, 
                    cluster_col, 
                    seq_col='site_seq', 
                    id_col = 'sub_site',
                    count_thr=10, # if less than the count threshold, not include in the return
                    valid_thr=None, # percentage of not-nan values in pssm
                      plot=False):
    "Extract motifs from clusters in a dataframe"
    pssms = []
    ids = []
    # drop duplicates based both on cluster column and substrate seq id column
    if id_col is not None: df = df.drop_duplicates(subset=[cluster_col,id_col]).copy()
    value_counts = df[cluster_col].value_counts()
    
    for cluster_id, counts in tqdm(value_counts.items(),total=len(value_counts)):
        if count_thr is not None and counts < count_thr:
            continue
            
        df_cluster = df[df[cluster_col] == cluster_id]
        n= len(df_cluster)
        pssm = get_prob(df_cluster, seq_col)
        valid_score = (pssm != 0).sum().sum() / (pssm.shape[0] * pssm.shape[1])

        if valid_thr is not None and valid_score <= valid_thr:
            continue

        pssms.append(flatten_pssm(pssm))
        ids.append(cluster_id)

        if plot:
            plot_logo(pssm, title=f'Cluster {cluster_id} (n={n})', figsize=(14, 1))
            plt.show()
            plt.close()

    pssm_df = pd.DataFrame(pssms, index=ids)
    return pssm_df

In [None]:
get_cluster_pssms(data,'kinase_group')


  0%|          | 0/10 [00:00<?, ?it/s]


 10%|█         | 1/10 [00:00<00:01,  5.11it/s]


 20%|██        | 2/10 [00:00<00:01,  5.45it/s]


 30%|███       | 3/10 [00:00<00:01,  5.62it/s]


 40%|████      | 4/10 [00:00<00:00,  6.10it/s]


 50%|█████     | 5/10 [00:00<00:00,  7.04it/s]


 70%|███████   | 7/10 [00:00<00:00,  8.60it/s]


 90%|█████████ | 9/10 [00:01<00:00, 10.54it/s]


100%|██████████| 10/10 [00:01<00:00,  8.72it/s]




Unnamed: 0,-20P,-20G,-20A,-20C,-20S,-20T,-20V,-20I,-20L,-20M,-20F,-20Y,-20W,-20H,-20K,-20R,-20Q,-20N,-20D,-20E,-20s,-20t,-20y,-19P,-19G,-19A,-19C,-19S,-19T,-19V,-19I,-19L,-19M,-19F,-19Y,-19W,-19H,-19K,-19R,-19Q,-19N,-19D,-19E,-19s,-19t,-19y,-18P,-18G,-18A,-18C,...,18E,18s,18t,18y,19P,19G,19A,19C,19S,19T,19V,19I,19L,19M,19F,19Y,19W,19H,19K,19R,19Q,19N,19D,19E,19s,19t,19y,20P,20G,20A,20C,20S,20T,20V,20I,20L,20M,20F,20Y,20W,20H,20K,20R,20Q,20N,20D,20E,20s,20t,20y
TK,0.056754,0.069067,0.070474,0.016886,0.045028,0.039400,0.058630,0.049836,0.083255,0.024038,0.032716,0.023335,0.007974,0.021107,0.069184,0.056168,0.045966,0.036585,0.058865,0.081027,0.024508,0.011843,0.017355,0.057915,0.067158,0.073125,0.016731,0.045396,0.037440,0.053001,0.049959,0.087633,0.023751,0.031356,0.018018,0.009594,0.022230,0.074529,0.063999,0.043173,0.041652,0.060840,0.073359,0.025389,0.010881,0.012870,0.058123,0.067810,0.062792,0.014006,...,0.080097,0.027670,0.012621,0.012743,0.058601,0.065424,0.065302,0.014864,0.045078,0.038986,0.056408,0.048124,0.087354,0.018884,0.042276,0.017178,0.008894,0.018519,0.073830,0.061769,0.047149,0.040327,0.059698,0.078216,0.028509,0.010599,0.014011,0.066015,0.072738,0.065037,0.013692,0.048900,0.037897,0.054523,0.050122,0.077262,0.023227,0.039364,0.018337,0.008435,0.019071,0.070905,0.059413,0.040465,0.038509,0.056357,0.083496,0.025917,0.012836,0.017482
CMGC,0.080589,0.070340,0.083792,0.013709,0.050865,0.035874,0.053812,0.035874,0.074824,0.022037,0.029468,0.015247,0.007559,0.020884,0.069058,0.057143,0.044331,0.032031,0.053299,0.078668,0.042921,0.020115,0.007559,0.079718,0.075496,0.072937,0.012028,0.058733,0.035061,0.053871,0.032885,0.074472,0.021753,0.026104,0.013436,0.007806,0.021241,0.064491,0.063596,0.046449,0.037492,0.055534,0.080614,0.042482,0.018298,0.005502,0.074292,0.068930,0.074419,0.013786,...,0.084768,0.052185,0.016556,0.005960,0.074212,0.073813,0.074212,0.011970,0.053731,0.034047,0.050805,0.031520,0.074345,0.017423,0.028195,0.015029,0.010374,0.020215,0.070887,0.063173,0.045884,0.033914,0.056124,0.086980,0.048278,0.018619,0.006251,0.082465,0.066961,0.074980,0.012563,0.050789,0.034215,0.050521,0.038626,0.079925,0.019246,0.028869,0.016840,0.006816,0.023924,0.072307,0.060011,0.046111,0.035285,0.057070,0.074044,0.043972,0.018043,0.006415
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CK1,0.057860,0.076419,0.076965,0.014192,0.039847,0.037118,0.063865,0.052402,0.070961,0.017467,0.027293,0.012555,0.014192,0.024017,0.079694,0.050764,0.033297,0.035480,0.058406,0.086790,0.038210,0.015830,0.016376,0.061069,0.070883,0.069793,0.007634,0.033806,0.032715,0.053435,0.043075,0.087786,0.023991,0.022901,0.016903,0.010905,0.020720,0.076881,0.059978,0.056161,0.037077,0.061069,0.087786,0.041439,0.017448,0.006543,0.050000,0.067935,0.073913,0.010870,...,0.084629,0.041451,0.015544,0.012090,0.053148,0.068746,0.069324,0.012709,0.036973,0.021953,0.047371,0.042172,0.082034,0.017909,0.030040,0.017909,0.020220,0.019064,0.095321,0.047371,0.041594,0.034084,0.065280,0.098209,0.041594,0.021953,0.015020,0.051865,0.074592,0.088578,0.015152,0.039627,0.034382,0.055361,0.039627,0.074592,0.020979,0.034382,0.020396,0.008741,0.026224,0.079837,0.046620,0.038462,0.043124,0.067599,0.087995,0.027972,0.012238,0.011655
Atypical,0.071895,0.065359,0.063492,0.014939,0.053221,0.045752,0.058824,0.050420,0.085901,0.020542,0.017740,0.020542,0.005602,0.034547,0.049486,0.057890,0.040149,0.034547,0.059757,0.078431,0.041083,0.025210,0.004669,0.042870,0.052190,0.074557,0.013979,0.063374,0.036347,0.054986,0.031687,0.089469,0.027959,0.036347,0.015843,0.010252,0.026095,0.074557,0.056850,0.054054,0.034483,0.063374,0.068966,0.046598,0.018639,0.006524,0.051068,0.063138,0.072423,0.012071,...,0.072175,0.047483,0.013295,0.015195,0.059829,0.081671,0.066477,0.017094,0.047483,0.044634,0.050332,0.034188,0.063628,0.010446,0.037037,0.016144,0.006648,0.027540,0.047483,0.055081,0.057930,0.043685,0.059829,0.091168,0.039886,0.027540,0.014245,0.077290,0.060115,0.062977,0.015267,0.055344,0.029580,0.065840,0.026718,0.092557,0.017176,0.039122,0.023855,0.009542,0.016221,0.046756,0.054389,0.050573,0.055344,0.062977,0.068702,0.040076,0.020992,0.008588


## Entropy

In [None]:
#| export
def get_entropy(pssm_df,# a dataframe of pssm with index as aa and column as position
            return_min=False, # return min entropy as a single value or return all entropy as a pd.series
            exclude_zero=False, # exclude the column of 0 (center position) in the entropy calculation
            clean_zero=True, # if true, zero out non-last three values in position 0 (keep only s,t,y values at center)
            ): 
    "Calculate entropy per position of a PSSM surrounding 0. The less entropy the more information it contains."
    pssm_df = pssm_df.copy()
    pssm_df.columns= pssm_df.columns.astype(int)
    if 0 in pssm_df.columns:
        if clean_zero:                       
            pssm_df = _clean_zero(pssm_df)
        if exclude_zero:
            # remove columns starts with zero and columns with interger name 0
            cols_to_drop = [col for col in pssm_df.columns 
                            if col == 0 or (isinstance(col, str) and col.startswith('0'))]
            if cols_to_drop: pssm_df = pssm_df.drop(columns=cols_to_drop)

    pssm_df = pssm_df/pssm_df.sum()
    per_position = -np.sum(pssm_df * np.log2(pssm_df + EPSILON), axis=0)
    per_position[pssm_df.sum() == 0] = 0
    return float(per_position.min()) if return_min else per_position

Let $P_i(x)$ be the probability of amino acid $x$ at position $i$ in the PSSM, with $i \in \{-k, \dots, -1, 0, +1, \dots, +k\}$. The entropy at each position $i$ is defined as:

$$
H_i = - \sum_{x} P_i(x) \log_2 \left( P_i(x) + \varepsilon \right)
$$

where $\varepsilon = 10^{-8}$ is a small constant added for numerical stability.

If `exclude_zero=True`, the central position $i = 0$ is omitted from the entropy calculation.

If `clean_zero=True`, all values at position $i = 0$ are zeroed out except for amino acids Serine (S), Threonine (T), and Tyrosine (Y), typically the only possible phospho-acceptors in kinase motif analysis.

If `return_min=True`, the function returns the minimum entropy across all positions:

$$
H_{\text{spec}} = \min_i H_i
$$

Otherwise, the function returns the full vector $\{H_i\}$ for each position $i$, reflecting how much information (or uncertainty) is contained at each position in the motif.


In [None]:
# get entropy per position
get_entropy(pssm_df).sort_values()

Position
 0     1.074964
 1     3.346861
         ...   
-9     4.297737
-12    4.302189
Length: 41, dtype: float64

In [None]:
# calculate minimum entropy of surrouding positions
get_entropy(pssm_df,return_min=True,exclude_zero=True)

3.3468606104695913

In [None]:
#| export
@delegates(get_entropy)
def get_entropy_flat(flat_pssm:pd.Series,**kwargs): 
    "Calculate entropy per position of a flat PSSM surrounding 0"
    pssm_df = recover_pssm(flat_pssm)
    return get_entropy(pssm_df,**kwargs)

In [None]:
get_entropy_flat(flat_pssm).sort_values()

Position
 0     1.074964
 1     3.346861
         ...   
-9     4.297737
-12    4.302189
Length: 41, dtype: float64

In [None]:
get_entropy_flat(flat_pssm,return_min=True,exclude_zero=True)

3.3468606104695913

In [None]:
# test equal
(get_entropy_flat(flat_pssm).round(5) == get_entropy(pssm_df).round(5)).value_counts()

True    41
Name: count, dtype: int64

## Information Content

In [None]:
#| export
@delegates(get_entropy)
def get_IC(pssm_df,**kwargs):
    """
    Calculate the information content (bits) from a frequency matrix,
    using log2(3) for the middle position and log2(len(pssm_df)) for others.
    The higher the more information it contains.
    """
    
    entropy_position = get_entropy(pssm_df,**kwargs)
    
    max_entropy_array = pd.Series(np.log2(len(pssm_df)), index=entropy_position.index)

    # set exclude_zero to False
    exclude_zero = kwargs.get('exclude_zero', False)
    if exclude_zero is False: max_entropy_array[0] = np.log2(3)

    # information_content = max_entropy - entropy --> log2(N) - entropy
    IC_position = max_entropy_array - entropy_position

    # if entropy is zero, set to zero as there's no value
    IC_position[entropy_position == 0] = 0
    return IC_position

Let $P_i(x)$ be the frequency (probability) of amino acid $x$ at position $i$ in the PSSM. The standard information content (IC) at position $i$ is defined as:

$$
\mathrm{IC}_i = \max H_i - H_i
$$

which is:

$$
\mathrm{IC}_i = \log_2(N) - H_i
$$

where $N$ is the number of possible amino acids (i.e., $N = \text{len}(P_i)$).

At the center position ($i = 0$), only three amino acids (S, T, Y) are relevant, so the maximum entropy at each position is defined as:

$$
\max H_i =
\begin{cases}
\log_2(3) & \text{if } i = 0 \\
\log_2(N) & \text{otherwise}
\end{cases}
$$


In [None]:
# the higher the more conserved
get_IC(pssm_df,exclude_zero=True).sort_values()

Position
-12    0.221373
-9     0.225825
         ...   
 4     0.717760
 1     1.176701
Length: 40, dtype: float64

Check all zero cases:

In [None]:
pssm_df2=pssm_df.copy()

In [None]:
pssm_df2[-20]=0

In [None]:
get_entropy(pssm_df2,exclude_zero=True).sort_values()

Position
-20    0.000000
 1     3.346861
         ...   
-9     4.297737
-12    4.302189
Length: 40, dtype: float64

In [None]:
#| export
@delegates(get_IC)
def get_IC_flat(flat_pssm:pd.Series,**kwargs):
    """Calculate the information content (bits) from a flattened pssm pd.Series,
    using log2(3) for the middle position and log2(len(pssm_df)) for others."""
    
    pssm_df = recover_pssm(flat_pssm)
    return get_IC(pssm_df,**kwargs)

In [None]:
get_IC_flat(flat_pssm,exclude_zero=True).sort_values()

Position
-12    0.221373
-9     0.225825
         ...   
 4     0.717760
 1     1.176701
Length: 40, dtype: float64

In [None]:
(get_IC_flat(flat_pssm).round(5) == get_IC(pssm_df).round(5)).value_counts()

True    41
Name: count, dtype: int64

## Overall specificity

In [None]:
#| export
def get_specificity(pssm_df):
    "Get specificity score of a pssm, excluding zero position."
    ICs = get_IC(pssm_df, exclude_zero=True)
    # only consider IC with values
    ICs= ICs[ICs > 0]
    return float(2*ICs.max()+ICs.var())

We evaluated the overall specificity of a PSSM by combining two metrics: the maximum IC across surrounding positions and the variance of IC values:

$$
\text{Specificity Score} = 2 \times \max(\text{IC}) + \mathrm{Var}(\text{IC})
$$

In [None]:
get_specificity(pssm_df)

2.381609408364424

In [None]:
#| export
def get_specificity_flat(flat_pssm):
    "Get specificity score of a pssm, excluding zero position."
    ICs = get_IC_flat(flat_pssm, exclude_zero=True)
    return float(2*ICs.max()+ICs.var())

In [None]:
get_specificity_flat(flat_pssm)

2.381609408364424

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()