<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Figure_6/Figure_6cd/explore_macaque-only_and_shared_virs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Explore distributions of macaque only and shared viruses

In [None]:
!pip install -q gget
import numpy as np
import gget
from scipy import stats
import anndata
import pandas as pd
import json
import os
import glob
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
%config InlineBackend.figure_format='retina'

def nd(arr):
    """
    Function to transform numpy matrix to nd array.
    """
    return np.asarray(arr).reshape(-1)

# Load data

The count matrix was generated [here](https://github.com/pachterlab/LSCHWCP_2023/tree/main/Notebooks/align_macaque_PBMC_data/7_virus_host_captured_dlist_cdna_dna).



In [None]:
# Load virus count matrix from Caltech Data
# !wget virus_host-captured_dlist_cdna_dna.h5ad



In [None]:
palmdb_adata = anndata.read("virus_host-captured_dlist_cdna_dna.h5ad")
palmdb_adata

In [None]:
palmdb_adata.X.max()

#### Load ID to taxonomy mapping

In [None]:
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/PalmDB/ID_to_taxonomy_mapping.csv

In [None]:
phylogeny_data = pd.read_csv("ID_to_taxonomy_mapping.csv")

# Drop columns not needed here
phylogeny_data = phylogeny_data.drop("ID", axis=1).drop("strandedness", axis=1).drop_duplicates()
phylogeny_data

___
# Macaque viruses
### Show fraction of positive cells for each virus per day per animal (after host QC - macaque cells only)

In [None]:
def plot_heatmap(vir_fractions_df, save_name, num_virs=100, norm_axis=1, figsize=(17, 7)):
    fontsize = 16

    fraction_df = vir_fractions_df[[column for column in vir_fractions_df.columns if "_fraction" in column]]
    fraction_df = fraction_df.iloc[:, : num_virs]

    # Rename columns to keep only virus ID
    cols = [column.split("_")[0] for column in fraction_df.columns]
    fraction_df.columns = cols

    # Reverse row order
    fraction_df = fraction_df.iloc[::-1]

    fig = sns.clustermap(
        fraction_df,
        figsize=figsize,
        row_cluster=False,
        standard_scale=norm_axis,
        dendrogram_ratio=(.1, .1),
        cbar_pos=(0.1, -0.02, .15, .02),
        cbar_kws={
            "label": "Scaled fraction of positive cells" if norm_axis else "Fraction of positive cells",
            "orientation": "horizontal"
        },
        yticklabels=True,
        xticklabels=True,
        # norm=mpl.colors.LogNorm()
        # cmap="Reds"
    )

    # Save figure
    fig.savefig(
        f"frac_per_virus_{save_name}.png", dpi=300, bbox_inches="tight"
    )

    # Return ordered virus IDs
    return fig.ax_heatmap.xaxis.get_majorticklabels()

Compute fraction of positive cells per virus:

In [None]:
%%time
adata = palmdb_adata[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "macaca_only"].copy()

timepoints = ['-30d', '-4d', '0d', '4h', '24h', '3d', '4d', '5d', '6d', '7d', '8d']
samples = []
sample_totals = []

# Get counts for each virus per sample
vir_fractions_df = pd.DataFrame()

samples_generated = False
# sample_totals_generated = False
for virus_id in adata.var.index.values:
    virus_counts = []
    for timepoint in timepoints:
        for animal_id in np.sort(adata.obs[adata.obs["dpi_clean"] == timepoint]["donor_animal"].unique()):
            virus_counts.append(adata.X[(adata.obs["donor_animal"] == animal_id) & (adata.obs["dpi_clean"] == timepoint), adata.var.index == virus_id].sum())

            if not samples_generated:
                samples.append(str(timepoint + "_" + animal_id))
                # Get total number of cells positive for at least one virus
                # total_pos_cells = (nd(adata.X[(adata.obs["donor_animal"] == animal_id) & (adata.obs["dpi_clean"] == timepoint), :].sum(axis=1) > 0)).sum()
                # sample_totals.append(total_pos_cells)

    samples_generated = True
    # sample_totals_generated = True
    utax_label = "_".join(phylogeny_data[phylogeny_data["rep_ID"] == virus_id].values[0]).replace("_.", "")
    vir_fractions_df[utax_label] = virus_counts

vir_fractions_df["sample"] = samples

# Set sample as index
vir_fractions_df = vir_fractions_df.set_index("sample")

# Add total number of cells per sample
num_cells = []
for sample in vir_fractions_df.index.values:
    num_cells.append(len(adata.obs[(adata.obs["dpi_clean"] == sample.split("_")[0]) & (adata.obs["donor_animal"] == sample.split("_")[1])]))
vir_fractions_df["sample_total"] = num_cells

# Get fractions for each virus
for virus in vir_fractions_df.columns[:-1]:
    vir_fractions_df[f"{virus}_fraction"] = vir_fractions_df[virus] / vir_fractions_df["sample_total"]

# De-fragment dataframe
vir_fractions_df = vir_fractions_df.copy()

# Sort by maximum value accross all samples
vir_fractions_df = vir_fractions_df[vir_fractions_df.max().sort_values(ascending=False).index]

# De-fragment dataframe
vir_fractions_df = vir_fractions_df.copy()

vir_fractions_df.head()

In [None]:
# Show macaque viruses sorted by highest fraction of positive cells across all samples
sorted_mac_virs = [col.split("_")[0] for col in vir_fractions_df if "_fraction" not in col][1:]
print("Total number of macaque only viruses: ", len(sorted_mac_virs))
sorted_mac_virs

Plot heatmap:

In [None]:
temp = plot_heatmap(vir_fractions_df, "per_tp_per_animal_nonorm", norm_axis=None, figsize=(3, 5))

In [None]:
sorted_vir_ids_plt = plot_heatmap(vir_fractions_df, "per_tp_per_animal", norm_axis=1, figsize=(3, 7))

In [None]:
sorted_vir_ids = [vir_id.get_text() for vir_id in sorted_vir_ids_plt]
sorted_vir_ids

Plot total fraction of positive cells across all samples:

In [None]:
%%time
adata = palmdb_adata[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "macaca_only"].copy()

timepoints = ['-30d', '-4d', '0d', '4h', '24h', '3d', '4d', '5d', '6d', '7d', '8d']

totals_df = pd.DataFrame()

for virus_id in sorted_vir_ids:
    total = adata.X[:, adata.var.index == virus_id].sum()
    totals_df[virus_id] = [total]

# Add total number of cells
totals_df["sample_total"] = len(adata.obs)

# Get fractions for each virus
for virus in totals_df.columns[:-1]:
    totals_df[f"{virus}_fraction"] = totals_df[virus] / totals_df["sample_total"]

totals_df.head()

In [None]:
totals_df_fractions = totals_df[[column for column in totals_df.columns if "_fraction" in column]]

fig, ax = plt.subplots(figsize=(3, 5))
fontsize = 20

# cmap = sns.color_palette("rocket", as_cmap=True)
cmap = "Blues_r"

im = ax.imshow(totals_df_fractions.values, cmap=cmap, vmin = None, vmax = 0.1, aspect=0.45)

cb = plt.colorbar(im, shrink=0.25, orientation="horizontal", pad=0.4)
cb.set_label(label="Fraction of positive cells across all samples", size=fontsize)
cb.ax.tick_params(axis="both", labelsize=fontsize+6, length=8, width=2, pad=8) # Trying to match sns colorbar

ax.set_yticks([])
ax.set_xticks(np.arange(len(sorted_vir_ids)))
ax.set_xticklabels(sorted_vir_ids, rotation = 45, ha="right")

ax.tick_params(axis="both", labelsize=fontsize, length=5, width=1.5)

fig.savefig(
    "frac_per_virus_totals_bus+d-list.png", dpi=300, bbox_inches="tight"
)

fig.show()

Save virus IDs and their taxonomies in table:

In [None]:
temp_df = phylogeny_data[phylogeny_data["rep_ID"].isin(sorted_vir_ids)]
temp_df.rep_ID = temp_df.rep_ID.astype("category")
temp_df.rep_ID = temp_df.rep_ID.cat.set_categories(sorted_vir_ids)
temp_df = temp_df.sort_values(["rep_ID"])

# Get number of positive cells per virus
adata = palmdb_adata[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "macaca_only"].copy()
pos_cells = []
for vir_id in temp_df["rep_ID"].values:
    pos_cells.append((nd(adata.X[:, adata.var.index == vir_id].sum(axis=1) > 0)).sum())

temp_df["# positive macaque cells"] = pos_cells

temp_df.to_csv("macaque_only_virs_ID2tax.csv", index=False)

In [None]:
temp_df.sort_values("# positive macaque cells", ascending=False)

In [None]:
# Blast u11150
gget.blast("YPYDAPSFDHEVGTNEVESVMLEENAWFKRNKDTIHADRITISNAILEKWWNTDIKIEGYPTFKYKHGVGSGIPATYYLDELVNEGRTETIFASIKKAFNLKSEVVSDKSGDDLMA")

___
# Same for shared viruses

In [None]:
%%time
adata = palmdb_adata[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "shared"].copy()

timepoints = ['-30d', '-4d', '0d', '4h', '24h', '3d', '4d', '5d', '6d', '7d', '8d']
samples = []

# Get counts for each virus per sample
vir_fractions_df = pd.DataFrame()

samples_generated = False

for virus_id in adata.var.index.values:
    virus_counts = []

    for timepoint in timepoints:
        for animal_id in np.sort(adata.obs[adata.obs["dpi_clean"] == timepoint]["donor_animal"].unique()):
            virus_counts.append(adata.X[(adata.obs["donor_animal"] == animal_id) & (adata.obs["dpi_clean"] == timepoint), adata.var.index == virus_id].sum())

            if not samples_generated:
                samples.append(str(timepoint + "_" + animal_id))

    samples_generated = True
    utax_label = "_".join(phylogeny_data[phylogeny_data["rep_ID"] == virus_id].values[0]).replace("_.", "")
    vir_fractions_df[utax_label] = virus_counts

vir_fractions_df["sample"] = samples

# De-fragment dataframe
vir_fractions_df = vir_fractions_df.copy()

# Set sample as index
vir_fractions_df = vir_fractions_df.set_index("sample")

# Add total number of cells per sample
num_cells = []
for sample in vir_fractions_df.index.values:
    num_cells.append(len(adata.obs[(adata.obs["dpi_clean"] == sample.split("_")[0]) & (adata.obs["donor_animal"] == sample.split("_")[1])]))
vir_fractions_df["sample_total"] = num_cells

# Get fractions for each virus
for virus in vir_fractions_df.columns[:-1]:
    vir_fractions_df[f"{virus}_fraction"] = vir_fractions_df[virus] / vir_fractions_df["sample_total"]

# De-fragment dataframe
vir_fractions_df = vir_fractions_df.copy()

# Sort by maximum value accross all samples
vir_fractions_df = vir_fractions_df[vir_fractions_df.max().sort_values(ascending=False).index]

vir_fractions_df.head()

In [None]:
# Show shared viruses sorted by highest fraction of positive cells across all samples
sorted_shared_virs = [col.split("_")[0] for col in vir_fractions_df if "_fraction" not in col][1:]
print("Total number of shared viruses: ", len(sorted_shared_virs))
sorted_shared_virs

In [None]:
shared_virs = plot_heatmap(vir_fractions_df, "per_tp_per_animal_shared", figsize=(20, 7))

In [None]:
shared_virs_clean = [vir_id.get_text() for vir_id in shared_virs]
shared_virs_clean

Plot total fraction of positive cells across all samples:

In [None]:
%%time
adata = palmdb_adata[palmdb_adata.obs["celltype"].notnull(), palmdb_adata.var["v_type"] == "shared"].copy()

totals_df_shared = pd.DataFrame()

for virus_id in shared_virs_clean:
    total = adata.X[:, adata.var.index == virus_id].sum()
    totals_df_shared[virus_id] = [total]

# Add total number of cells
totals_df_shared["sample_total"] = len(adata.obs)

# Get fractions for each virus
for virus in totals_df_shared.columns[:-1]:
    totals_df_shared[f"{virus}_fraction"] = totals_df_shared[virus] / totals_df_shared["sample_total"]

totals_df_shared.head()

In [None]:
totals_df_shared_fractions = totals_df_shared[[column for column in totals_df_shared.columns if "_fraction" in column]]

fig, ax = plt.subplots(figsize=(30, 5))
fontsize = 18

# cmap = sns.color_palette("rocket", as_cmap=True)
cmap = "Blues_r"

im = ax.imshow(totals_df_shared_fractions.values, cmap=cmap, vmin = 0, vmax = 0.1, aspect=0.5)

cb = plt.colorbar(im, shrink=0.25, orientation="horizontal", pad=0.4)
cb.set_label(label="Fraction of positive cells across all samples", size=fontsize)
cb.ax.tick_params(axis="both", labelsize=fontsize)

ax.set_yticks([])
ax.set_xticks(np.arange(len(shared_virs_clean)))
ax.set_xticklabels(shared_virs_clean, rotation = 45, ha="right")

ax.tick_params(axis="both", labelsize=fontsize)

fig.savefig(
    "frac_per_virus_totals_shared.png", dpi=300, bbox_inches="tight"
)

fig.show()

In [None]:
temp_df_shared = phylogeny_data[phylogeny_data["rep_ID"].isin(shared_virs_clean)]
temp_df_shared.rep_ID = temp_df_shared.rep_ID.astype("category")
temp_df_shared.rep_ID = temp_df_shared.rep_ID.cat.set_categories(shared_virs_clean)
temp_df_shared = temp_df_shared.sort_values(["rep_ID"])

# Get number of positive cells per virus
adata = palmdb_adata[palmdb_adata.obs["celltype"].notnull(), :].copy()
pos_cells = []
for vir_id in temp_df_shared["rep_ID"].values:
    pos_cells.append((nd(adata.X[:, adata.var.index == vir_id].sum(axis=1) > 0)).sum())

temp_df_shared["# positive macaque cells"] = pos_cells

temp_df_shared.to_csv("shared_virs_ID2tax.csv", index=False)

In [None]:
temp_df_shared.sort_values("# positive macaque cells", ascending=False)

___
# Plot shared/macaque_only/canis_only fractions per phylum, order, class etc.

In [None]:
%%time
adata = palmdb_adata

master_df = pd.DataFrame()
for i, tax in enumerate(["order"]):
    # Get counts for each virus per sample
    vir_fractions_df = pd.DataFrame()

    samples = []
    m_counts = []
    c_counts = []
    s_counts = []
    u_counts = []
    for tax_name in np.sort(phylogeny_data[tax].unique()):
        v_ids = phylogeny_data[phylogeny_data[tax] == tax_name]["rep_ID"].unique()
        m_counts.append(adata[:, (adata.var.index.isin(v_ids)) & (adata.var["v_type"] == "macaca_only")].X.sum())
        c_counts.append(adata[:, (adata.var.index.isin(v_ids)) & (adata.var["v_type"] == "canis_only")].X.sum())
        s_counts.append(adata[:, (adata.var.index.isin(v_ids)) & (adata.var["v_type"] == "shared")].X.sum())
        u_counts.append(adata[:, (adata.var.index.isin(v_ids)) & (adata.var["v_type"] == "undefined")].X.sum())

        samples.append(tax_name)

    vir_fractions_df["sample"] = samples
    vir_fractions_df["macaca_only"] = m_counts
    vir_fractions_df["shared"] = s_counts
    vir_fractions_df["canis_only"] = c_counts
    vir_fractions_df["undefined"] = u_counts

    # De-fragment dataframe
    vir_fractions_df = vir_fractions_df.copy()

    # Set sample as index
    vir_fractions_df = vir_fractions_df.set_index("sample")

    # Get total virus reads per sample
    vir_fractions_df["sample_total"] = vir_fractions_df.sum(axis = 1)

    # Get fractions for each virus
    for virus in vir_fractions_df.columns[:-1]:
        vir_fractions_df[f"{virus}_fraction"] = vir_fractions_df[virus] / vir_fractions_df["sample_total"]

    # De-fragment dataframe
    vir_fractions_df = vir_fractions_df.copy()

    vir_fractions_df = vir_fractions_df.dropna()

    vir_fractions_df = vir_fractions_df.sort_values(["macaca_only_fraction", "canis_only_fraction", "undefined_fraction", "shared_fraction"])

    if i == 0:
        master_df = vir_fractions_df.copy()

    else:
        master_df = master_df.append(vir_fractions_df)

master_df.to_csv("fractions_vtype_per_tax.csv")
vir_fractions_df = master_df

vir_fractions_df

In [None]:
fig, ax = plt.subplots(figsize=(11, 7))
label_color="black"
fontsize = 16
width = 0.7

samples = vir_fractions_df.index.values
x = np.arange(len(samples))
bottom = np.zeros(len(samples))

vcs = ['macaca_only_fraction', 'canis_only_fraction', 'undefined_fraction', 'shared_fraction']
labels = ["Macaque only", "MDCK only", "Undefined", "Shared"]

colors = ["#003049", "#8fc0de", "lightgrey", "#4b8eb3"]
c_idx = 0
for label, virus_column in zip(labels, vcs):
    if "_fraction" in virus_column:
        if virus_column == 'canis_only_fraction':
            ax.bar(x, vir_fractions_df[virus_column].values, width, hatch="//", label=label, bottom=bottom, color=colors[c_idx])
        else:
            ax.bar(x, vir_fractions_df[virus_column].values, width, label=label, bottom=bottom, color=colors[c_idx])
        bottom += vir_fractions_df[virus_column].values
        c_idx += 1

# Add number of total virus reads to plot
y_height = 1.01
for i, sample in enumerate(samples):
    total_count = "{:,}".format(vir_fractions_df[vir_fractions_df.index == sample]["sample_total"].values[0].astype(int))
    ax.text(i, y_height, total_count, fontsize=fontsize-2, ha="left", rotation=45, color=label_color)

# ax.set_yscale("log")
ax.set_ylabel("Fraction of positive cells", fontsize=fontsize+2, color=label_color)
ax.set_xlabel("Virus order", fontsize=fontsize+2, color=label_color)

leg = ax.legend(fontsize=fontsize, loc='upper left')
for i, leg_txt in enumerate(leg.get_texts()):
    leg_txt.set_color(label_color)
leg.get_frame().set_alpha(1)

labels = []
for sample in samples:
    if sample == ".":
        labels.append("Undefined")
    else:
        labels.append(sample)
ax.set_xticks(x, labels, rotation=45, ha="right")

ax.tick_params(axis="both", labelsize=fontsize)
# ax.set_title(f"Fraction of positive cells per virus category", fontsize=fontsize+2, color=label_color, pad=65)

ax.margins(y=0, x=0.005)

# # Adjust for black background
# ax.spines['bottom'].set_color('white')
# ax.spines['top'].set_color('white')
# ax.spines['left'].set_color('white')
# ax.spines['right'].set_color('white')
# ax.tick_params(axis='both', colors='white')

# plt.gca().invert_yaxis()

# plt.tight_layout()

ax.grid(True, which="both", color="lightgrey", ls="--", lw=1)
ax.set_axisbelow(True)
ax.xaxis.grid(False)

plt.savefig("genus_fractions_v_type.png", dpi=300, bbox_inches="tight", transparent=True)

fig.show()