<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Figure_3/Figure_3a/4_macaque_ZEBOV_validation_seqwell.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q kb_python
import kb_python.utils as kb_utils
import numpy as np
from scipy import stats
import anndata
import pandas as pd
import json
import os
import glob
import matplotlib.pyplot as plt
import matplotlib as mpl
%config InlineBackend.figure_format='retina'

def nd(arr):
    """
    Function to transform numpy matrix to nd array.
    """
    return np.asarray(arr).reshape(-1)

### Load data
The count matrix was generated [here](https://github.com/pachterlab/LSCHWCP_2023/tree/main/Notebooks/align_macaque_PBMC_data/4_virus_dlist_cdna_dna). Here, we are loading the count matrix that was generated without cell barcode correction since we do not need single-cell resolution for this validation and so we can retain all cells (otherwise depending on the threshold set during the generation of the barcode onlist we would get different results here).

In [None]:
# Load alignment results from Caltech Data
# !wget
!unzip virus_dlist_cdna_dna_alignment_results.zip

In [None]:
# Filepath to counts
X = "virus_dlist_cdna_dna_alignment_results/no_barcode_correction/bustools_count/output.mtx"
# Filepath to barcode metadata
var_path = "virus_dlist_cdna_dna_alignment_results/no_barcode_correction/bustools_count/output.genes.txt"
# Filepath to gene metadata
obs_path = "virus_dlist_cdna_dna_alignment_results/no_barcode_correction/bustools_count/output.barcodes.txt"

# Create AnnData object
palmdb_adata = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)
palmdb_adata

In [None]:
# Add sample barcodes to adata
sb_file = open("virus_dlist_cdna_dna_alignment_results/no_barcode_correction/bustools_count/output.barcodes.prefix.txt")
sample_barcodes = sb_file.read().splitlines()
sb_file.close()

# Only use last 16 bases because sample barcode is always 16 bases in length
palmdb_adata.obs["sample_barcode"] = [bc[-16:] for bc in sample_barcodes]

In [None]:
# Create barcode to sample lookup
bc2sample_df = pd.DataFrame()

b_file = open("virus_dlist_cdna_dna_alignment_results/matrix.sample.barcodes")
barcodes = b_file.read().splitlines()
b_file.close()

s_file = open("virus_dlist_cdna_dna_alignment_results/matrix.cells")
samples = s_file.read().splitlines()
s_file.close()

bc2sample_df["sample_barcode"] = barcodes
bc2sample_df["srr"] = samples
bc2sample_df

In [None]:
palmdb_adata.obs["barcode"] = palmdb_adata.obs.index.values
palmdb_adata.obs = palmdb_adata.obs.merge(bc2sample_df, on="sample_barcode", how="left").set_index("barcode", drop=False)
palmdb_adata.obs

### Add timepoints from SRR metadata:

In [None]:
# Load library metadata
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/Supp_Fig_3/Supp_Fig_3abc/PRJNA665227_SraRunTable.txt
srr_meta = pd.read_csv("PRJNA665227_SraRunTable.txt", sep=",")

# Only keep relevent data
srr_meta["mdck_spike_in"] = srr_meta["mdck_spike_in"].fillna(False).values
srr_meta = srr_meta[["Run", "donor_animal", "Experiment", "mdck_spike_in", "hours_post_innoculation", "day_post_infection"]]
srr_meta

In [None]:
palmdb_adata.obs = (
        palmdb_adata
        .obs.merge(srr_meta, left_on="srr", right_on="Run", how="left")
        .set_index("barcode", drop=False)
    )

palmdb_adata.obs

Create clean dpi column:

In [None]:
# Join day_post_infection and hours_post_innoculation columns
palmdb_adata.obs["dpi"] = palmdb_adata.obs["day_post_infection"].fillna(palmdb_adata.obs["hours_post_innoculation"]).astype(int)

# Add h/d accession to denote hours/days
palmdb_adata.obs["dpi_accessions"] = np.where(palmdb_adata.obs["hours_post_innoculation"].isna(), "d", "h")
palmdb_adata.obs["dpi_clean"] = palmdb_adata.obs["dpi"].astype(str) + palmdb_adata.obs["dpi_accessions"].astype(str)

palmdb_adata.obs

In [None]:
# Merge 7d and 8d timepoints
new_tps = []
for tp in palmdb_adata.obs["dpi_clean"].values:
    if tp == "7d" or tp == "8d":
        new_tps.append("7d/8d")
    else:
        new_tps.append(tp)

palmdb_adata.obs["dpi_clean_merged"] = new_tps

___
### Correlation between between qPCR viral load and kb EBOV detection

In [None]:
# Load virus ID to taxonomy mapping
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/ID_to_taxonomy_mapping.csv
phylogeny_data = pd.read_csv("ID_to_taxonomy_mapping.csv")
phylogeny_data

In [None]:
# Raw viral load data provided by authors
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/Figure_3/Figure_3a/macaque_ZEBOV_validation_viral_loads.tsv
viralload = pd.read_csv("macaque_ZEBOV_validation_viral_loads.tsv", sep="\t")

# Change naming of animal IDs to match naming in scseq data
viralload = viralload.replace("NHP01", "NHP1")
viralload = viralload.replace("NHP02", "NHP2")
viralload = viralload.set_index("Unnamed: 0")

viralload = viralload.rename(columns={'BL':0})
viralload.head()

In [None]:
viralload_log = viralload.replace('UND', 1)
viralload_log_unstack = viralload_log.unstack().reset_index()
viralload_log_unstack.columns = ['Day', 'Animal', 'log_viral_load']
viralload_log_unstack = viralload_log_unstack.dropna()
viralload_log_unstack['Day']=viralload_log_unstack['Day'].astype(int)

vl_perday_mean = viralload_log_unstack.groupby('Day')['log_viral_load'].median()
vl_perday_min = viralload_log_unstack.groupby('Day')['log_viral_load'].min()
vl_perday_max = viralload_log_unstack.groupby('Day')['log_viral_load'].max()


vl_day_sumary_stats = pd.concat([vl_perday_mean, vl_perday_min, vl_perday_max], axis=1)
vl_day_sumary_stats.columns = ['mean', 'min', 'max']
vl_day_sumary_stats

In [None]:
# Drop rows with samples not contained in scseq data
vl_day_sumary_stats = vl_day_sumary_stats.drop([1,2])

In [None]:
phylogeny_data[phylogeny_data["species"].str.contains("Zaire ebolavirus")]

In [None]:
%%time
# Plot animals individually
samples = ["0d", "3d", "4d", "5d", "6d", "7d", "8d"]
labels = ["0", "3", "4", "5", "6", "7", "8"]

virus_ids = ['u10']

cidx = []
kb_counts = []
kb_counts_norm = []
vloads = []
for i, sample in enumerate(samples):
    for animal in viralload_log_unstack[viralload_log_unstack["Day"]==int(sample.split("d")[0])]["Animal"].values:
        if animal in np.unique(palmdb_adata.obs[palmdb_adata.obs["dpi_clean"] == sample]["donor_animal"].values):
            # Only take into account timpeoints/animals with at least 100k (unfiltered) cells
            num_cells = len(palmdb_adata.obs[(palmdb_adata.obs["dpi_clean"] == sample) & (palmdb_adata.obs["donor_animal"] == animal)])
            if num_cells > 100000:
                count = palmdb_adata[(palmdb_adata.obs["dpi_clean"] == sample) & (palmdb_adata.obs["donor_animal"] == animal), palmdb_adata.var.index.isin(virus_ids)].X.sum()
                kb_counts.append(count)
                kb_counts_norm.append(count / num_cells)

                vloads.append(viralload.replace('UND', 0).loc[animal].values[int(sample.split("d")[0])])
                cidx.append(i)

                print(sample, " ", animal, " ", count, " ", num_cells)

In [None]:
fig, ax = plt.subplots(figsize=(7,7))

fontsize=16

sc = ax.scatter(kb_counts, vloads, c=cidx, cmap="Blues", edgecolors="black", s=250, zorder=2)

cbar = plt.colorbar(sc)
cbar.ax.tick_params(labelsize=fontsize)
cbar.ax.set_yticklabels(labels)
cbar.ax.set_ylabel("Days post-infection", fontsize=fontsize, labelpad=1.1)

ax.set_yscale("symlog")
ax.set_xscale("symlog")
ax.set_ylabel("RT-qPCR (Zaire ebolavirus copies/mL)", fontsize=fontsize)
ax.set_xlabel(
    "kallisto\n(raw counts for Zaire ebolavirus)",
    fontsize=fontsize,
)

# ax.set_xlim(right=1000)

ax.text(18, 0.15, f"n={len(kb_counts)}", fontsize=fontsize)

# Add diagonal
# ax.plot([0, 1], [0, 1], transform=ax.transAxes, c="black", ls="-", lw=1, zorder=1)

ax.axvline(2, ls="--", color="black", lw=1)

ax.tick_params(axis="both", labelsize=fontsize)
ax.set_title(f"PBMC samples from rhesus macaques\ninfected with Zaire ebolavirus", fontsize=fontsize)

# plt.tight_layout()

ax.grid(True, which="both", color="lightgray", ls="--", lw=1)
ax.set_axisbelow(True)

plt.savefig("kb_vs_qPCR_scatter.png", dpi=300, bbox_inches="tight")

fig.show()