In [1]:
!pip install plotly
!pip install ipywidgets



In [2]:
import plotly
import anndata

import pandas as pd
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
import numpy as np

plotly.offline.init_notebook_mode(connected=True)


In [3]:
# adata = anndata.read_h5ad("data/big_table_with_embeddings_1000.h5ad")
# adata = anndata.read_h5ad("data/americangut_embeddings.h5ad")
# adata = anndata.read_h5ad("data/delete/big_table_with_embeddings_full.h5ad")
adata = anndata.read_h5ad("data/big_table_with_embeddings.h5ad")


In [4]:
nan_vals = [np.nan, np.inf, -np.inf, "nan", "n/a", "na"]


In [5]:
# Add F:B ratio

f_cols = adata.var[adata.var["taxonomy_1"] == "p__Firmicutes"].index
b_cols = adata.var[adata.var["taxonomy_1"] == "p__Bacteroidetes"].index

f_vals = adata[:, f_cols].X.sum(axis=1)
b_vals = adata[:, b_cols].X.sum(axis=1) + 1e-6

fb_ratio = f_vals / b_vals
adata.obs["fb_ratio"] = np.clip(fb_ratio, 0, 10)


In [6]:
# Drop boring columns
for col in reversed(adata.obs.columns):
    vals = adata.obs[col]
    if vals.dtype.name == "category":
        vals = vals.str.lower()
        nan_indicator = vals.isin(nan_vals)
        n_nans = nan_indicator.sum()
        vals_filtered = vals[~nan_indicator]
    else:
        nan_indicator = np.isnan(vals)
        n_nans = nan_indicator.sum()
    if n_nans > 20000:
        print(f"Dropping {col} with {n_nans} nan values")
        adata.obs.drop(columns=col, inplace=True)

    elif vals[~nan_indicator].nunique() == 1:
        print(f"Dropping {col} with only one unique value")
        adata.obs.drop(columns=col, inplace=True)


Dropping weight_units with 32608 nan values
Dropping weight_kg with 32608 nan values
Dropping weight_cat with 32589 nan values
Dropping weekend_wake_time with 32608 nan values
Dropping weekend_sleep_time with 32608 nan values
Dropping weekday_wake_time with 32608 nan values
Dropping weekday_sleep_time with 32608 nan values
Dropping vioscreen_zinc with 32601 nan values
Dropping vioscreen_xylitol with 32601 nan values
Dropping vioscreen_whole_grain_servings with 32601 nan values
Dropping vioscreen_wgrain with 32601 nan values
Dropping vioscreen_weight with 32608 nan values
Dropping vioscreen_water with 32601 nan values
Dropping vioscreen_vitk with 32601 nan values
Dropping vioscreen_vite_iu with 32601 nan values
Dropping vioscreen_vitd_iu with 32601 nan values
Dropping vioscreen_vitd3 with 32601 nan values
Dropping vioscreen_vitd2 with 32601 nan values
Dropping vioscreen_vitd with 32601 nan values
Dropping vioscreen_vitc with 32601 nan values
Dropping vioscreen_vitb6 with 32601 nan value

In [7]:
import plotly.express as px
from ipywidgets.embed import embed_minimal_html

obsm_matrices = list(adata.obsm.keys())
# obsm_matrices = ["euc_mix_2", "geomstats_hyperboloid_mix_CONVERTED_2", "hyp_mix_2_PAC", "pca_2"]
metadata_vars = list(adata.obs.keys())
# component_vars = list(adata.varm.keys())
# component_vars.append(None)
component_vars = [
    "component_embeddings_euclidean_2",
    "component_embeddings_poincare_2",
    None,
]

# Create two blank figures, one for NaN values and one for regular data
# This has colorbar set to None so that it doesn't show up
pdisk = go.Scattergl(
    x=[],
    y=[],
    mode="lines",
    line=dict(color="black"),
    name="Poincare disk boundary",
)
scatter_nan = go.Scattergl(
    x=[],
    y=[],
    mode="markers",
    text=[],
    marker=dict(color="gray", colorbar=None, size=4),
    name="NaN values",
)
scatter = go.Scattergl(
    x=[],
    y=[],
    mode="markers",
    text=[],
    marker=dict(color=[], colorscale="Plasma", size=4, colorbar=dict(title="")),
    name="Valid values",
)
scatter_components = go.Scattergl(
    x=[],
    y=[],
    mode="markers",
    text=[],
    marker=dict(
        color=[], colorscale="Viridis", size=4, colorbar=dict(title="")
    ),
    name="OTUs",
)

fig = go.FigureWidget(
    data=[scatter_nan, scatter, pdisk, scatter_components],
    layout=go.Layout(
        title=dict(text="Plotting Embeddings"),
        xaxis=dict(title="X"),
        yaxis=dict(title="Y"),
        hovermode="closest",
        # Move legend to bottom right
        legend=dict(x=1.1, y=0.5, bgcolor="rgba(0,0,0,0)"),
        autosize=False,
        width=1000,
        height=1000,
    ),
)

# Create widgets: select embedding and metadata, click plot
embed_dropdown = widgets.Dropdown(
    options=obsm_matrices, value=obsm_matrices[0], description="Embedding:"
)
metadata_dropdown = widgets.Dropdown(
    options=metadata_vars, value=metadata_vars[0], description="Metadata:"
)
component_dropdown = widgets.Dropdown(
    options=component_vars, value=None, description="OTU embeddings:"
)
dim1_textbox = widgets.Text(value="0", description="Dim 1:", disabled=False)
dim2_textbox = widgets.Text(value="1", description="Dim 2:", disabled=False)
plot_button = widgets.Button(description="Plot")

# Function to find nans in metadata
def find_nans(col):
    mask1 = col.isin(nan_vals)
    mask2 = col.isnull()
    return mask1 | mask2


# For Poincare disk
t = np.linspace(0, 2 * np.pi, 100)
circle_x = np.cos(t)
circle_y = np.sin(t)

# Update function: change embeddings, color, text
def plot_on_click(b):
    embedding = embed_dropdown.value
    dim1 = int(dim1_textbox.value)
    dim2 = int(dim2_textbox.value)
    x = adata.obsm[embedding][:, dim1]
    y = adata.obsm[embedding][:, dim2]
    if embedding in ["h2", "poi_mix_2"]:
        fig.data[2].x = circle_x
        fig.data[2].y = circle_y
    else:
        fig.data[2].x = []
        fig.data[2].y = []

    # Add color based on metadata
    metadata = metadata_dropdown.value
    col = adata.obs[metadata]
    if adata.obs[metadata].dtype.name == "category":
        col = col.str.lower()
        col = col.astype("category")
        color = col.cat.codes.copy()
    else:
        color = col.copy()

    nan_mask = find_nans(col)
    fig.data[0].x = x[nan_mask]
    fig.data[0].y = y[nan_mask]
    fig.data[1].marker.color = color[~nan_mask]
    fig.data[1].marker.colorbar.title = metadata
    fig.data[1].marker.colorbar.titleside = "top"
    fig.data[1].x = x[~nan_mask]
    fig.data[1].y = y[~nan_mask]

    # Update text to include metadata
    fig.data[0].text = [
        f"Index: {i}<br>{metadata}: {j}"
        for i, j in zip(adata.obs.index[nan_mask], col[nan_mask])
    ]
    fig.data[1].text = [
        f"Index: {i}<br>{metadata}: {j}"
        for i, j in zip(adata.obs.index[~nan_mask], col[~nan_mask])
    ]

    # Add component embeddings
    if component_dropdown.value is not None:
        ce = adata.varm[component_dropdown.value]
        if isinstance(ce, pd.DataFrame):
            ce = ce.values
        ce_x = ce[:, dim1]
        ce_y = ce[:, dim2]
        fig.data[3].x = ce_x
        fig.data[3].y = ce_y
        fig.data[3].text = [
            f"{i}<br>{' '.join(adata.var.loc[i])}" for i in adata.var.index
        ]
        fig.data[3].marker.color = (
            adata.var["taxonomy_1"].astype("category").cat.codes
        )
    else:
        fig.data[3].x = []
        fig.data[3].y = []
        fig.data[3].text = []
        fig.data[3].marker.color = []


plot_button.on_click(plot_on_click)

vbox = widgets.VBox(
    [
        embed_dropdown,
        metadata_dropdown,
        component_dropdown,
        dim1_textbox,
        dim2_textbox,
        plot_button,
    ]
)
vbox_all = widgets.VBox([vbox, fig])
display(vbox_all)

# Save plot 
# embed_minimal_html("plot.html", views=[vbox_all], title="Mixture embeddings AGP")


VBox(children=(VBox(children=(Dropdown(description='Embedding:', options=('euc_mix_128', 'euc_mix_16', 'euc_mi…

In [15]:
adata.obs["host_body_habitat"].groupby(adata.obs["host_body_habitat"]).count()

host_body_habitat
UBERON:feces            284
nan                    3214
uberon:blood              3
uberon:ear               55
uberon:eye               54
uberon:feces          24873
uberon:hair              19
uberon:nose             297
uberon:oral cavity     2259
uberon:skin            1405
uberon:vagina           145
Name: host_body_habitat, dtype: int64

In [8]:
# For offline troubleshooting

np.savetxt(
    "data/2d_hyperboloid_mix.csv",
    adata.obsm["geomstats_hyperboloid_mix_CONVERTED_2_100iter"],
)
adata.obs.to_csv("data/2d_hyperboloid_mix_metadata.csv")
