goal: to train a genome-level classifier that predicts whether a genome is a PICI, cfPICI, P4, or phage.

Input for the model: a vector of proteins in each PHROG functional category for a given genome (e.g., number of tail, head, integrase, lysis, etc. proteins). The final output should be a trained model that can classify new genomes based on their PHROG function profiles.

- First, construct training data. The PHROG annotations are already available, so use the true functional categories to construct the input vectors. construct vectprs according to the sequence of proteins; reverse the order if majority of proteins are on the negative strand.
- Once the feature vectors are built for each genome, train a multiclass classifier (like XGBoost or LightGBM).
- Target label: the genome type (PICI, cfPICI, P4, or phage).

whole workflow: 
   Bacterial Genome 
   → Sliding Window 
   → Protein Prediction 
   → PHROG Function Prediction (your trained predictors)
   → Feature Vector Construction 
   → Multi-class Classification

In [1]:
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

# data preparation

In [13]:
annotation = pd.read_parquet("../dataset/Phage_and_Satellites_Pann_Pcat_Pcol.pa")

In [74]:
# def get_acc_info(annotation, type: str):
#     annotation_type = annotation[annotation["what"] == type]
#     acc_df = pd.DataFrame(columns=["acc", "average_strand", "num_proteins"])
#     for acc in annotation_type["acc"].unique():
#         data = annotation_type[annotation_type["acc"] == acc]
#         num_proteins = len(data)
#         average_strand = data["strand"].mean()
#         acc_df = pd.concat(
#             [
#                 acc_df,
#                 pd.DataFrame(
#                     {
#                         "acc": [acc],
#                         "average_strand": [average_strand],
#                         "num_proteins": [num_proteins],
#                     }
#                 ),
#             ],
#         )
#     return acc_df


# acc_pici_df = get_acc_info(annotation, "PICI")
# acc_cfpici_df = get_acc_info(annotation, "CFPICI")
# acc_p4_df = get_acc_info(annotation, "P4")
# acc_phage_df = get_acc_info(annotation, "phage")
# acc_phage_df = pd.read_csv("../dataset/acc_phage_df.csv")

In [7]:
# acc_pici_df["what"] = "PICI"
# acc_cfpici_df["what"] = "CFPICI"
# acc_p4_df["what"] = "P4"
# acc_phage_df["what"] = "phage"
# all_acc_df = pd.concat([acc_pici_df, acc_cfpici_df, acc_p4_df, acc_phage_df])
# all_acc_df.to_csv("../dataset/acc_all_df.csv", index=False)

In [None]:
# all_acc = all_acc_df["acc"].unique()
# annotation_phage = annotation[annotation["what"] == "phage"]
# annotation_pici = annotation[annotation["what"] == "PICI"]
# annotation_cfpici = annotation[annotation["what"] == "CFPICI"]
# annotation_p4 = annotation[annotation["what"] == "P4"]
# pici_acc = annotation_pici["acc"].unique()
# cfpici_acc = annotation_cfpici["acc"].unique()
# p4_acc = annotation_p4["acc"].unique()
# phage_acc = annotation_phage["acc"].unique()
# print(f"all_acc: {len(all_acc)}")
# print(f"pici_acc: {len(pici_acc)}")
# print(f"cfpici_acc: {len(cfpici_acc)}")
# print(f"p4_acc: {len(p4_acc)}")
# print(f"phage_acc: {len(phage_acc)}")

In [None]:
# import matplotlib.pyplot as plt

# plt.figure(figsize=(10, 6))
# plt.subplot(2, 2, 1)
# plt.hist(acc_pici_df["average_strand"], bins=100, edgecolor="black")
# plt.title("PICI Average Strand")
# plt.xlabel("Average Strand")
# plt.ylabel("Frequency")

# plt.subplot(2, 2, 2)
# plt.hist(acc_cfpici_df["average_strand"], bins=100, edgecolor="black")
# plt.title("CFPICI Average Strand")
# plt.xlabel("Average Strand")
# plt.ylabel("Frequency")

# plt.subplot(2, 2, 3)
# plt.hist(acc_p4_df["average_strand"], bins=100, edgecolor="black")
# plt.title("P4 Average Strand")
# plt.xlabel("Average Strand")
# plt.ylabel("Frequency")

# plt.subplot(2, 2, 4)
# plt.hist(acc_phage_df["average_strand"], bins=100, edgecolor="black")
# plt.title("phage Average Strand")
# plt.xlabel("Average Strand")
# plt.ylabel("Frequency")

# plt.tight_layout()
# plt.show()

In [None]:
# plt.figure(figsize=(10, 8))
# plt.subplot(2, 2, 1)
# plt.hist(acc_pici_df["num_proteins"], bins=100, edgecolor="black")
# plt.title("Number of Proteins, PICI")
# plt.xlabel("Number of Proteins")
# plt.ylabel("Frequency")

# plt.subplot(2, 2, 2)
# plt.hist(acc_cfpici_df["num_proteins"], bins=100, edgecolor="black")
# plt.title("Number of Proteins, CFPICI")
# plt.xlabel("Number of Proteins")
# plt.ylabel("Frequency")

# plt.subplot(2, 2, 3)
# plt.hist(acc_p4_df["num_proteins"], bins=100, edgecolor="black")
# plt.title("Number of Proteins, P4")
# plt.xlabel("Number of Proteins")
# plt.ylabel("Frequency")

# plt.subplot(2, 2, 4)
# plt.hist(acc_phage_df["num_proteins"], bins=100, edgecolor="black")
# plt.title("Number of Proteins, phage")
# plt.xlabel("Number of Proteins")
# plt.ylabel("Frequency")

# plt.tight_layout()
# plt.show()

# feature vector

In [21]:
# from pici_predictor.phrog_function import function_name_raw_to_num

In [22]:
function_name_raw_to_num = {
    "lysis": 1,
    "tail": 2,
    "connector": 3,
    "DNA, RNA and nucleotide metabolism": 4,
    "head and packaging": 5,
    "other": 6,
    "transcription regulation": 7,
    "moron, auxiliary metabolic gene and host takeover": 8,
    "unknown function": 9,
    "integration and excision": 10,
    "unknown_no_hit": 11,
}

In [19]:
all_acc_df = pd.read_csv("../dataset/acc_all_df.csv")

In [23]:
def prepare_sequence_vector(annotation, acc, average_strand, max_length=30):
    # Get ordered proteins
    element_proteins = annotation[annotation["acc"] == acc].sort_values("start")
    if average_strand < 0:
        element_proteins = element_proteins.iloc[::-1]

    # Convert functions to numbers using your existing mapping
    function_vector = [
        function_name_raw_to_num[func] for func in element_proteins["pcat"]
    ]

    # Truncate if longer than max_length
    if len(function_vector) > max_length:
        function_vector = function_vector[:max_length]

    # Pad with 0 if shorter than max_length
    if len(function_vector) < max_length:
        function_vector = function_vector + [0] * (max_length - len(function_vector))

    return function_vector

In [24]:
def get_feature_vector(row):
    return prepare_sequence_vector(
        annotation, row["acc"], row["average_strand"], max_length=30
    )


all_acc_df["feature_vector"] = all_acc_df.apply(get_feature_vector, axis=1)

In [25]:
all_acc_df.to_csv("all_acc_vector.csv", index=False)

# training

In [2]:
all_acc_vector = pd.read_csv("../dataset/all_acc_vector.csv")

In [None]:
# Convert to numpy arrays for training
X = np.array(all_acc_df["feature_vector"].tolist())
y = all_acc_df["what"]

# X will be shape (30407, 30) - all your samples with 30 features each
# y will be shape (30407,) - the labels

In [None]:
# For XGBoost configuration:
xgb_params = {
    "enable_categorical": True,  # Tell XGBoost these are categorical
    "tree_method": "hist",  # Required for categorical features
    # Categories will be 0 (padding), 1-11 (your function categories)
}