<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/align_macaque_PBMC_data/4_virus_dlist_cdna_dna/2_viral_QC_dlist_cdna_dna.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Virus count matrix quality control (QC) and virus categorization

In [None]:
!pip install -q kb_python
import kb_python.utils as kb_utils
import numpy as np
from scipy import stats
import anndata
import pandas as pd
import json
import os
import glob
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import MaxNLocator
%config InlineBackend.figure_format='retina'

def nd(arr):
    """
    Function to transform numpy matrix to nd array.
    """
    return np.asarray(arr).reshape(-1)

### Load data aligned to PalmDB using translated search kb with host genomic and transcriptomic sequences masked using the D-list:

In [None]:
# Load alignment results from Caltech Data
!wget https://data.caltech.edu/records/sh33z-hrx98/files/virus_dlist_cdna_dna_alignment_results.zip?download=1
!mv virus_dlist_cdna_dna_alignment_results.zip?download=1 virus_dlist_cdna_dna_alignment_results.zip
!unzip virus_dlist_cdna_dna_alignment_results.zip

In [None]:
# Filepath to counts
X = "virus_dlist_cdna_dna_alignment_results/bustools_count/output.mtx"
# Filepath to barcode metadata
var_path = "virus_dlist_cdna_dna_alignment_results/bustools_count/output.genes.txt"
# Filepath to gene metadata
obs_path = "virus_dlist_cdna_dna_alignment_results/bustools_count/output.barcodes.txt"

# Create AnnData object
palmdb_adata = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)
palmdb_adata

In [None]:
# Add sample barcodes to adata
sb_file = open("virus_dlist_cdna_dna_alignment_results/bustools_count/output.barcodes.prefix.txt")
sample_barcodes = sb_file.read().splitlines()
sb_file.close()

# Only use last 16 bases because sample barcode is always 16 bases in length
palmdb_adata.obs["sample_barcode"] = [bc[-16:] for bc in sample_barcodes]

In [None]:
# Create barcode to sample lookup
bc2sample_df = pd.DataFrame()

b_file = open("virus_dlist_cdna_dna_alignment_results/matrix.sample.barcodes")
barcodes = b_file.read().splitlines()
b_file.close()

s_file = open("virus_dlist_cdna_dna_alignment_results/matrix.cells")
samples = s_file.read().splitlines()
s_file.close()

bc2sample_df["sample_barcode"] = barcodes
bc2sample_df["srr"] = samples
bc2sample_df

In [None]:
palmdb_adata.obs["barcode"] = palmdb_adata.obs.index.values
palmdb_adata.obs = palmdb_adata.obs.merge(bc2sample_df, on="sample_barcode", how="left").set_index("barcode", drop=False)
palmdb_adata.obs

### Add timepoints from SRR metadata:

In [None]:
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/Supp_Fig_3/Supp_Fig_3abc/PRJNA665227_SraRunTable.txt
srr_meta = pd.read_csv("PRJNA665227_SraRunTable.txt", sep=",")

# Only keep relevent data
srr_meta["mdck_spike_in"] = srr_meta["mdck_spike_in"].fillna(False).values
srr_meta = srr_meta[["Run", "donor_animal", "Experiment", "mdck_spike_in", "hours_post_innoculation", "day_post_infection"]]
srr_meta

In [None]:
palmdb_adata.obs = (
        palmdb_adata
        .obs.merge(srr_meta, left_on="srr", right_on="Run", how="left")
        .set_index("barcode", drop=False)
    )

palmdb_adata.obs

Create clean dpi column:

In [None]:
# Join day_post_infection and hours_post_innoculation columns
palmdb_adata.obs["dpi"] = palmdb_adata.obs["day_post_infection"].fillna(palmdb_adata.obs["hours_post_innoculation"]).astype(int)

# Add h/d accession to denote hours/days
palmdb_adata.obs["dpi_accessions"] = np.where(palmdb_adata.obs["hours_post_innoculation"].isna(), "d", "h")
palmdb_adata.obs["dpi_clean"] = palmdb_adata.obs["dpi"].astype(str) + palmdb_adata.obs["dpi_accessions"].astype(str)

palmdb_adata.obs

___

# Add macaque and canine metadata to virus count matrix:

In [None]:
# Download host count matrices from Caltech Data
# Generated here: https://github.com/pachterlab/LSCHWCP_2023/tree/main/Notebooks/Supp_Fig_3/Supp_Fig_3abc
!wget https://data.caltech.edu/records/sh33z-hrx98/files/macaque_QC_norm_leiden_celltypes.h5ad?download=1
!mv macaque_QC_norm_leiden_celltypes.h5ad?download=1 macaque_QC_norm_leiden_celltypes.h5ad
!wget https://data.caltech.edu/records/sh33z-hrx98/files/host_QC.h5ad?download=1
!mv host_QC.h5ad?download=1 host_QC.h5ad

# Macaque only
mac_adata = anndata.read("macaque_QC_norm_leiden_celltypes.h5ad")
mac_adata

# All host (macaque + MDCK)
host_adata = anndata.read("host_QC.h5ad")
host_adata

In [None]:
palmdb_adata.obs = palmdb_adata.obs.drop("barcode", axis=1)
palmdb_adata.obs = palmdb_adata.obs.merge(
    mac_adata.obs[["srr", "leiden", "celltype_clusters", "celltype"]],
    on=["barcode", "srr"],
    how="left",
    validate="one_to_one"
)

In [None]:
palmdb_adata.obs[palmdb_adata.obs["celltype"].notnull()]

In [None]:
palmdb_adata.obs = palmdb_adata.obs.merge(
    host_adata.obs[["srr", "species"]],
    on=["barcode", "srr"],
    how="left",
    validate="one_to_one"
)

In [None]:
palmdb_adata.obs[palmdb_adata.obs["species"].notnull()]

# Threshold viruses to keep only those seen in at least [virus_threshold] QC'd cells

In [None]:
# Remove all cells that did not pass macaque or canine filtering
palmdb_adata = palmdb_adata[palmdb_adata.obs["species"].notnull(), :]
palmdb_adata.obs

In [None]:
# Remove all viruses that were not detected in at least [virus_threshold] cells across the entire dataset
virus_threshold = 1

palmdb_adata = palmdb_adata[:, (palmdb_adata.X > 0).sum(axis=0) >= virus_threshold]
palmdb_adata

# Binarize virus matrix

In [None]:
# Save raw counts before binarizing
palmdb_adata.raw = palmdb_adata

In [None]:
# Replace all positive integers with 1
palmdb_adata.X[palmdb_adata.X > 0] = 1

In [None]:
palmdb_adata.raw.X.sum()

In [None]:
palmdb_adata.raw.X.max()

In [None]:
palmdb_adata.X.sum()

In [None]:
palmdb_adata.X.max()

# Merge timepoints

In [None]:
# Merge 7d and 8d timepoints
new_tps = []
for tp in palmdb_adata.obs["dpi_clean"].values:
    if tp == "7d" or tp == "8d":
        new_tps.append("7d/8d")
    else:
        new_tps.append(tp)

palmdb_adata.obs["dpi_clean_merged"] = new_tps

# Mark viruses seen in both canine and macaque cells as "shared"

In [None]:
def yex(ax, c="grey", alpha=0.75):
    """
    Funciton to add linear graph to plot.
    Call after defining x and y scales
    """
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    # Plot both limits against each other
    ax.plot(lims, lims, c, alpha=alpha, zorder=0, lw=1)
    ax.set(**{
        "aspect": "equal",
        "xlim": lims,
        "ylim": lims
    })
    return ax

Viruses along the y=x line will have equal fractions of positive cells in both MDCK and macaque cells and are likely shared. We will define a log (increases as fractions also increase) minimum distance between the MDCK and macaque fractions.

In [None]:
# Minimum fraction of positive cells per virus per species
min_frac_threshold = 0.0005

print(f"This equals to a minimum of {min_frac_threshold * len(palmdb_adata.obs[palmdb_adata.obs['species'] == 'macaca_mulatta'])} positive cells for macaque")

In [None]:
%%time
# Fold change between fractions of pos cells within which viruses will be considered shared
fold_change = 2

# Minimum number of cells allowed from other species for "_only" assignments
min_num_cells = 5

# Get fraction of positive cells for each virus for each species
mac_fractions = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", :].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
can_fractions = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", :].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

v_types = []
for mac_frac, can_frac in zip(mac_fractions, can_fractions):
    if mac_frac < min_frac_threshold and can_frac < min_frac_threshold:
        v_types.append("below_threshold")
    elif mac_frac > can_frac and mac_frac / can_frac <= fold_change:
        v_types.append("shared")
    elif can_frac > mac_frac and can_frac / mac_frac <= fold_change:
        v_types.append("shared")
    elif mac_frac > can_frac and can_frac <= (min_num_cells / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])): # Allow [min_num_cells] positive cells in the other species
        v_types.append("macaca_only")
    elif can_frac > mac_frac and mac_frac <= (min_num_cells / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])):
        v_types.append("canis_only")
    else:
        v_types.append("undefined")

In [None]:
palmdb_adata.var["v_type"] = v_types

In [None]:
fig, ax = plt.subplots(figsize=(12, 10))
fontsize = 14
vmax = 5.0

all_v_types = ["macaca_only", "canis_only", "shared", "undefined", "below_threshold"]
labels = ["macaque only", "MDCK only", "shared", "undefined", "below threshold"]

# Generate colormaps (removing lightest values for visibility)
edgecolors = ["green", "blue", "red", "grey", "lightgrey"]
cmaps_orig = [plt.cm.Greens, plt.cm.Blues, plt.cm.Oranges, plt.cm.Greys]
min_val, max_val = 0.25, 1.0
n = 20
cmaps = []
for orig_cmap in cmaps_orig:
    colors = orig_cmap(np.linspace(min_val, max_val, n))
    cmaps.append(mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors))

for i, v_type in enumerate(all_v_types):
    # Plot fraction of canis and macaque positive cells
    x = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
    y = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

    # Convert to percentages
    x = x*100
    y = y*100

    # Histogram data to show point density
    bins = [1500, 1500]
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    print(z2.max())

    # Normalize colormaps to same range
    if v_type != "below_threshold":
        scatter = ax.scatter(x2, y2, c=z2, cmap=cmaps[i], norm=mpl.colors.Normalize(vmin=1, vmax=vmax), edgecolor=edgecolors[i], lw=0.2, s=50, alpha=1)
    else:
        scatter = ax.scatter(x2, y2, c=edgecolors[i], edgecolor=edgecolors[i], lw=0.2, s=50, alpha=0.2)

    # Add colorbar
    if v_type != "below_threshold":
        cbar = fig.colorbar(scatter, ax=ax, pad=0.01, shrink=0.7)
        cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        if i == 0:
            cbar.ax.tick_params(axis="both", labelsize=fontsize-2)
        else:
            cbar.set_ticks([])
        cbar.ax.set_ylabel(f"# of {labels[i]} viruses", fontsize=fontsize)

ax.set_xlabel("Percentage of positive Macaca mulatta cells", fontsize=fontsize)
ax.set_ylabel("Percentage of positive MDCK cells", fontsize=fontsize)

ax.set_xscale("log")
ax.set_yscale("log")

# Add shared threshold area and y=x line
# Get axis limits
lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

# Threshold fold change above/below x=y line
x_thres = np.array(lims)
ax.fill_between(x_thres, 10**(np.log10(x_thres)+np.log10(fold_change)), 10**(np.log10(x_thres)-np.log10(fold_change)), alpha=0.1, color="#f77f00")

# Add minimum fraction threshold lines
ax.plot([min_frac_threshold*100, min_frac_threshold*100], [min_frac_threshold*100, ax.get_xlim()[0]], c="grey", lw=1, ls="--", zorder=-1)
ax.plot([min_frac_threshold*100, ax.get_xlim()[0]], [min_frac_threshold*100, min_frac_threshold*100], c="grey", lw=1, ls="--", zorder=-1)

ax.plot(ax.get_xlim(), ax.get_xlim(), c="grey", lw=1, ls="-")

ax.set_ylim(lims)
ax.set_xlim(lims)
ax.set_aspect("equal")

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

# Save figure
fig.savefig(
    "virus_shared_thresholds.png", dpi=300, bbox_inches="tight", transparent=True
)

fig.show()

In [None]:
# Zoom into area around 0
fig, ax = plt.subplots(figsize=(12, 10))
fontsize = 14
vmax = 5.0

all_v_types = ["macaca_only", "canis_only", "shared", "undefined", "below_threshold"]
labels = ["macaque only", "MDCK only", "shared", "undefined", "below threshold"]

# Generate colormaps (removing lightest values for visibility)
edgecolors = ["green", "blue", "red", "grey", "lightgrey"]
cmaps_orig = [plt.cm.Greens, plt.cm.Blues, plt.cm.Oranges, plt.cm.Greys]
min_val, max_val = 0.25, 1.0
n = 20
cmaps = []
for orig_cmap in cmaps_orig:
    colors = orig_cmap(np.linspace(min_val, max_val, n))
    cmaps.append(mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors))

for i, v_type in enumerate(all_v_types):
    # Plot fraction of canis and macaque positive cells
    x = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
    y = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

    # Convert to percentages
    x = x*100
    y = y*100

    # Histogram data to show point density
    bins = [1500, 1500]
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    print(z2.max())

    # Normalize colormaps to same range
    if v_type != "below_threshold":
        scatter = ax.scatter(x2, y2, c=z2, cmap=cmaps[i], norm=mpl.colors.Normalize(vmin=1, vmax=vmax), edgecolor=edgecolors[i], lw=0.2, s=50, alpha=1)
    else:
        scatter = ax.scatter(x2, y2, c=edgecolors[i], edgecolor=edgecolors[i], lw=0.2, s=50, alpha=0.2)
    # scatter = ax.scatter(x2, y2, c=z2, cmap=cmaps[i], edgecolor=edgecolors[i], lw=0.1, s=20)

    # Add colorbar
    if v_type != "below_threshold":
        cbar = fig.colorbar(scatter, ax=ax, pad=0.01, shrink=0.7)
        cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        if i == 0:
            cbar.ax.tick_params(axis="both", labelsize=fontsize-2)
        else:
            cbar.set_ticks([])
        cbar.ax.set_ylabel(f"# of {labels[i]} viruses", fontsize=fontsize)

# Mark Ebolavirus with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u10"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u10"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1, s=50)

ax.set_xlabel("Percentage of positive Macaca mulatta cells", fontsize=fontsize)
ax.set_ylabel("Percentage of positive MDCK cells", fontsize=fontsize)

ax.set_ylim(bottom=-0.01, top=0.23)
ax.set_xlim(left=-0.01, right=0.23)

# Add shared threshold area and y=x line
# Get axis limits
lims = [
        0.0001,
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

# Threshold fold change above/below x=y line
x_thres = np.array(lims)
ax.fill_between(x_thres, 10**(np.log10(x_thres)+np.log10(fold_change)), 10**(np.log10(x_thres)-np.log10(fold_change)), alpha=0.1, color="#f77f00")

# Add minimum fraction threshold lines
ax.plot([min_frac_threshold*100, min_frac_threshold*100], [min_frac_threshold*100, -1], c="grey", lw=1, ls="--", zorder=-1)
ax.plot([min_frac_threshold*100, -1], [min_frac_threshold*100, min_frac_threshold*100], c="grey", lw=1, ls="--", zorder=-1)

ax.plot(ax.get_xlim(), ax.get_xlim(), c="grey", lw=1, ls="-")

ax.set_aspect("equal")

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

# Save figure
fig.savefig(
    "virus_shared_thresholds_zero.png", dpi=300, bbox_inches="tight", transparent=True
)

fig.show()

# Plot count histogram to show low counts per virus ID (compared to host) before binarizing:

In [None]:
palmdb_adata_temp = palmdb_adata.raw[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "macaca_only"].X
vir_counts = nd(palmdb_adata_temp[palmdb_adata_temp>0])

palmdb_adata_temp = palmdb_adata.raw[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "shared"].X
vir_counts2 = nd(palmdb_adata_temp[palmdb_adata_temp>0])

host_counts = nd(mac_adata.raw.X[mac_adata.raw.X>0])

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
fontsize = 14

ax.hist(host_counts, bins=160, histtype='stepfilled', edgecolor="grey", facecolor="lightgrey", label="Macaque genes", log=True)
ax.hist(vir_counts2, bins=7, histtype='stepfilled', edgecolor="#f77f00", facecolor="#f77f00", alpha=0.8, label="Viruses (shared)", log=True)
ax.hist(vir_counts, bins=2, histtype='stepfilled', edgecolor="#0f7c49", facecolor="#0f7c49", alpha=1, label="Viruses (macaque only)", log=True)

ax.legend(fontsize=fontsize)

ax.set_xlabel("Count in one cell", fontsize=fontsize)
ax.set_ylabel("Frequency", fontsize=fontsize)

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

ax.grid(True, which="both", color="lightgray", ls="--", lw=1)
ax.set_axisbelow(True)

plt.savefig("count_histogram.png", dpi=300, bbox_inches="tight")

fig.show()

___

In [None]:
# Save new binarized virus matrix including host species information and virus types
palmdb_adata.write("virus_dlist_cdna_dna.h5ad")