# PGB1 logistic regression

The aim is to build a simple logistic regression model that predicts the probaility of presence or absence of the PG_binding_1 domain in a genome given the distribution of presence or absence of all other Pfam & TIGR domains. Phylogeny is also added (taxonomic rank: order) to manage confounding variables.

In [156]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib as plt
import seaborn as sns
import statsmodels.api as sm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.utils.class_weight import compute_class_weight

cwd = os.getcwd()
if cwd.endswith('notebook'):
    os.chdir('..')
    cwd = os.getcwd()

In [157]:
sns.set_theme(palette='colorblind', font_scale=1.3)
palette = sns.color_palette().as_hex()

data_folder = Path('./data/')
assert data_folder.is_dir()

db_proka = Path('../db_proka/')
assert db_proka.is_dir()

## Load dataset

In [158]:
target_domain = 'PG_binding_1'

### Pfam

In [159]:
pfam_df = pd.read_csv(data_folder / 'pfam_bacteria.csv', index_col='assembly_accession')
assert target_domain in set(pfam_df.columns)

print(f'Number of genomes: {len(pfam_df):,}')
print(f'Number of domains: {len(pfam_df.columns):,}')

Number of genomes: 32,507
Number of domains: 3,981


In [160]:
p = 100 * len(pfam_df[pfam_df[target_domain] == 1]) / len(pfam_df)

print(f'{target_domain} is present in {p:.0f}% of all genomes.')

PG_binding_1 is present in 64% of all genomes.


### TIGR

In [161]:
TIGR_df = pd.read_csv(data_folder / 'TIGR_bacteria.csv', index_col='assembly_accession')

### Phylogeny

In [162]:
phylogeny_df = pd.read_csv(data_folder / 'taxonomy_order_bacteria.csv', index_col='assembly_accession')

### Merge datasets

In [163]:
data_df = pd.merge(
    pfam_df,
    TIGR_df,
    left_index=True,
    right_index=True,
    how='inner',
)
data_df = pd.merge(
    data_df,
    phylogeny_df,
    left_index=True,
    right_index=True,
    how='inner',
)

In [164]:
print(f'Number of genomes: {len(data_df):,}')
print(f'Number of columns: {len(data_df.columns):,}')

Number of genomes: 32,507
Number of columns: 6,704


## PGB1 exploration

In [165]:
gtdb_metadata = pd.read_csv(db_proka / 'gtdb_metadata.csv', index_col='ncbi_accession')
gtdb_metadata['domain'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[0].replace('d__', ''))
gtdb_metadata['gtdb_phylum'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[1].replace('p__', ''))
gtdb_metadata['gtdb_class'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[2].replace('c__', ''))
gtdb_metadata['gtdb_order'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[3].replace('o__', ''))
gtdb_metadata['gtdb_family'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[4].replace('f__', ''))
gtdb_metadata['gtdb_genus'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[5].replace('g__', ''))
gtdb_metadata['gtdb_species'] = gtdb_metadata['gtdb_taxonomy'].apply(lambda t: t.split(';')[6].replace('s__', ''))

bacterial_genomes = gtdb_metadata[gtdb_metadata['domain'] == 'Bacteria']

In [166]:
pgb1_metadata = pd.merge(
    pfam_df[[target_domain]],
    bacterial_genomes[['gtdb_phylum', 'gtdb_class', 'gtdb_order', 'gtdb_family', 'gtdb_genus', 'gtdb_species']],
    left_index=True,
    right_index=True,
)
pgb1_metadata.head()

Unnamed: 0,PG_binding_1,gtdb_phylum,gtdb_class,gtdb_order,gtdb_family,gtdb_genus,gtdb_species
GCA_000007325.1,1,Fusobacteriota,Fusobacteriia,Fusobacteriales,Fusobacteriaceae,Fusobacterium,Fusobacterium nucleatum
GCA_000008885.1,0,Pseudomonadota,Gammaproteobacteria,Enterobacterales,Enterobacteriaceae,Wigglesworthia,Wigglesworthia glossinidia_A
GCA_000009845.1,0,Bacillota,Bacilli,Acholeplasmatales,Acholeplasmataceae,Phytoplasma,Phytoplasma sp000009845
GCA_000010565.1,1,Bacillota_B,Desulfotomaculia,Desulfotomaculales,Pelotomaculaceae,Pelotomaculum,Pelotomaculum thermopropionicum
GCA_000011445.1,0,Bacillota,Bacilli,Mycoplasmatales,Mycoplasmataceae,Mycoplasma,Mycoplasma mycoides


In [167]:
taxonomy_to_subset = {}
for taxonomy in ['gtdb_phylum', 'gtdb_class', 'gtdb_order', 'gtdb_family', 'gtdb_genus']:
    g = pgb1_metadata[[taxonomy, target_domain]].groupby(taxonomy).agg(['sum', 'count'])

    subset_to_consider = sorted(set(g[
        #(g[(target_domain, 'sum')] > 0) &
        (g[(target_domain, 'count')] > 1)
    ].index))

    taxonomy_to_subset[taxonomy] = subset_to_consider

    p = 100 * len(subset_to_consider) / len(g.index)
    print(f'{taxonomy}: {len(subset_to_consider):,} out of {len(g.index):,} ({p:.0f} %)')

gtdb_phylum: 126 out of 155 (81 %)
gtdb_class: 300 out of 385 (78 %)
gtdb_order: 840 out of 1,172 (72 %)
gtdb_family: 1,850 out of 2,882 (64 %)
gtdb_genus: 5,860 out of 12,553 (47 %)


In [168]:
pgb1_subset = pgb1_metadata[
    pgb1_metadata['gtdb_phylum'].isin(taxonomy_to_subset['gtdb_phylum']) &
    pgb1_metadata['gtdb_class'].isin(taxonomy_to_subset['gtdb_class']) &
    pgb1_metadata['gtdb_order'].isin(taxonomy_to_subset['gtdb_order']) &
    pgb1_metadata['gtdb_family'].isin(taxonomy_to_subset['gtdb_family']) &
    pgb1_metadata['gtdb_genus'].isin(taxonomy_to_subset['gtdb_genus'])
]

p = 100 * len(pgb1_subset) / len(pgb1_metadata)
print(f'Relevant genomes: {len(pgb1_subset):,} out of {len(pgb1_metadata):,} ({p:.0f} %)')

Relevant genomes: 25,814 out of 32,507 (79 %)


In [169]:
# data_df = data_df.loc[pgb1_subset.index]

In [170]:
print(f'Number of genomes: {len(data_df):,}')
print(f'Number of columns: {len(data_df.columns):,}')

Number of genomes: 32,507
Number of columns: 6,704


In [171]:
n_genomes_with_domain = data_df[target_domain].sum()
p = 100 * n_genomes_with_domain / len(data_df)
print(f'Number of genomes with {target_domain}: {n_genomes_with_domain:,} ({p:.0f} %)')

Number of genomes with PG_binding_1: 20,955 (64 %)


## Prepare variables

In [172]:
response = data_df[target_domain]
predictors = data_df.loc[:, data_df.columns != target_domain]

# Add constant term for intercept
predictors = sm.add_constant(predictors)

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(predictors, response, test_size=0.1, random_state=123)

# Calculate class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.array([0, 1]), y=y_train)
weights = np.array([class_weights[0] if x == 0 else class_weights[1] for x in y_train])
print(class_weights)

[1.4014179  0.77734084]


## Train logistic regression

In [173]:
reg_param = 1e-3

# Fitting the logistic regression model
model = sm.GLM(
    y_train, 
    X_train, 
    family=sm.families.Binomial(), 
    freq_weights=weights,
).fit_regularized(
    maxiter=100, 
    alpha=reg_param,
    L1_wt=0.5,
)



## Evaluate model

In [174]:
# Predict on the validation set
y_pred = model.predict(X_val)
y_pred_binary = [1 if x >= 0.5 else 0 for x in y_pred]

# Calculate validation metrics
accuracy = accuracy_score(y_val, y_pred_binary)
conf_matrix = confusion_matrix(y_val, y_pred_binary)
class_report = classification_report(y_val, y_pred_binary)

print("\nValidation set performance:")
print(f"Accuracy: {accuracy}")
print("Confusion Matrix:")
print(conf_matrix)
print("Classification Report:")
print(class_report)


Validation set performance:
Accuracy: 0.8741925561365733
Confusion Matrix:
[[ 993  121]
 [ 288 1849]]
Classification Report:
              precision    recall  f1-score   support

           0       0.78      0.89      0.83      1114
           1       0.94      0.87      0.90      2137

    accuracy                           0.87      3251
   macro avg       0.86      0.88      0.86      3251
weighted avg       0.88      0.87      0.88      3251



## Identify top domains

In [175]:
# Extract p-values and coefficients of the coefficients from the training model
p_values = model.pvalues
coefficients = model.params

# Filter significant p-values
significant_vars = p_values[p_values < 1e-2].sort_values()

# Removing the intercept from the significant variables
try:
    significant_vars = significant_vars.drop('const')
except:
    pass

# Extracting the values of the significant parameters
significant_coefficients = coefficients[significant_vars.index]
sorted_significant_coeff = sorted(significant_coefficients.items(), reverse=True, key=lambda t: t[1])

# Display the ranked significant regressors and their coefficients
print("\nRanked significant regressors from the training set and their coefficients:")
for i, (var, coef) in enumerate(sorted_significant_coeff):
    print(f"{var}: {coef:.3f}")

AttributeError: 'RegularizedResults' object has no attribute 'pvalues'

In [182]:
coefficients = model.params
sorted_coeff = sorted(coefficients.items(), reverse=True, key=lambda t: t[1])
for i, (var, coef) in enumerate(sorted_coeff):
    if coef > 0 or coef < 0:
        print(f"{var}: {coef:.3f}")

Scaffold: 1.906
PG_binding_2: 1.567
TIGR02283: 1.151
SpoIIIAC: 0.557
Spore_III_AB: 0.541
TIGR01925: 0.424
TIGR01424: 0.424
Com_YlbF: 0.406
TIGR02848: 0.343
Peptidase_M15_2: 0.335
TIGR00247: 0.335
Peptidase_S11: 0.327
TIGR01134: 0.289
SpoIIE_N: 0.276
YkuD: 0.273
TIGR01646: 0.269
TIGR00690: 0.267
DUF1540: 0.266
Peptidase_S41: 0.265
RimP_N: 0.264
Peptidase_M14: 0.249
HTH_40: 0.242
TIGR01740: 0.239
TIGR00225: 0.235
TIGR01449: 0.229
ExoD: 0.228
TIGR00858: 0.223
DUF378: 0.220
FxsA: 0.218
ATP-gua_Ptrans: 0.216
Peptidase_U4: 0.214
TIGR00117: 0.211
TIGR03998: 0.210
5-FTHF_cyc-lig: 0.209
DUF4430: 0.206
HTH_ParB: 0.198
PGDH_inter: 0.197
TIGR02483: 0.197
TIGR00706: 0.196
LpxI_C: 0.196
Cupin_7: 0.192
HATPase_c_2: 0.191
DUF808: 0.191
LMWPc: 0.189
TIGR01459: 0.189
CorA: 0.189
SepF: 0.188
Thioredox_DsbH: 0.188
UPF0047: 0.187
GerA: 0.186
Dodecin: 0.185
TIGR04018: 0.183
Toprim_C_rpt: 0.182
TIGR01887: 0.179
TIGR03599: 0.179
o_Erysipelotrichales: 0.178
CSD2: 0.175
PEPCK_N: 0.175
Fn3-like: 0.173
TIGR00611:

In [186]:
acc = TIGR_df[TIGR_df['TIGR02729'] == 1].index
gtdb_metadata.loc[acc].reset_index()[['gtdb_phylum', 'gtdb_class', 'gtdb_order', 'assembly_accession']].groupby([
    'gtdb_phylum', 'gtdb_class', 'gtdb_order'
]).count().sort_values('assembly_accession', ascending=False).head(20)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,assembly_accession
gtdb_phylum,gtdb_class,gtdb_order,Unnamed: 3_level_1
Pseudomonadota,Gammaproteobacteria,Burkholderiales,1695
Bacillota_A,Clostridia,Lachnospirales,1257
Bacteroidota,Bacteroidia,Bacteroidales,1226
Bacillota_A,Clostridia,Oscillospirales,1217
Bacteroidota,Bacteroidia,Flavobacteriales,1047
Pseudomonadota,Alphaproteobacteria,Rhizobiales,1022
Actinomycetota,Actinomycetia,Actinomycetales,986
Pseudomonadota,Alphaproteobacteria,Rhodobacterales,866
Pseudomonadota,Gammaproteobacteria,Pseudomonadales,731
Pseudomonadota,Gammaproteobacteria,Enterobacterales_A,574
