In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import numpy as np
import pandas as pd
from genplasmid.datasets import genbank_to_glm2, read_genbank
import warnings
from datasets import load_dataset

from Bio import BiopythonParserWarning

# Suppress the specific warning
warnings.filterwarnings("ignore", category=BiopythonParserWarning, message="Attempting to parse malformed locus line:")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
data = load_dataset("wconnell/openplasmid")
data = data.filter(lambda x: x['GenBank Raw'] != '')
data = data.map(lambda x: {'glm2_sequence': genbank_to_glm2(x['GenBank Raw'])})

# metadata
all_feat = data['train'].to_pandas()
all_feat['GenBank'] = all_feat['GenBank Raw'].map(read_genbank)


In [None]:
import re

def clean_gene_name(gene):
    # Remove common prefixes and suffixes
    gene = re.sub(r'^(human|mouse|rat|h|m|r)\s*', '', gene, flags=re.IGNORECASE)
    gene = re.sub(r'\s*(gene|protein)$', '', gene, flags=re.IGNORECASE)
    
    # Remove parentheses and their contents
    gene = re.sub(r'\s*\([^)]*\)', '', gene)
    
    # Remove specific strings
    gene = re.sub(r'(or nptII)', '', gene)
    
    # Standardize common gene names
    gene_map = {
        'neo': 'neomycin resistance',
        'amp': 'ampicillin resistance',
        'gfp': 'GFP',
        'egfp': 'GFP',
        'rfp': 'RFP',
        'dsred': 'RFP',
        'kan': 'kanamycin resistance',
    }
    
    for key, value in gene_map.items():
        if re.search(rf'\b{key}\b', gene, re.IGNORECASE):
            return value
    
    return gene.strip().lower()

def extract_cds_genes(record):
    genes = []
    for feature in record.features:
        if feature.type == 'CDS':
            gene = feature.qualifiers.get('gene', [])
            product = feature.qualifiers.get('product', [])
            if gene:
                genes.append(clean_gene_name(gene[0]))
            elif product:
                genes.append(clean_gene_name(product[0]))
    return genes

def calculate_gc_content(sequence):
    gc_count = sequence.count('G') + sequence.count('C')
    total_count = len(sequence)
    return gc_count / total_count if total_count > 0 else 0


In [None]:
# extract sequence features
all_feat['CDS genes'] = all_feat['GenBank'].map(extract_cds_genes)
all_feat['Sequence length'] = all_feat['GenBank'].map(lambda x: len(x.seq))
all_feat['GC content'] = all_feat['GenBank'].map(calculate_gc_content)

In [None]:
import itertools
from collections import Counter

# count labels
feature_counts = Counter(list(itertools.chain.from_iterable(all_feat['CDS genes'].values)))
feature_counts.most_common(15)
exclude = ['bla', 'op']
keywords = [key for key, value in dict(feature_counts.most_common(15)).items() if key not in exclude]
keywords = keywords[::-1]
keywords

# keywords = ['gfp', 'cas9', '6xhis', 'RFP', 'factor xa', 'aph-ii', 'aph-ia', 'laci', 't antigen']

In [None]:
def map_genes(gene):
    for keyword in keywords:
        if re.search(keyword, " ".join(gene), re.IGNORECASE):
            return keyword
    return None

all_feat['CDS curated features'] = all_feat['CDS genes'].map(map_genes)

In [None]:
# def map_genes(genes):
#     keywords = ['cas9', 'RFP', 'factor xa', 'gp41 peptide', 'luciferase', 'aph-ii', 'aph-ia', 'laci', 't antigen', 'factor xa', 'gfp', '6xhis', 'op',]
#     for gene in genes:
#         for keyword in keywords:
#             if re.search(keyword, gene, re.IGNORECASE):
#                 return keyword
#     return None

# all_feat['CDS curated features'] = all_feat['CDS features'].map(map_genes)

# Count of rows for each mapped gene and None
gene_counts = all_feat['CDS curated features'].value_counts(dropna=False)
print("\nCounts for each mapped gene:")
print(gene_counts)

# Percentage of rows with a mapped gene
mapped_percentage = (all_feat['CDS curated features'].notna().sum() / len(all_feat)) * 100
print(f"\nPercentage of rows with a mapped gene: {mapped_percentage:.2f}%")


In [None]:
# Initialize a new column with empty lists
all_feat['Entrez Genes'] = [[] for _ in range(len(all_feat))]

for idx, row in all_feat.iterrows():
    for i in range(1, 4):
        insert = row[f'Gene/Insert {i}']
        if isinstance(insert, dict) and insert.get('Entrez Gene', None):
            all_feat.at[idx, 'Entrez Genes'].append(insert['Entrez Gene'].upper())

# count the most common entrez genes
common_entrez = all_feat['Entrez Genes'].explode().value_counts().head(20).index.tolist()

# Create a reverse lookup dictionary for common_entrez
common_entrez_priority = {gene: i for i, gene in enumerate(reversed(common_entrez))}

def get_highest_priority_gene(genes):
    present_genes = [gene for gene in genes if gene in common_entrez_priority]
    if present_genes:
        return max(present_genes, key=lambda g: common_entrez_priority[g])
    return None

# Create a new column 'Common Entrez Gene' with the highest priority gene
all_feat['Common Entrez Gene'] = all_feat['Entrez Genes'].apply(get_highest_priority_gene)

# Display the first few rows to verify the new column
print(all_feat[['Entrez Genes', 'Common Entrez Gene']].head(10))

In [None]:
import scanpy as sc

if 'embeddings' not in locals():
    embeddings = np.load('data/glm2v2_embeddings.npy')

adata = sc.AnnData(embeddings, obs=all_feat)
adata.obs['log(seq_len)'] = np.log10(adata.obs['Sequence length'])
sc.tl.pca(adata)
adata

In [None]:
sc.pl.pca(
    adata, 
    color=['CDS curated features'],
    palette=sc.pl.palettes.vega_20_scanpy,
    ncols=3,
    vmin='p5',
    vmax='p95',
    vcenter='p50',
    na_color='lightgrey', # set to transparent instead of 'lightgrey'
)

In [None]:
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from sklearn.cluster import KMeans

# Filter out rows where 'CDS curated features' is None
filtered_adata = adata[adata.obs['CDS curated features'].notna()]

# Perform Spectral clustering on the PCA results
from sklearn.cluster import SpectralClustering

n_clusters = len(filtered_adata.obs['CDS curated features'].unique())
spectral = SpectralClustering(n_clusters=n_clusters, random_state=42, affinity='nearest_neighbors')
cluster_labels = spectral.fit_predict(filtered_adata.obsm['X_pca'])
# n_clusters = len(filtered_adata.obs['CDS curated features'].unique())
# kmeans = KMeans(n_clusters=n_clusters, random_state=42)
# cluster_labels = kmeans.fit_predict(filtered_adata.obsm['X_pca'])

# Get the 'CDS curated features' labels
gene_labels = filtered_adata.obs['CDS curated features'].astype('category').cat.codes

# Calculate metrics
nmi_score = normalized_mutual_info_score(gene_labels, cluster_labels)
ari_score = adjusted_rand_score(gene_labels, cluster_labels)

print(f"Normalized Mutual Information: {nmi_score:.4f}")
print(f"Adjusted Rand Index: {ari_score:.4f}")
print("")
# Optionally, visualize the clustering results
import scanpy as sc

filtered_adata.obs['KMeans_cluster'] = cluster_labels
filtered_adata.obs['KMeans_cluster'] = filtered_adata.obs['KMeans_cluster'].astype(str)
sc.pl.pca(
    filtered_adata,
    color=['CDS curated features', 'KMeans_cluster'],
    ncols=2,
    # legend_loc='on data',
    legend_fontsize='xx-small',
    title=['CDS curated features', 'KMeans Clusters']
)

In [None]:
sc.pl.pca(
    adata, 
    color=['Common Entrez Gene'],
    palette=sc.pl.palettes.vega_20_scanpy,
    ncols=3,
    vmin='p5',
    vmax='p95',
    vcenter='p50',
    na_color='none', # set to transparent instead of 'lightgrey'
)

In [None]:
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from sklearn.cluster import KMeans

# Filter out rows where 'Common Entrez Gene' is None
filtered_adata = adata[adata.obs['Common Entrez Gene'].notna()]

# Perform K-means clustering on the PCA results
n_clusters = len(filtered_adata.obs['Common Entrez Gene'].unique())
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(filtered_adata.obsm['X_pca'])

# Get the 'Common Entrez Gene' labels
gene_labels = filtered_adata.obs['Common Entrez Gene'].astype('category').cat.codes

# Calculate metrics
nmi_score = normalized_mutual_info_score(gene_labels, cluster_labels)
ari_score = adjusted_rand_score(gene_labels, cluster_labels)

print(f"Normalized Mutual Information: {nmi_score:.4f}")
print(f"Adjusted Rand Index: {ari_score:.4f}")
print("")
# Optionally, visualize the clustering results
import scanpy as sc

filtered_adata.obs['KMeans_cluster'] = cluster_labels
sc.pl.pca(
    filtered_adata,
    color=['Common Entrez Gene', 'KMeans_cluster'],
    ncols=2,
    legend_loc='on data',
    legend_fontsize='xx-small',
    title=['Common Entrez Gene', 'KMeans Clusters']
)

In [None]:
sc.pp.neighbors(filtered_adata)
sc.tl.umap(filtered_adata)
sc.pl.umap(
    filtered_adata,
    color=['Common Entrez Gene', 'KMeans_cluster'],
    ncols=2,
    legend_loc='on data',
    legend_fontsize='xx-small',
    title=['Common Entrez Gene', 'KMeans Clusters']
)

In [None]:
# Create a new column for each category and delete if all NaN
for category in list(category_mapping.keys()):
    adata.obs[category] = adata.obs['Categories'].map(lambda x: category if category in x else np.nan)
    if adata.obs[category].isna().all():
        del adata.obs[category]
        print(f"Deleted empty category: {category}")

# Plot PCA for the first 10 non-empty categories
plotted_categories = 0
for category in adata.obs.columns:
    if category in category_mapping.keys():
        sc.pl.pca(
            adata,
            color=[category],
            palette=['#ff7f0e'],
            title=category,
            na_color='lightgrey',
        )
        plotted_categories += 1
        if plotted_categories == 10:
            break

print(f"Total non-empty categories: {len([col for col in adata.obs.columns if col in category_mapping])}")