<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/align_macaque_PBMC_data/7_virus_host_captured_dlist_cdna_dna/2_viral_QC_host_captured_dlist_cdna_dna.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Virus count matrix quality control (QC) and virus categorization

In [None]:
!pip install -q kb_python
import kb_python.utils as kb_utils
import numpy as np
from scipy import stats
import anndata
import pandas as pd
import json
import os
import glob
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import MaxNLocator
%config InlineBackend.figure_format='retina'

def nd(arr):
    """
    Function to transform numpy matrix to nd array.
    """
    return np.asarray(arr).reshape(-1)

### Load data aligned to PalmDB using translated search with the host genome and transcriptome masked in the reference index using D-list and host reads captured:

In [None]:
# Load alignment results from Caltech Data
# !wget
!unzip virus_host_capture_dlist_cdna_dna_alignment_results.zip

Import data from all batches/SRRs:

In [None]:
%%time

palmdb_virus_adatas = []
palmdb_host_adatas = []
for file_path in glob.glob("virus_host_capture_dlist_cdna_dna_alignment_results/virus/AAAA*/"):
    # Filepath to counts
    X = f"{file_path}/virus/bustools_count/output.mtx"
    # Filepath to barcode metadata
    var_path = f"{file_path}/virus/bustools_count/output.genes.txt"
    # Filepath to gene metadata
    obs_path = f"{file_path}/virus/bustools_count/output.barcodes.txt"
    # Create AnnData object
    virus_adata = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)

    palmdb_virus_adatas.append(virus_adata)

    X = f"{file_path}/host/bustools_count/output.mtx"
    var_path = f"{file_path}/host/bustools_count/output.genes.txt"
    obs_path = f"{file_path}/host/bustools_count/output.barcodes.txt"
    host_adata = kb_utils.import_matrix_as_anndata(X, obs_path, var_path)

    palmdb_host_adatas.append(host_adata)

Concatenate adata objects:

In [None]:
%%time
palmdb_adata = anndata.concat(palmdb_virus_adatas, join="outer")
palmdb_adata_host = anndata.concat(palmdb_host_adatas, join="outer")

In [None]:
palmdb_adata

In [None]:
palmdb_adata[:, palmdb_adata.var.index == "u10"].X.sum()

In [None]:
palmdb_adata_host

In [None]:
# Cannot use prefix file here because data was generated with -n flag
# bustools count populates any addditional fields with As until 32 characters
palmdb_adata.obs["sample_barcode"] = [bc[-28:-12] for bc in palmdb_adata.obs.index.values]
palmdb_adata.obs["barcode"] = [bc[-12:] for bc in palmdb_adata.obs.index.values]
palmdb_adata.obs

In [None]:
# Create barcode to sample lookup
bc2sample_df = pd.DataFrame()

b_file = open("virus_host_capture_dlist_cdna_dna_alignment_results/virus/matrix.sample.barcodes")
barcodes = b_file.read().splitlines()
b_file.close()

s_file = open("virus_host_capture_dlist_cdna_dna_alignment_results/virus/matrix.cells")
samples = s_file.read().splitlines()
s_file.close()

bc2sample_df["sample_barcode"] = barcodes
bc2sample_df["srr"] = samples
bc2sample_df

In [None]:
palmdb_adata.obs = palmdb_adata.obs.merge(bc2sample_df, on="sample_barcode", how="left").set_index("barcode", drop=False)
palmdb_adata.obs

In [None]:
# Same for reads that also occured in host
palmdb_adata_host.obs["sample_barcode"] = [bc[-28:-12] for bc in palmdb_adata_host.obs.index.values]
palmdb_adata_host.obs["barcode"] = [bc[-12:] for bc in palmdb_adata_host.obs.index.values]
palmdb_adata_host.obs = palmdb_adata_host.obs.merge(bc2sample_df, on="sample_barcode", how="left").set_index("barcode", drop=False)
palmdb_adata_host.obs

### Add timepoints from SRR metadata:

In [None]:
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/Supp_Fig_3/Supp_Fig_3abc/PRJNA665227_SraRunTable.txt
srr_meta = pd.read_csv("PRJNA665227_SraRunTable.txt", sep=",")

# Only keep relevent data
srr_meta["mdck_spike_in"] = srr_meta["mdck_spike_in"].fillna(False).values
srr_meta = srr_meta[["Run", "donor_animal", "Experiment", "mdck_spike_in", "hours_post_innoculation", "day_post_infection"]]
srr_meta

In [None]:
palmdb_adata.obs = (
        palmdb_adata
        .obs.merge(srr_meta, left_on="srr", right_on="Run", how="left")
        .set_index("barcode", drop=False)
    )

palmdb_adata.obs

In [None]:
palmdb_adata_host.obs = (
        palmdb_adata_host
        .obs.merge(srr_meta, left_on="srr", right_on="Run", how="left")
        .set_index("barcode", drop=False)
    )

palmdb_adata_host.obs

Create clean dpi column:

In [None]:
# Join day_post_infection and hours_post_innoculation columns
palmdb_adata.obs["dpi"] = palmdb_adata.obs["day_post_infection"].fillna(palmdb_adata.obs["hours_post_innoculation"]).astype(int)

# Add h/d accession to denote hours/days
palmdb_adata.obs["dpi_accessions"] = np.where(palmdb_adata.obs["hours_post_innoculation"].isna(), "d", "h")
palmdb_adata.obs["dpi_clean"] = palmdb_adata.obs["dpi"].astype(str) + palmdb_adata.obs["dpi_accessions"].astype(str)

palmdb_adata.obs

In [None]:
# Join day_post_infection and hours_post_innoculation columns
palmdb_adata_host.obs["dpi"] = palmdb_adata_host.obs["day_post_infection"].fillna(palmdb_adata_host.obs["hours_post_innoculation"]).astype(int)

# Add h/d accession to denote hours/days
palmdb_adata_host.obs["dpi_accessions"] = np.where(palmdb_adata_host.obs["hours_post_innoculation"].isna(), "d", "h")
palmdb_adata_host.obs["dpi_clean"] = palmdb_adata_host.obs["dpi"].astype(str) + palmdb_adata_host.obs["dpi_accessions"].astype(str)

palmdb_adata_host.obs

___

# Add macaque and canine metadata to virus count matrix:

In [None]:
# Download host count matrices from Caltech Data
# Generated here: https://github.com/pachterlab/LSCHWCP_2023/tree/main/Notebooks/Supp_Fig_3/Supp_Fig_3abc
!wget https://data.caltech.edu/records/sh33z-hrx98/files/macaque_QC_norm_leiden_celltypes.h5ad?download=1
!mv macaque_QC_norm_leiden_celltypes.h5ad?download=1 macaque_QC_norm_leiden_celltypes.h5ad
!wget https://data.caltech.edu/records/sh33z-hrx98/files/host_QC.h5ad?download=1
!mv host_QC.h5ad?download=1 host_QC.h5ad

# Macaque only
mac_adata = anndata.read("macaque_QC_norm_leiden_celltypes.h5ad")
mac_adata

# All host (macaque + MDCK)
host_adata = anndata.read("host_QC.h5ad")
host_adata

In [None]:
palmdb_adata.obs = palmdb_adata.obs.drop("barcode", axis=1)
palmdb_adata.obs = palmdb_adata.obs.merge(
    mac_adata.obs[["srr", "leiden", "celltype_clusters", "celltype"]],
    on=["barcode", "srr"],
    how="left",
    validate="one_to_one"
)

palmdb_adata_host.obs = palmdb_adata_host.obs.drop("barcode", axis=1)
palmdb_adata_host.obs = palmdb_adata_host.obs.merge(
    mac_adata.obs[["srr", "leiden", "celltype_clusters", "celltype"]],
    on=["barcode", "srr"],
    how="left",
    validate="one_to_one"
)

In [None]:
palmdb_adata.obs[palmdb_adata.obs["celltype"].notnull()]

In [None]:
palmdb_adata.obs = palmdb_adata.obs.merge(
    host_adata.obs[["srr", "species"]],
    on=["barcode", "srr"],
    how="left",
    validate="one_to_one"
)

palmdb_adata_host.obs = palmdb_adata_host.obs.merge(
    host_adata.obs[["srr", "species"]],
    on=["barcode", "srr"],
    how="left",
    validate="one_to_one"
)

In [None]:
palmdb_adata.obs[palmdb_adata.obs["species"].notnull()]

# Threshold viruses to keep only those seen in at least [virus_threshold] QC'd cells

In [None]:
# Remove all cells that did not pass macaque or canine filtering
palmdb_adata = palmdb_adata[palmdb_adata.obs["species"].notnull(), :]

palmdb_adata_host = palmdb_adata_host[palmdb_adata_host.obs["species"].notnull(), :]

palmdb_adata.obs

In [None]:
# Remove all viruses that were not detected in at least [virus_threshold] cells across the entire dataset
virus_threshold = 1

palmdb_adata = palmdb_adata[:, (palmdb_adata.X > 0).sum(axis=0) >= virus_threshold]

palmdb_adata_host = palmdb_adata_host[:, (palmdb_adata_host.X > 0).sum(axis=0) >= virus_threshold]

palmdb_adata

In [None]:
palmdb_adata_host

# Binarize virus matrix

In [None]:
# Save raw counts before binarizing
palmdb_adata.raw = palmdb_adata
palmdb_adata_host.raw = palmdb_adata_host

In [None]:
# Replace all positive integers with 1
palmdb_adata.X[palmdb_adata.X > 0] = 1
palmdb_adata_host.X[palmdb_adata_host.X > 0] = 1

In [None]:
palmdb_adata.raw.X.sum()

In [None]:
palmdb_adata.raw.X.max()

In [None]:
palmdb_adata.X.sum()

In [None]:
palmdb_adata.X.max()

# Merge timepoints

In [None]:
# Merge 7d and 8d timepoints
new_tps = []
for tp in palmdb_adata.obs["dpi_clean"].values:
    if tp == "7d" or tp == "8d":
        new_tps.append("7d/8d")
    else:
        new_tps.append(tp)

palmdb_adata.obs["dpi_clean_merged"] = new_tps


# Merge 7d and 8d timepoints
new_tps = []
for tp in palmdb_adata_host.obs["dpi_clean"].values:
    if tp == "7d" or tp == "8d":
        new_tps.append("7d/8d")
    else:
        new_tps.append(tp)

palmdb_adata_host.obs["dpi_clean_merged"] = new_tps

# Show viruses that have...
1. Only reads that also mapped to host
2. Only reads that mapped to virus
3. Some reads that mapped to host and some reads that mapped to virus

In [None]:
palmdb_adata

In [None]:
palmdb_adata_host

In [None]:
vir_virs = palmdb_adata.var.index.values
host_virs = palmdb_adata_host.var.index.values
mixed_virs = list(set(vir_virs).intersection(set(host_virs)))

In [None]:
len(vir_virs)

In [None]:
print("Number of 'mixed' viruses: ", len(mixed_virs))

In [None]:
print("Number of vir only viruses: ", 11176 - 2260)

In [None]:
print("Number of host only viruses: ", 5266 - 2260)

In [None]:
print("Total number of viruses detected: ", len(set(list(vir_virs) + list(host_virs))))

In [None]:
"u10" in mixed_virs

# Mark viruses seen in both canine and macaque cells as "shared"

In [None]:
def yex(ax, c="grey", alpha=0.75):
    """
    Funciton to add linear graph to plot.
    Call after defining x and y scales
    """
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    # Plot both limits against each other
    ax.plot(lims, lims, c, alpha=alpha, zorder=0, lw=1)
    ax.set(**{
        "aspect": "equal",
        "xlim": lims,
        "ylim": lims
    })
    return ax

Viruses along the y=x line will have equal fractions of positive cells in both MDCK and macaque cells and are likely shared. We will define a log (increases as fractions also increase) minimum distance between the MDCK and macaque fractions.

In [None]:
# Minimum fraction of positive cells per virus per species
min_frac_threshold = 0.0005

print(f"This equals to a minimum of {min_frac_threshold * len(palmdb_adata.obs[palmdb_adata.obs['species'] == 'macaca_mulatta'])} positive cells for macaque")

In [None]:
%%time
# Fold change between fractions of pos cells within which viruses will be considered shared
fold_change = 2

# Maximum number of cells allowed from other species for "_only" assignments
min_num_cells = 7

# Get fraction of positive cells for each virus for each species
mac_fractions = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", :].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
can_fractions = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", :].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

v_types = []
for mac_frac, can_frac in zip(mac_fractions, can_fractions):
    if mac_frac < min_frac_threshold and can_frac < min_frac_threshold:
        v_types.append("below_threshold")
    elif mac_frac > can_frac and mac_frac / can_frac <= fold_change:
        v_types.append("shared")
    elif can_frac > mac_frac and can_frac / mac_frac <= fold_change:
        v_types.append("shared")
    elif mac_frac > can_frac and can_frac <= (min_num_cells / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])): # Allow [min_num_cells] positive cells in the other species
        v_types.append("macaca_only")
    elif can_frac > mac_frac and mac_frac <= (min_num_cells / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])):
        v_types.append("canis_only")
    else:
        v_types.append("undefined")

In [None]:
palmdb_adata.var["v_type"] = v_types

In [None]:
from collections import Counter
Counter(v_types)

In [None]:
fig, ax = plt.subplots(figsize=(12, 10))
fontsize = 14
vmax = 5

all_v_types = ["macaca_only", "canis_only", "shared", "undefined", "below_threshold"]
labels = ["macaque only", "MDCK only", "shared", "undefined", "below threshold"]

# Generate colormaps (removing lightest values for visibility)
edgecolors = ["green", "blue", "red", "grey", "lightgrey"]
cmaps_orig = [plt.cm.Greens, plt.cm.Blues, plt.cm.Oranges, plt.cm.Greys]
min_val, max_val = 0.25, 1.0
n = 20
cmaps = []
for orig_cmap in cmaps_orig:
    colors = orig_cmap(np.linspace(min_val, max_val, n))
    cmaps.append(mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors))

for i, v_type in enumerate(all_v_types):
    # Plot fraction of canis and macaque positive cells
    x = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
    y = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

    # Convert to percentages
    x = x*100
    y = y*100

    # Histogram data to show point density
    bins = [1500, 1500]
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    print(z2.max())

    # Normalize colormaps to same range
    if v_type != "below_threshold":
        scatter = ax.scatter(x2, y2, c=z2, cmap=cmaps[i], norm=mpl.colors.Normalize(vmin=1, vmax=vmax), edgecolor=edgecolors[i], lw=0.2, s=50, alpha=1)
    else:
        scatter = ax.scatter(x2, y2, c=edgecolors[i], edgecolor=edgecolors[i], lw=0.2, s=50, alpha=0.2)

    # Add colorbar
    if v_type != "below_threshold":
        cbar = fig.colorbar(scatter, ax=ax, pad=0.01, shrink=0.7)
        cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        if i == 0:
            cbar.ax.tick_params(axis="both", labelsize=fontsize-2)
        else:
            cbar.set_ticks([])
        cbar.ax.set_ylabel(f"# of {labels[i]} viruses", fontsize=fontsize)

# Mark u202260 with black edge
x_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u202260"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u202260"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1, s=50)

# Mark u102324 with black edge
x_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u102324"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u102324"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1, s=50)

# Mark u134800 with black edge
x_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u134800"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u134800"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1, s=50)

ax.set_xlabel("Percentage of positive Macaca mulatta cells", fontsize=fontsize)
ax.set_ylabel("Percentage of positive MDCK cells", fontsize=fontsize)

ax.set_xscale("log")
ax.set_yscale("log")

# Add shared threshold area and y=x line
# Get axis limits
lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

# Threshold fold change above/below x=y line
x_thres = np.array(lims)
ax.fill_between(x_thres, 10**(np.log10(x_thres)+np.log10(fold_change)), 10**(np.log10(x_thres)-np.log10(fold_change)), alpha=0.1, color="#f77f00")

# Add minimum fraction threshold lines
ax.plot([min_frac_threshold*100, min_frac_threshold*100], [min_frac_threshold*100, ax.get_xlim()[0]], c="grey", lw=1, ls="--", zorder=-1)
ax.plot([min_frac_threshold*100, ax.get_xlim()[0]], [min_frac_threshold*100, min_frac_threshold*100], c="grey", lw=1, ls="--", zorder=-1)

ax.plot(ax.get_xlim(), ax.get_xlim(), c="grey", lw=1, ls="-")

ax.set_ylim(lims)
ax.set_xlim(lims)
ax.set_aspect("equal")

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

# Save figure
# fig.savefig(
# f"virus_shared_thresholds_bus.png", dpi=300, bbox_inches="tight", transparent=True
# )

fig.show()

Adjust the colors:

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
fontsize = 16
vmax = 5

all_v_types = ["macaca_only", "shared", "canis_only",  "undefined", "below_threshold"]
labels = ["Macaque only", "Shared", "MDCK only", "Undefined", "Below threshold"]

edgecolors = ["black", "black", "black", "black", "black"]
# colors = ["#5c2554", "#ab4e9b", "#bb7fb4", "lightgrey", "lightgrey"] # purples
colors = ["#003049", "#4b8eb3", "#8fc0de", "lightgrey", "lightgrey"]
dot_size = 75

for i, v_type in enumerate(all_v_types):
    # Plot fraction of canis and macaque positive cells
    x = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
    y = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

    # Convert to percentages
    x = x*100
    y = y*100

    if v_type == "below_threshold":
        scatter = ax.scatter(x, y, color=colors[i], s=dot_size, alpha=0.5, zorder=-1)
    else:
        scatter = ax.scatter(x, y, color=colors[i], edgecolor=edgecolors[i], lw=0.2, s=dot_size, alpha=1, label=labels[i])


ax.legend(fontsize=fontsize, loc="upper left")

# Mark u202260 with black edge
x_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u202260"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u202260"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

# Mark u102324 with black edge
x_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u102324"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u102324"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

# Mark u134800 with black edge
x_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u134800"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u134800"].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

ax.set_xlabel("Positive macaque cells (%)", fontsize=fontsize)
ax.set_ylabel("Positive MDCK cells (%)", fontsize=fontsize)

ax.set_xscale("log")
ax.set_yscale("log")

# Add shared threshold area and y=x line
# Get axis limits
lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

# Threshold fold change above/below x=y line
x_thres = np.array(lims)
ax.fill_between(x_thres, 10**(np.log10(x_thres)+np.log10(fold_change)), 10**(np.log10(x_thres)-np.log10(fold_change)), alpha=0.1, color=colors[1])

# Add minimum fraction threshold lines
ax.plot([min_frac_threshold*100, min_frac_threshold*100], [min_frac_threshold*100, ax.get_xlim()[0]], c="grey", lw=1, ls="--", zorder=-1)
ax.plot([min_frac_threshold*100, ax.get_xlim()[0]], [min_frac_threshold*100, min_frac_threshold*100], c="grey", lw=1, ls="--", zorder=-1)

ax.plot(ax.get_xlim(), ax.get_xlim(), c="grey", lw=1, ls="-")

ax.set_ylim(lims)
ax.set_xlim(lims)
ax.set_aspect("equal")

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

# Save figure
fig.savefig(
    f"virus_shared_thresholds_bus.png", dpi=300, bbox_inches="tight", transparent=True
)

fig.show()

In [None]:
# Zoom into area around 0
fig, ax = plt.subplots(figsize=(8, 8))
fontsize = 16
vmax = 5

all_v_types = ["macaca_only", "shared", "canis_only",  "undefined", "below_threshold"]
labels = ["macaque only", "shared", "MDCK only", "undefined", "below threshold"]

edgecolors = ["black", "black", "black", "black", "black"]
# colors = ["#5c2554", "#ab4e9b", "#bb7fb4", "lightgrey", "lightgrey"] # purples
colors = ["#003049", "#4b8eb3", "#8fc0de", "lightgrey", "lightgrey"]

for i, v_type in enumerate(all_v_types):
    # Plot fraction of canis and macaque positive cells
    x = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
    y = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

    # Convert to percentages
    x = x*100
    y = y*100

    # Normalize colormaps to same range
    if v_type == "below_threshold":
        scatter = ax.scatter(x, y, color=colors[i], lw=0.2, s=dot_size, alpha=0.5)
    else:
        scatter = ax.scatter(x, y, color=colors[i], edgecolor=edgecolors[i], lw=0.2, s=dot_size, alpha=1)


# Mark Ebolavirus with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u10"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u10"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

# Mark u102540 with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u102540"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u102540"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

# Mark u11150 with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u11150"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u11150"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

# Mark u39566 with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u39566"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u39566"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=dot_size)

ax.set_xlabel("Positive macaque cells (%)", fontsize=fontsize)
ax.set_ylabel("Positive MDCK cells (%)", fontsize=fontsize)

ax.set_ylim(bottom=-0.01, top=0.23)
ax.set_xlim(left=-0.01, right=0.23)

# Add shared threshold area and y=x line
# Get axis limits
lims = [
        0.0001,
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

# Threshold fold change above/below x=y line
x_thres = np.array(lims)
ax.fill_between(x_thres, 10**(np.log10(x_thres)+np.log10(fold_change)), 10**(np.log10(x_thres)-np.log10(fold_change)), alpha=0.1, color=colors[1])

# Add minimum fraction threshold lines
ax.plot([min_frac_threshold*100, min_frac_threshold*100], [min_frac_threshold*100, -1], c="grey", lw=1, ls="--", zorder=-1)
ax.plot([min_frac_threshold*100, -1], [min_frac_threshold*100, min_frac_threshold*100], c="grey", lw=1, ls="--", zorder=-1)

ax.plot(ax.get_xlim(), ax.get_xlim(), c="grey", lw=1, ls="-")

ax.set_aspect("equal")

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

# # Save figure
# fig.savefig(
#     "virus_shared_thresholds_zero.png", dpi=300, bbox_inches="tight"
# )

fig.show()

Adjust the colors:

In [None]:
# Zoom into area around 0
fig, ax = plt.subplots(figsize=(12, 10))
fontsize = 14
vmax = 5

all_v_types = ["macaca_only", "canis_only", "shared", "undefined", "below_threshold"]
labels = ["macaque only", "MDCK only", "shared", "undefined", "below threshold"]

# Generate colormaps (removing lightest values for visibility)
edgecolors = ["green", "blue", "red", "grey", "lightgrey"]
cmaps_orig = [plt.cm.Greens, plt.cm.Blues, plt.cm.Oranges, plt.cm.Greys]
min_val, max_val = 0.25, 1.0
n = 20
cmaps = []
for orig_cmap in cmaps_orig:
    colors = orig_cmap(np.linspace(min_val, max_val, n))
    cmaps.append(mpl.colors.LinearSegmentedColormap.from_list("mycmap", colors))

for i, v_type in enumerate(all_v_types):
    # Plot fraction of canis and macaque positive cells
    x = nd(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
    y = nd(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var["v_type"] == v_type].X.sum(axis=0)) / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])

    # Convert to percentages
    x = x*100
    y = y*100

    # Histogram data to show point density
    bins = [1500, 1500]
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    print(z2.max())

    # Normalize colormaps to same range
    if v_type != "below_threshold":
        scatter = ax.scatter(x2, y2, c=z2, cmap=cmaps[i], norm=mpl.colors.Normalize(vmin=1, vmax=vmax), edgecolor=edgecolors[i], lw=0.2, s=100, alpha=1)
    else:
        scatter = ax.scatter(x2, y2, c=edgecolors[i], edgecolor=edgecolors[i], lw=0.2, s=100, alpha=0.2)

    # Add colorbar
    if v_type != "below_threshold":
        cbar = fig.colorbar(scatter, ax=ax, pad=0.01, shrink=0.7)
        cbar.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        if i == 0:
            cbar.ax.tick_params(axis="both", labelsize=fontsize-2)
        else:
            cbar.set_ticks([])
        cbar.ax.set_ylabel(f"# of {labels[i]} viruses", fontsize=fontsize)

# Mark Ebolavirus with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u10"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u10"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=100)

# Mark u102540 with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u102540"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u102540"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=100)

# Mark u11150 with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u11150"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u11150"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=100)

# Mark u39566 with black edge
x_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "macaca_mulatta", palmdb_adata.var.index=="u39566"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "macaca_mulatta"])
y_ebov = (np.array(palmdb_adata[palmdb_adata.obs["species"] == "canis_lupus_familiaris", palmdb_adata.var.index=="u39566"].X.sum(axis=0))).flatten() / len(palmdb_adata.obs[palmdb_adata.obs["species"] == "canis_lupus_familiaris"])
ax.scatter(x_ebov*100, y_ebov*100, facecolor=[0, 0, 0, 0], edgecolor=[0, 0, 0, 1], lw=1.5, s=100)

ax.set_xlabel("Percentage of positive Macaca mulatta cells", fontsize=fontsize)
ax.set_ylabel("Percentage of positive MDCK cells", fontsize=fontsize)

ax.set_ylim(bottom=-0.01, top=0.23)
ax.set_xlim(left=-0.01, right=0.23)

# Add shared threshold area and y=x line
# Get axis limits
lims = [
        0.0001,
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]

# Threshold fold change above/below x=y line
x_thres = np.array(lims)
ax.fill_between(x_thres, 10**(np.log10(x_thres)+np.log10(fold_change)), 10**(np.log10(x_thres)-np.log10(fold_change)), alpha=0.1, color="#f77f00")

# Add minimum fraction threshold lines
ax.plot([min_frac_threshold*100, min_frac_threshold*100], [min_frac_threshold*100, -1], c="grey", lw=1, ls="--", zorder=-1)
ax.plot([min_frac_threshold*100, -1], [min_frac_threshold*100, min_frac_threshold*100], c="grey", lw=1, ls="--", zorder=-1)

ax.plot(ax.get_xlim(), ax.get_xlim(), c="grey", lw=1, ls="-")

ax.set_aspect("equal")

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

# Save figure
fig.savefig(
    "virus_shared_thresholds_zero.png", dpi=300, bbox_inches="tight", transparent=True
)

fig.show()

In [None]:
%%time
# Same for also-in-host matrix

# Get fraction of positive cells for each virus for each species
mac_fractions = nd(palmdb_adata_host[palmdb_adata_host.obs["species"] == "macaca_mulatta", :].X.sum(axis=0)) / len(palmdb_adata_host.obs[palmdb_adata_host.obs["species"] == "macaca_mulatta"])
can_fractions = nd(palmdb_adata_host[palmdb_adata_host.obs["species"] == "canis_lupus_familiaris", :].X.sum(axis=0)) / len(palmdb_adata_host.obs[palmdb_adata_host.obs["species"] == "canis_lupus_familiaris"])

v_types = []
for mac_frac, can_frac in zip(mac_fractions, can_fractions):
    if mac_frac < min_frac_threshold and can_frac < min_frac_threshold:
        v_types.append("below_threshold")
    elif mac_frac > can_frac and mac_frac / can_frac <= fold_change:
        v_types.append("shared")
    elif can_frac > mac_frac and can_frac / mac_frac <= fold_change:
        v_types.append("shared")
    elif mac_frac > can_frac and can_frac <= (min_num_cells / len(palmdb_adata_host.obs[palmdb_adata_host.obs["species"] == "canis_lupus_familiaris"])): # Allow [min_num_cells] positive cells in the other species
        v_types.append("macaca_only")
    elif can_frac > mac_frac and mac_frac <= (min_num_cells / len(palmdb_adata_host.obs[palmdb_adata_host.obs["species"] == "macaca_mulatta"])):
        v_types.append("canis_only")
    else:
        v_types.append("undefined")

palmdb_adata_host.var["v_type"] = v_types

In [None]:
Counter(v_types)

# For each virus, plot number of only virus / also in host reads

In [None]:
# Macaque only + shared viruses in order of heatmap
virs = ['u39566', 'u102540', 'u11150', 'u10', 'u288819','u290519','u10240','u183255','u1001','u100291','u103829','u110641','u181379','u202260','u135858','u101227','u100188','u27694','u34159','u100245','u10015','u100733','u100173','u100196','u100599','u100644','u100296','u100017','u100002','u100012','u100024','u100048','u100302','u100074','u100289','u100026','u100111','u100139','u100154','u100251','u100177','u100215','u100049','u100000','u100001','u100007','u100004','u100011','u100093','u100116','u100019','u100076','u100028','u100153','u100031','u100145','u102324','u134800']
vir_fracs = []
host_fracs = []
for i, vir in enumerate(virs):
    vir_count = palmdb_adata.raw[:, palmdb_adata.var.index == vir].X.sum()
    host_count = palmdb_adata_host.raw[:, palmdb_adata_host.var.index == vir].X.sum()
    total = vir_count + host_count

    vir_fracs.append(vir_count / total)
    host_fracs.append(host_count / total)

df_read_fracs = pd.DataFrame()
df_read_fracs["vir"] = virs
df_read_fracs["vir_frac"] = vir_fracs
df_read_fracs["host_frac"] = host_fracs

df_read_fracs = df_read_fracs.sort_values("vir_frac", ascending=False)

In [None]:
fig, ax = plt.subplots(figsize=(20, 3))
fontsize = 14
width = 0.75

for i, vir in enumerate(df_read_fracs["vir"].values):
    vir_frac = df_read_fracs[df_read_fracs["vir"] == vir]["vir_frac"].values[0]
    host_frac = df_read_fracs[df_read_fracs["vir"] == vir]["host_frac"].values[0]
    if i == 0:
        ax.bar(i, vir_frac, width=width, label="Virus only", color="#f77f00")
        ax.bar(i, host_frac, bottom=vir_frac, label="Also in host", width=width, color="#808080")
    else:
        ax.bar(i, vir_frac, width=width, color="#f77f00")
        ax.bar(i, host_frac, bottom=vir_frac, width=width, color="#808080")

ax.set_ylabel("Fraction of reads", fontsize=fontsize+2)
ax.set_xlabel("Virus ID", fontsize=fontsize+2)
ax.set_xticks(np.arange(len(df_read_fracs["vir"].values)), df_read_fracs["vir"].values, rotation=45, ha="right", fontsize=fontsize)
ax.margins(x=0.005, y=0)
ax.legend(fontsize=fontsize, loc="upper right")
ax.tick_params(axis='both', labelsize=fontsize)

ax.axhline(0.5, color="black", ls="--", lw=1)

ax.grid(True, which="both", color="lightgrey", ls="--", lw=1)
ax.set_axisbelow(True)
ax.xaxis.grid(False)

plt.savefig("frac_per_capture_type.png", dpi=300, bbox_inches="tight")

fig.show()

In [None]:
palmdb_adata_host.raw[:, palmdb_adata_host.var.index == "u10"].X.sum()

# Plot count histogram to show low counts per virus ID (compared to host) before binarizing:

In [None]:
palmdb_adata_temp = palmdb_adata.raw[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "macaca_only"].X
vir_counts = nd(palmdb_adata_temp[palmdb_adata_temp>0])

palmdb_adata_temp = palmdb_adata.raw[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "shared"].X
vir_counts2 = nd(palmdb_adata_temp[palmdb_adata_temp>0])

host_counts = nd(mac_adata.raw.X[mac_adata.raw.X>0])

In [None]:
fig, ax = plt.subplots(figsize=(10, 3))
fontsize = 10

ax.hist(host_counts, bins=160, histtype='stepfilled', edgecolor="grey", facecolor="lightgrey", label="Macaque genes", log=True)
ax.hist(vir_counts2, bins=10, histtype='stepfilled', edgecolor="#4b8eb3", facecolor="#4b8eb3", alpha=0.8, label="Viruses (shared)", log=True)
ax.hist(vir_counts, bins=2, histtype='stepfilled', edgecolor="#003049", facecolor="#003049", alpha=1, label="Viruses (macaque only)", log=True)

ax.legend(fontsize=fontsize)

ax.set_xlabel("Count in one cell", fontsize=fontsize)
ax.set_ylabel("Frequency", fontsize=fontsize)

# Change fontsize of tick labels
ax.tick_params(axis="both", labelsize=fontsize-2)

ax.grid(True, which="both", color="lightgray", ls="--", lw=1)
ax.set_axisbelow(True)

plt.savefig("count_histogram.png", dpi=300, bbox_inches="tight")

fig.show()

___

In [None]:
# Save new binarized virus matrix including host species information and virus types
palmdb_adata.write("virus_host-captured_dlist_cdna_dna.h5ad")
palmdb_adata_host.write("virus_host-captured_dlist_cdna_dna_also-in-host.h5ad")