<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Figure_3/Figure_3b/validate_palmdb2palmdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Validate the alignment of PalmDB nucleotide sequences to the corresponding amino acid reference with kallisto translated search

In [1]:
!pip install -q biopython kb_python
from Bio import SeqIO
import pandas as pd
import numpy as np
from itertools import product
import kb_python.utils as kb_utils
%config InlineBackend.figure_format='retina'
import matplotlib.pyplot as plt
from tqdm import tqdm
TQDM_BAR_FORMAT = (
    "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.2/119.2 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.2/45.2 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.9/21.9 MB[0m [31m35.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

Install kallisto and bustools:

In [None]:
# Install kallisto from source
!git clone -q https://github.com/pachterlab/kallisto.git
!cd kallisto && mkdir build && cd build && cmake .. && make

# Install bustools from source
!git clone -q https://github.com/BUStools/bustools.git
!cd bustools && mkdir build && cd build && cmake .. && make

# Define paths to kallisto and bustools binaries
kallisto = "/content/kallisto/build/src/kallisto"
bustools = "/content/bustools/build/src/bustools"

  Compatibility with CMake < 3.5 will be removed from a future version of
  CMake.

  Update the VERSION argument <min> value or use a ...<max> suffix to tell
  CMake that the project does not need compatibility with older versions.

[0m
-- The C compiler identification is GNU 11.4.0
-- The CXX compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Performing Test COMPILER_SUPPORTS_CXX17
-- Performing Test COMPILER_SUPPORTS_CXX17 - Success
[0mshared build[0m
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE  
-- Found ZLIB: /us

Load data:

In [None]:
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/palmdb_clustered_t2g.txt
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/ID_to_taxonomy_mapping.csv
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/palmdb_rdrp_seqs.fa
!wget https://github.com/pachterlab/LSCHWCP_2023/raw/main/Notebooks/Figure_3/Figure_3b/palmdb_rdrp_seqs_nucleotides.fasta.zip
!unzip palmdb_rdrp_seqs_nucleotides.fasta.zip

virus_t2g = "palmdb_clustered_t2g.txt"
u_tax_dp = "ID_to_taxonomy_mapping.csv"
palmdb = "palmdb_rdrp_seqs.fa"

# PalmDB amino acid sequences reverse translated to nucleotides
palmdb_nn_fasta = "palmdb_rdrp_seqs_nucleotides.fasta"

Reverse translate virus ref amino acid sequences to nucleotides:

In [None]:
# Adapted from https://github.com/Edinburgh-Genome-Foundry/DnaChisel
from Bio.Data import CodonTable

def flatten(l):
    return [item for sublist in l for item in sublist]

def get_backtranslation_table(table_name="Standard"):
    table = CodonTable.unambiguous_dna_by_name[table_name]
    back_translation_table = {}
    for codon, amino_acid in table.forward_table.items():
        if amino_acid not in back_translation_table:
            back_translation_table[amino_acid] = []
        back_translation_table[amino_acid].append(codon)
    back_translation_table["*"] = table.stop_codons
    back_translation_table["START"] = table.start_codons
    back_translation_table["X"] = list(set(flatten(back_translation_table.values())) - set(back_translation_table["*"]))
    back_translation_table["B"] = back_translation_table["N"] + back_translation_table["D"]
    back_translation_table["J"] = back_translation_table["L"] + back_translation_table["I"]
    back_translation_table["Z"] = back_translation_table["E"] + back_translation_table["Q"]
    return back_translation_table

def reverse_translate(protein_sequence, randomize_codons=False, table="Standard"):
    """Return a DNA sequence which translates to the provided protein sequence.

    Parameters
    ----------

    protein_sequence
      A sequence string of aminoacids, e.g. "MVKK..."

    table
      Genetic code table to use (e.g. 'Standard', 'Bacterial', etc.).
      See dnachisel.biotools.CODON_TABLE_NAMES for a list of available genetic
      code tables.

    randomize_codons
      If False, the first valid codon found is used for each, which can create
      biases (GC content, etc.), if True, each amino acid gets replaced by a
      randomly selected codon for this amino acid.
    """
    backtranslation_table = get_backtranslation_table(table_name=table)
    if randomize_codons:
        random_numbers = np.random.randint(0, 1000, len(protein_sequence))
        random_indices = [
            random_number % len(backtranslation_table[aa])
            for aa, random_number in zip(protein_sequence, random_numbers)
        ]
        return "".join(
            [
                backtranslation_table[aa][random_indice]
                for aa, random_indice in zip(protein_sequence, random_indices)
            ]
        )
    return "".join([backtranslation_table[aa][0] for aa in protein_sequence])

In [None]:
# Reverse translate virus ref amino acid sequences to nucleotides and save to fasta
from Bio import SeqIO

with open(palmdb) as handle, open(palmdb_nn_fasta, "w") as palm_nuc:
    for record in SeqIO.parse(handle, "fasta"):
        palm_nuc.write(">" + str(record.id) + "\n")
        palm_nuc.write(reverse_translate(str(record.seq), table="Standard") + "\n")

Create fake R1 file with barcodes for rev translated PalmDB:

In [None]:
# Create a fake barcode for each sequence in the PalmDB
barcodes = [''.join(i) for i in product(["A", "C", "G", "T"], repeat=10)]

print(len(barcodes))
barcodes[:10]

In [None]:
# Create fake R1 file with barcodes
virus_ids = []
palmdb_nn_R1 = "palmdb_barcodes_R1.fasta"
with open(palmdb_nn_R1, "w") as fake_R1:
    for (barcode, record) in zip(barcodes[:296561], SeqIO.parse(palmdb_nn_fasta, "fasta")):
        fake_R1.write(">" + record.id + "\n")
        fake_R1.write(barcode + "\n")
        virus_ids.append(record.id)

In [None]:
len(virus_ids)

In [None]:
bc2virus = pd.DataFrame()
bc2virus["virus_ID"] = virus_ids
bc2virus["barcode"] = barcodes[:296561]
bc2virus.head()

Align:

In [None]:
# Use kallisto to align with fake barcodes
virus_index = "virus_index.idx"
out_folder = "palmdb_mapping_validation"

!$kallisto bus \
        -i $virus_index \
        -o $out_folder \
        --aa \
        -x 0,0,0:0,0,10:1,0,0 \
        -t 20 \
        $palmdb_nn_R1 $palmdb_nn_fasta

!$bustools sort -o $out_folder/output_sorted.bus $out_folder/output.bus

!$bustools count \
    --genecounts \
    --cm \
    -o $out_folder/bustools_count/ \
    -g $virus_t2g \
    -e $out_folder/matrix.ec \
    -t $out_folder/transcripts.txt \
    $out_folder/output_sorted.bus

In [None]:
# Load data
# Filepath to counts
X = f"{out_folder}/bustools_count/output.mtx"
# Filepath to barcode metadata
var_path = f"{out_folder}/bustools_count/output.genes.txt"
# Filepath to gene metadata
obs_path = f"{out_folder}/bustools_count/output.barcodes.txt"

# Create AnnData object
adata = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)

adata

In [None]:
adata.obs

In [None]:
# Add virus IDs
adata.obs = adata.obs.merge(bc2virus, how="left", on="barcode", validate="one_to_one").set_index("virus_ID")
adata.obs

Plot heatmap of ID mapping:

In [None]:
# %%time
# fig, ax = plt.subplots(figsize = (20,20))
# im = ax.imshow(adata.X.todense())
# plt.savefig("palmdb2palmdb_heatmap.png", dpi=300, bbox_inches="tight")

### Plot fraction of each taxa assigned correctly:

In [None]:
tax_df = pd.read_csv(u_tax_dp)
tax_df

In [None]:
adata.X.sum()

Get fractions of correct/incorrect mapping for each ID for each tax:

In [None]:
def get_acc_fractions(taxon):
    df = pd.DataFrame()
    df["mapping"] = ["correct", "incorrect", "not_aligned", "multimapped"]

    total_n = len(tax_df[taxon].unique()) + 1
    n = 1

    for tax_name in tax_df[taxon].unique():
        correct = 0
        incorrect = 0
        not_aligned = 0
        multimapped = 0

        with tqdm(total=len(tax_df[tax_df[taxon]==tax_name]["ID"].values), bar_format=TQDM_BAR_FORMAT) as pbar:
            pbar.set_description(f"Checking IDs for {tax_name} ({n}/{total_n})")

            for virus_id in tax_df[tax_df[taxon]==tax_name]["ID"].values:
                try:
                    mapping = adata.var[np.array(adata[adata.obs.index == virus_id, :].X.todense())[0] > 0]

                    if len(mapping) == 1:
                        mapped_id = mapping.index[0]

                        if mapped_id == virus_id:
                            correct += 1

                        else:
                            if tax_df[tax_df["ID"]==mapped_id][taxon].values[0] == tax_name:
                                correct += 1

                            else:
                                incorrect += 1

                    else:
                        multimapped += 1

                except IndexError:
                    not_aligned += 1

                pbar.update(1)

        df[tax_name] = [correct, incorrect, not_aligned, multimapped]
        n += 1

    df.to_csv(f"{out_folder}/{taxon}_mapping.csv", index=False)

In [None]:
get_acc_fractions("phylum")

In [None]:
get_acc_fractions("class")

In [None]:
get_acc_fractions("order")

In [None]:
get_acc_fractions("family")

In [None]:
get_acc_fractions("genus")

In [None]:
get_acc_fractions("species")

#### Plot:

In [None]:
def fractions(df):
    m_types = df["mapping"].values

    # # Drop unknown virus sequences and transpose dataframe
    # df_T = df.drop(".", axis=1).set_index("mapping").T
    # Transpose dataframe
    df_T = df.set_index("mapping").T

    # Add total
    df_T["total"] = df_T.sum(axis=1).values

    # Compute fraction of total count for each mapping type
    for mt in m_types:
        df_T[f"{mt}_fraction"] = (df_T[mt] / df_T["total"]).values

    return df_T

In [None]:
fig, axs = plt.subplots(figsize=(100, 20), ncols=5)

fontsize = 50
subplt_spacing = 0.04

colors = ["#3e8938", "red", "grey", "lightgrey"]
alphas = [0.75, 1, 1, 1]
m_types = ["correct", "incorrect", "not_aligned", "multimapped"]
legend_labels = ["Correct", "Incorrect", "Not aligned", "Multimapped"]

## Plot phylum fractions
phyla_df = pd.read_csv(f"{out_folder}/phylum_mapping.csv")
frac_df = fractions(phyla_df)
x = list(map(lambda st: str.replace(st, ".", "Unknown"), frac_df.index.values))
ax = axs[0]

previous_samples = 0
for alpha, color, mt in zip(alphas, colors, m_types):
    y = frac_df[f"{mt}_fraction"].values
    ax.bar(
        x,
        y,
        width = 0.75,
        bottom = previous_samples,
        color = color,
        alpha = alpha,
    )
    previous_samples = previous_samples + y

# Add number of sequences above each bar
for index, total in enumerate(frac_df["total"].values):
    ax.text(x=index, y=1.01, s="{:,}".format(total), size=fontsize, ha="left", rotation=45)

ax.set_xticklabels(x, rotation=45, ha="right")
ax.tick_params(axis="both", labelsize=fontsize)
ax.margins(x=subplt_spacing, y=0)
ax.spines[['right']].set_visible(False)
ax.patch.set_alpha(0)

ax.set_ylabel("Fraction", fontsize=fontsize+2)


## Plot class fractions
class_df = pd.read_csv(f"{out_folder}/class_mapping.csv")
frac_df = fractions(class_df)
x = list(map(lambda st: str.replace(st, ".", "Unknown"), frac_df.index.values))
ax = axs[1]

previous_samples = 0
for alpha, color, mt in zip(alphas, colors, m_types):
    y = frac_df[f"{mt}_fraction"].values
    ax.bar(
        x,
        y,
        width = 0.75,
        bottom = previous_samples,
        color = color,
        alpha = alpha,
    )
    previous_samples = previous_samples + y

# Add number of sequences above each bar
for index, total in enumerate(frac_df["total"].values):
    ax.text(x=index, y=1.01, s="{:,}".format(total), size=fontsize - 15, ha="left", rotation=45)

ax.set_xticklabels(x, rotation=45, ha="right")
ax.tick_params(axis="both", labelsize=fontsize - 15)
ax.margins(x=subplt_spacing, y=0)
ax.spines[['left', 'right']].set_visible(False)
ax.set_yticks([])
ax.patch.set_alpha(0)


## Plot order fractions
order_df = pd.read_csv(f"{out_folder}/order_mapping.csv")
frac_df = fractions(order_df)
x = list(map(lambda st: str.replace(st, ".", "Unknown"), frac_df.index.values))
ax = axs[2]

previous_samples = 0
for alpha, color, mt in zip(alphas, colors, m_types):
    y = frac_df[f"{mt}_fraction"].values
    ax.bar(
        x,
        y,
        width = 0.75,
        bottom = previous_samples,
        color = color,
        alpha = alpha,
    )
    previous_samples = previous_samples + y

# Add number of sequences above each bar
for index, total in enumerate(frac_df["total"].values):
    ax.text(x=index, y=1.01, s="{:,}".format(total), size=fontsize - 25, ha="left", rotation=45)

ax.set_xticklabels(x, rotation=45, ha="right")
ax.tick_params(axis="both", labelsize=fontsize - 25)
ax.margins(x=subplt_spacing, y=0)
ax.spines[['left', 'right']].set_visible(False)
ax.set_yticks([])
ax.patch.set_alpha(0)

## Plot family fractions
family_df = pd.read_csv(f"{out_folder}/family_mapping.csv")
frac_df = fractions(family_df)
x = list(map(lambda st: str.replace(st, ".", "Unknown"), frac_df.index.values))
ax = axs[3]

previous_samples = 0
for alpha, color, mt in zip(alphas, colors, m_types):
    y = frac_df[f"{mt}_fraction"].values
    ax.bar(
        x,
        y,
        width = 1,
        bottom = previous_samples,
        color = color,
        alpha = alpha,
    )
    previous_samples = previous_samples + y

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.margins(x=subplt_spacing, y=0)
ax.spines[['left', 'right']].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.patch.set_alpha(0)

## Plot genus and species fractions
genus_df = pd.read_csv(f"{out_folder}/genus_mapping.csv")
species_df = pd.read_csv(f"{out_folder}/species_mapping.csv")

# Combine genus and species dataframe since there are too many taxa in each category to plot separately
combined_df = pd.DataFrame()
combined_df["genus"] = genus_df.set_index("mapping").sum(axis=1).values
combined_df["species"] = species_df.set_index("mapping").sum(axis=1).values
combined_df["mapping"] = species_df["mapping"].values

frac_df = fractions(combined_df)
x = list(map(lambda st: str.replace(st, ".", "Unknown"), frac_df.index.values))
ax = axs[4]

previous_samples = 0
for alpha, color, mt in zip(alphas, colors, m_types):
    y = frac_df[f"{mt}_fraction"].values
    ax.bar(
        x,
        y,
        width = 0.85,
        bottom = previous_samples,
        color = color,
        alpha = alpha,
        label = mt
    )
    previous_samples = previous_samples + y

# Add number of sequences above each bar
for index, total in enumerate(frac_df["total"].values):
    ax.text(x=index, y=1.01, s="{:,}".format(total), size=fontsize, ha="center")

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.margins(x=subplt_spacing, y=0)
ax.spines[['left']].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.patch.set_alpha(0)

# Add legend
ax.legend(bbox_to_anchor=(1.001, 1.025), loc="upper left", fontsize=fontsize, labels=legend_labels)


# Remove space between subplots
plt.subplots_adjust(wspace=0)


# Add grid
ax3 = fig.add_subplot(111, zorder=-1)
for _, spine in ax3.spines.items():
    spine.set_visible(False)
ax3.set_xticks([])
ax3.tick_params(labelleft=False, labelbottom=False, left=False, right=False)
ax3.get_shared_y_axes().join(axs[0], axs[1], axs[2], axs[3], axs[4])
ax3.grid(color="lightgrey", ls="--", lw=1)

plt.savefig("tax_assignment_validation.png", dpi=300, bbox_inches="tight", transparent=True)

fig.show()

Get total percentages at species-level:

In [None]:
species_df = pd.read_csv(f"{out_folder}/species_mapping.csv")
species_df = species_df.set_index("mapping").sum(axis=1).to_frame().T
species_df

In [None]:
total = species_df.sum(axis=1).values[0]

for column in species_df.columns:
    print(f"% {column}: {(species_df[column].values[0] / total) * 100}")

print(total)