<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Supp_Fig_10/Supp_Fig_10e/3_visualize_clustering_taxonomy_heatmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Visualize which taxonomies clustered together
If the clustering was done correctly, we expect that sequences with similar taxonomies based on the virus ID to sOTU mapping were clustered together, since those sequences should be more similar to each other.

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
%config InlineBackend.figure_format='retina'

Download PalmDB RdRP sequences reverse translated to nucleotides as shown in [this notebook](https://github.com/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Figure_3/Figure_3b/validate_palmdb2palmdb.ipynb):

In [None]:
!wget https://github.com/pachterlab/LSCHWCP_2023/raw/main/Notebooks/Figure_3/Figure_3b/palmdb_rdrp_seqs_nucleotides.fasta.zip
!unzip palmdb_rdrp_seqs_nucleotides.fasta.zip
palmdb_nn_fasta = "palmdb_rdrp_seqs_nucleotides.fasta"

Download the original virus ID t0 sOTU mapping as provided by Edgar et al.
The file was originally downloaded from here: https://github.com/rcedgar/palmdb/blob/main/2021-03-14/u_tax.tsv

In [None]:
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/create_optimized_palmdb/u_tax.tsv

### Load original taxonomies (pre-clustering):

In [None]:
# Load original taxonomies
df_tax = pd.read_csv("u_tax.tsv", sep="\t")
df_tax = df_tax.set_index("Label")

df_tax

### Load cluster members returned by mmseqs2:

In [None]:
df_clu_1 = pd.read_csv(palmdb_clu_tsv_1, sep="\t", header=None)
# Rename columns
df_clu_1.columns = ["representative", "member"]

In [None]:
codes, uniques = df_clu_1["representative"].factorize()
df_clu_1["cluster"] = codes
df_clu_1

___

### Build data frame with clustered virus IDs and their taxonomies

In [None]:
# Change order of IDs so IDs that are clustered together appear in succession
df_tax_pivoted = df_tax_pivoted.reindex(df_clu_1["member"].values)

In [None]:
df_tax_pivoted = df_tax_pivoted.reset_index()

df_tax_pivoted = df_tax_pivoted.merge(
    df_clu_1,
    left_on = "Label",
    right_on = "member",
    how = "outer"
)

df_tax_pivoted = df_tax_pivoted.drop(["representative", "member"], axis=1)
df_tax_pivoted

In [None]:
df_tax_pivoted = df_tax_pivoted.set_index(["cluster", "Label"])
df_tax_pivoted

In [None]:
df_tax_pivoted = df_tax_pivoted.sort_values(list(df_tax_pivoted.columns.values), ascending=False)

In [None]:
df_tax_pivoted = df_tax_pivoted.drop(["."], axis=1)
df_tax_pivoted = df_tax_pivoted.loc[~(df_tax_pivoted==0).all(axis=1)]

In [None]:
df_tax_pivoted

___

### Generate heatmaps for each taxonomy level:

In [None]:
%%time
fontsize = 12

for i, group_by in enumerate(df_tax.columns.values):
    # Create a copy of df_tax keeping only the index
    df_tax_pivoted = pd.DataFrame(df_tax.reset_index()["Label"])
    df_tax_pivoted = df_tax_pivoted.set_index("Label")

    # Convert to boolean dataframe by grouping
    for group in np.unique(df_tax[f"{group_by}"].values):
        df_tax_pivoted[group] = pd.DataFrame(df_tax[f"{group_by}"] == group)[f"{group_by}"].values.astype(int)

    # Change order of IDs so IDs that are clustered together appear in succession
    df_tax_pivoted = df_tax_pivoted.reindex(df_clu_1["member"].values)

    # Add cluster labels
    df_tax_pivoted = df_tax_pivoted.reset_index()
    df_tax_pivoted = df_tax_pivoted.merge(
        df_clu_1,
        left_on = "Label",
        right_on = "member",
        how = "outer"
    )
    df_tax_pivoted = df_tax_pivoted.drop(["representative", "member"], axis=1)
    df_tax_pivoted = df_tax_pivoted.set_index(["cluster", "Label"])

    # Sort values
    df_tax_pivoted = df_tax_pivoted.sort_values(list(df_tax_pivoted.columns.values), ascending=False)

    # Drop "." and sequences that have no annotation after dropping "."
    df_tax_pivoted = df_tax_pivoted.drop(["."], axis=1)
    df_tax_pivoted = df_tax_pivoted.loc[~(df_tax_pivoted==0).all(axis=1)]

    print(f"Plotting {group_by}...")

    # Plot and save heatmap
    fig, ax = plt.subplots(figsize = (10, 9))

    # clusters = df_tax_pivoted.index.get_level_values(0).values
    # cluster_colors = [plt.cm.Spectral(color_idx) for color_idx in clusters]

    # x = ["∎"] * len(df_tax_pivoted)
    y = df_tax_pivoted.columns
    values = df_tax_pivoted.values.T

    im = ax.imshow(values, cmap="inferno", vmin=0, vmax=1, aspect="auto")

    # Add tick labels
    if group_by != "genus" and group_by != "species":
        ax.set_yticks(np.arange(len(y)), labels=y)
    # ax.set_xticks(np.arange(len(x)), labels=x)
    # for xtick, color in zip(ax.get_xticklabels(), cluster_colors):
    #     xtick.set_color(color)

    # # Add lines to delienate clusters
    # for i, cluster_end in enumerate(df_tax_pivoted.reset_index().groupby("cluster", sort=False).count()["Label"].values):
    #     if i == 0:
    #         counter = cluster_end
    #         ax.axvline(counter, color="white", lw=0.1)
    #     else:
    #         counter += cluster_end
    #         ax.axvline(counter, color="white", lw=0.1)

    ax.set_xlabel("RdRP sequence", fontsize=fontsize)
    ax.set_title(f"Clustered sequences by virus {group_by}", fontsize=fontsize+2)

    plt.savefig(f"{group_by}_seqclusters_heatmaps.png", dpi=300, bbox_inches="tight")

    fig.tight_layout()
    fig.show()

___