<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Supp_Fig_1/remove_targets_from_ref.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Remove target virus species, genus and family from reference index prior to performing alignment

In [None]:
!pip install -q kb_python biopython

In [None]:
from Bio import SeqIO
import pandas as pd
import numpy as np
from itertools import product
import kb_python.utils as kb_utils
%config InlineBackend.figure_format='retina'
import matplotlib.pyplot as plt
from tqdm import tqdm
TQDM_BAR_FORMAT = (
    "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
)

def nd(arr):
    """
    Function to transform numpy matrix to nd array.
    """
    return np.asarray(arr).reshape(-1)

In [None]:
# Install kallisto from source
!git clone -q https://github.com/pachterlab/kallisto.git
!cd kallisto && mkdir build && cd build && cmake .. && make

# Install bustools from source
!git clone -q https://github.com/BUStools/bustools.git
!cd bustools && mkdir build && cd build && cmake .. && make

# Define paths to kallisto and bustools binaries
kallisto = "/content/kallisto/build/src/kallisto"
bustools = "/content/bustools/build/src/bustools"

In [None]:
# Download the customized transcripts to gene mapping
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/palmdb_clustered_t2g.txt
# Download the RdRP amino acid sequences
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/palmdb_rdrp_seqs.fa

virus_fasta = "palmdb_rdrp_seqs.fa"
virus_t2g = "palmdb_clustered_t2g.txt"

In [None]:
# Load virus ID to taxonomy mapping
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/ID_to_taxonomy_mapping.csv
taxmap = pd.read_csv("ID_to_taxonomy_mapping.csv")
taxmap

In [None]:
# Get fasta file containing ZEBOV RdRP sequences
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/Figure_2/Figure_2c/SRR12698539_2_extracted_u10.fa
nn_fasta = "SRR12698539_2_extracted_u10.fa"

In [None]:
# Number of threads used for alignments
threads = 2

In [None]:
out_folder = "remove_targets_from_ref"

____
# Remove all sequences linked to the genus *Ebolavirus* from the reference index

Create new fasta excluding all sequences associated with the Ebolavirus genus:

In [None]:
targets_all = taxmap[taxmap["genus"].str.contains("Ebolavirus")]["ID"].values

In [None]:
%%time
records = SeqIO.parse(virus_fasta, "fasta")

test_fasta = f"{out_folder}/uniques_nodup_targets_removed.fa"
with open(test_fasta, "w") as new_fasta:
    for record in records:
        if record.id not in targets_all:
            new_fasta.write(">" + record.id + "\n")
            new_fasta.write(str(record.seq) + "\n")

Generate reference index using this new fasta:

In [None]:
virus_index = f"{out_folder}/uniques_nodup_targets_removed.idx"

!$kallisto index \
    --aa \
    -t $threads \
    -i $virus_index \
    $test_fasta

Align ZEBOV RdRP nucleotide sequences to the new index:

In [None]:
!$kallisto bus \
        --aa \
        -i $virus_index \
        -o $out_folder/kallisto \
        -x bulk \
        -t $threads \
        $nn_fasta

!$bustools sort -o $out_folder/kallisto/output_sorted.bus $out_folder/kallisto/output.bus

!$bustools count \
    --genecounts \
    --cm \
    -o $out_folder/kallisto/bustools_count/ \
    -g $virus_t2g \
    -e $out_folder/kallisto/matrix.ec \
    -t $out_folder/kallisto/transcripts.txt \
    $out_folder/kallisto/output_sorted.bus

____
# Remove all *Ebolavirus* species sequences from the reference index

In [None]:
targets_all = taxmap[taxmap["species"].str.contains("Ebolavirus")]["ID"].values

In [None]:
%%time
records = SeqIO.parse(virus_fasta, "fasta")

test_fasta = f"{out_folder}/uniques_nodup_targets_removed_2.fa"
with open(test_fasta, "w") as new_fasta:
    for record in records:
        if record.id not in targets_all:
            new_fasta.write(">" + record.id + "\n")
            new_fasta.write(str(record.seq) + "\n")

Generate index with new fasta:

In [None]:
virus_index = f"{out_folder}/uniques_nodup_targets_removed_2.idx"

!$kallisto index \
    -t $threads \
    --aa \
    -i $virus_index \
    $test_fasta

Align fasta with ZEBOV RdRP nucleotide sequences:

In [None]:
!$kallisto bus \
        -i $virus_index \
        -o $out_folder/kallisto2 \
        --aa \
        -x bulk \
        -t $threads \
        $nn_fasta

!$bustools sort -o $out_folder/kallisto2/output_sorted.bus $out_folder/kallisto2/output.bus

!$bustools count \
    --genecounts \
    --cm \
    -o $out_folder/kallisto2/bustools_count/ \
    -g $virus_t2g \
    -e $out_folder/kallisto2/matrix.ec \
    -t $out_folder/kallisto2/transcripts.txt \
    $out_folder/kallisto2/output_sorted.bus

____
# Remove sequences linked to the family *Filoviridae* from the reference index

In [None]:
targets_all = taxmap[taxmap["family"].str.contains("Filoviridae")]["ID"].values

In [None]:
%%time
records = SeqIO.parse(virus_fasta, "fasta")

test_fasta = f"{out_folder}/uniques_nodup_targets_removed_3.fa"
with open(test_fasta, "w") as new_fasta:
    for record in records:
        if record.id not in targets_all:
            new_fasta.write(">" + record.id + "\n")
            new_fasta.write(str(record.seq) + "\n")

Generate index with new fasta:

In [None]:
virus_index = f"{out_folder}/uniques_nodup_targets_removed_3.idx"

!$kallisto index \
    -t $threads \
    --aa \
    -i $virus_index \
    $test_fasta

Align fasta with ZEBOV RdRP nucleotide sequences:

In [None]:
!$kallisto bus \
        -i $virus_index \
        -o $out_folder/kallisto3 \
        --aa \
        -x bulk \
        -t $threads \
        $nn_fasta

!$bustools sort -o $out_folder/kallisto3/output_sorted.bus $out_folder/kallisto3/output.bus

!$bustools count \
    --genecounts \
    --cm \
    -o $out_folder/kallisto3/bustools_count/ \
    -g $virus_t2g \
    -e $out_folder/kallisto3/matrix.ec \
    -t $out_folder/kallisto3/transcripts.txt \
    $out_folder/kallisto3/output_sorted.bus

___
# Load results and plot which taxonomies the sequences were aligned to

Load alignment data where all sequences linked to Ebolavirus genera were removed from the reference:

In [None]:
# Filepath to counts
X = f"{out_folder}/kallisto/bustools_count/output.mtx"
# Filepath to barcode metadata
var_path = f"{out_folder}/kallisto/bustools_count/output.genes.txt"
# Filepath to gene metadata
obs_path = f"{out_folder}/kallisto/bustools_count/output.barcodes.txt"

# Create AnnData object
adata = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)
adata

In [None]:
ids_seen = adata.var[nd(adata.X.sum(axis=0) > 0)].index.values
ids_seen

In [None]:
taxmap[taxmap["rep_ID"].isin(ids_seen)]

Load data where all sequences linked to Ebolavirus species were removed:

In [None]:
# Filepath to counts
X = f"{out_folder}/kallisto2/bustools_count/output.mtx"
# Filepath to barcode metadata
var_path = f"{out_folder}/kallisto2/bustools_count/output.genes.txt"
# Filepath to gene metadata
obs_path = f"{out_folder}/kallisto2/bustools_count/output.barcodes.txt"

# Create AnnData object
adata2 = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)
adata2

In [None]:
ids_seen2 = adata2.var[nd(adata2.X.sum(axis=0) > 0)].index.values
ids_seen2

In [None]:
taxmap[taxmap["rep_ID"].isin(ids_seen2)]

Load data where all sequences linked to the Filoviridae family were removed:

In [None]:
# Filepath to counts
X = f"{out_folder}/kallisto3/bustools_count/output.mtx"
# Filepath to barcode metadata
var_path = f"{out_folder}/kallisto3/bustools_count/output.genes.txt"
# Filepath to gene metadata
obs_path = f"{out_folder}/kallisto3/bustools_count/output.barcodes.txt"

# Create AnnData object
adata3 = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)
adata3

In [None]:
ids_seen3 = adata3.var[nd(adata3.X.sum(axis=0) > 0)].index.values
ids_seen3

In [None]:
taxmap[taxmap["rep_ID"].isin(ids_seen3)]

# Plot bar plots

In [None]:
# total number of sequences
total = 676

In [None]:
colors = ["#003049", "#4b8eb3", "#8fc0de"]

In [None]:
def barplot(tax_level, expected_tax):
    fig, ax = plt.subplots(figsize=(5,7))
    fontsize=18

    correct = adata2[:, adata2.var.index.isin(taxmap[taxmap[tax_level] == expected_tax]["rep_ID"].unique())].X.sum()
    # incorrect = adata2.X.sum() - correct
    ax.bar(0, correct, color=colors[0], edgecolor="black")
    not_aligned = total - correct
    if tax_level == "phylum":
        ax.bar(0, not_aligned, bottom=correct, color="white", edgecolor="black", hatch="/")

    if tax_level == "phylum":
        ax.text(0, correct-correct/2, str(int(correct)), fontsize=fontsize, ha="center", color="white")
    else:
        if correct > 0:
            ax.text(0, correct+10, str(int(correct)), fontsize=fontsize, ha="center")


    total_aligned = 145
    correct = adata[:, adata.var.index.isin(taxmap[taxmap[tax_level] == expected_tax]["rep_ID"].unique())].X.sum()
    # incorrect = adata.X.sum() - correct
    # multimapped = total_aligned - adata.X.sum()
    ax.bar(1, correct, color=colors[1], edgecolor="black")
    not_aligned = total - correct
    if tax_level == "phylum":
        ax.bar(1, not_aligned, bottom=correct, color="white", edgecolor="black", hatch="/")
    if tax_level == "phylum":
        ax.text(1, correct-correct/2, str(int(correct)), fontsize=fontsize, ha="center")
    else:
        if correct > 0:
            ax.text(1, correct+10, str(int(correct)), fontsize=fontsize, ha="center")
    # ax.bar(1, incorrect, bottom=correct, color="white", edgecolor="black", hatch="+")
    # ax.bar(1, multimapped, bottom=incorrect+correct, color="white", edgecolor="black", hatch="//")

    correct = adata3[:, adata3.var.index.isin(taxmap[taxmap[tax_level] == expected_tax]["rep_ID"].unique())].X.sum()
    # incorrect = adata3.X.sum() - correct
    not_aligned = total - correct
    if tax_level == "phylum":
        ax.bar(2, not_aligned, bottom=correct, color="white", edgecolor="black", hatch="/")
    ax.bar(2, correct, color=colors[2], edgecolor="black")
    if tax_level == "phylum":
        ax.text(2, correct-correct/2, str(int(correct)), fontsize=fontsize, ha="center")
    else:
        if correct > 0:
            ax.text(2, correct+10, str(int(correct)), fontsize=fontsize, ha="center")
    # ax.bar(2, incorrect, bottom=correct, color="white", edgecolor="black", hatch="+")

    ax.axhline(total, color="black", ls="--")

    # x_labels = ["All Ebolavirus\nspecies excluded", "All Ebolavirus\ngenera excluded", "All Filoviridae\nexcluded"]
    # ax.set_xticks([0,1,2], x_labels)
    ax.set_xticks([])

    if tax_level == "phylum":
        ax.set_ylabel("Counts", fontsize=fontsize)
    else:
        ax.set_yticks([0,100,200,300,400,500,600,700], ["","","","","","","",""])

    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)

    ax.tick_params(axis="both", labelsize=fontsize)

    ax.grid(True, which="both", color="lightgray", ls="--", lw=1)
    ax.set_axisbelow(True)

    plt.savefig(f"targets_removed_ontarget_{tax_level}.png", dpi=300, bbox_inches="tight")

    fig.show()

In [None]:
barplot("phylum", "Negarnaviricota")

In [None]:
barplot("class", "Monjiviricetes")

In [None]:
barplot("order", "Mononegavirales")

In [None]:
barplot("family", "Filoviridae")

In [None]:
barplot("genus", "Ebolavirus")

In [None]:
def barplot_incorrects(tax_level, expected_tax):
    fig, ax = plt.subplots(figsize=(5,7))
    fontsize=18

    correct = adata2[:, adata2.var.index.isin(taxmap[taxmap[tax_level].isin(expected_tax)]["rep_ID"].unique())].X.sum()
    ax.bar(0, correct, color=colors[0], edgecolor="black")
    if correct > 0:
        ax.text(0, correct+10, str(int(correct)), fontsize=fontsize, ha="center")


    correct = adata[:, adata.var.index.isin(taxmap[taxmap[tax_level].isin(expected_tax)]["rep_ID"].unique())].X.sum()
    ax.bar(1, correct, color=colors[1], edgecolor="black")
    if correct > 0:
        ax.text(1, correct+10, str(int(correct)), fontsize=fontsize, ha="center")

    correct = adata3[:, adata3.var.index.isin(taxmap[taxmap[tax_level].isin(expected_tax)]["rep_ID"].unique())].X.sum()
    ax.bar(2, correct, color=colors[2], edgecolor="black")
    if correct > 0:
        ax.text(2, correct+10, str(int(correct)), fontsize=fontsize, ha="center")

    ax.axhline(total, color="black", ls="--")

    # x_labels = ["All Ebolavirus\nspecies excluded", "All Ebolavirus\ngenera excluded", "All Filoviridae\nexcluded"]
    # ax.set_xticks([0,1,2], x_labels)
    ax.set_xticks([])

    if tax_level == "phylum":
        ax.set_ylabel("Counts", fontsize=fontsize)
    else:
        ax.set_yticks([0,100,200,300,400,500,600,700], ["","","","","","","",""])

    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)

    ax.tick_params(axis="both", labelsize=fontsize)

    ax.grid(True, which="both", color="lightgray", ls="--", lw=1)
    ax.set_axisbelow(True)

    plt.savefig(f"figures/targets_removed_offtarget_{tax_level}_{expected_tax[0]}.png", dpi=300, bbox_inches="tight")

    fig.show()

In [None]:
barplot_incorrects("species", ["Marburg marburgvirus"])

In [None]:
barplot_incorrects("genus", ["Marburgvirus"])

In [None]:
barplot_incorrects("family", ["Paramyxoviridae"])

In [None]:
def barplot_incorrects2(tax_level1, expected_tax1, tax_level2, expected_tax2):
    fig, ax = plt.subplots(figsize=(5,7))
    fontsize=18

    correct = adata2[:, adata2.var.index.isin(taxmap[(taxmap[tax_level1].isin(expected_tax1)) & (taxmap[tax_level2].isin(expected_tax2))]["rep_ID"].unique())].X.sum()
    ax.bar(0, correct, color=colors[0], edgecolor="black")
    if correct > 0:
        ax.text(0, correct+10, str(int(correct)), fontsize=fontsize, ha="center")


    correct = adata[:, adata.var.index.isin(taxmap[(taxmap[tax_level1].isin(expected_tax1)) & (taxmap[tax_level2].isin(expected_tax2))]["rep_ID"].unique())].X.sum()
    ax.bar(1, correct, color=colors[1], edgecolor="black")
    if correct > 0:
        ax.text(1, correct+10, str(int(correct)), fontsize=fontsize, ha="center")

    correct = adata3[:, adata3.var.index.isin(taxmap[(taxmap[tax_level1].isin(expected_tax1)) & (taxmap[tax_level2].isin(expected_tax2))]["rep_ID"].unique())].X.sum()
    ax.bar(2, correct, color=colors[2], edgecolor="black")
    if correct > 0:
        ax.text(2, correct+10, str(int(correct)), fontsize=fontsize, ha="center")

    ax.axhline(total, color="black", ls="--")

    # x_labels = ["All Ebolavirus\nspecies excluded", "All Ebolavirus\ngenera excluded", "All Filoviridae\nexcluded"]
    # ax.set_xticks([0,1,2], x_labels)
    ax.set_xticks([])

    ax.set_yticks([0,100,200,300,400,500,600,700], ["","","","","","","",""])

    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)

    ax.tick_params(axis="both", labelsize=fontsize)

    ax.grid(True, which="both", color="lightgray", ls="--", lw=1)
    ax.set_axisbelow(True)

    plt.savefig(f"figures/targets_removed_offtarget_{tax_level1}_{tax_level2}.png", dpi=300, bbox_inches="tight")

    fig.show()

In [None]:
barplot_incorrects2("genus", ["Ebolavirus"], "species", ["."])

In [None]:
barplot_incorrects2("genus", ["."], "species", ["."])