# Process data
> All of the data processing code I use for this project

In [1]:
%load_ext autoreload
%autoreload 2

# NeuroSEED embeddings - hyperbolic/Euclidean

In [2]:
# First things first: gzip all the embeddings so we don't run into LFS limits
import os
import pandas as pd

embed_dir = "../data/otu_embeddings/greengenes"
for file in os.listdir(embed_dir):
    if file.endswith(".csv"):
        old_path = os.path.join(embed_dir, file)
        new_path = os.path.join(
            embed_dir,
            file.replace("embeddings_", "").replace("hyperbolic", "H").replace("euclidean", "E").replace("_", "")
            + ".gz",
        )  # "embeddings_" prefix redundant
        print(f"Compressing {file}")
        pd.read_csv(old_path).to_csv(new_path, compression="gzip")
        print(f"Compressed {file} from {os.path.getsize(old_path)} to {os.path.getsize(new_path)}")
        print()

In [3]:
# Great! Now we can delete the old embeddings to clear some space
for file in os.listdir(embed_dir):
    if file.endswith(".csv"):
        os.remove(os.path.join(embed_dir, file))

I also added these datasets to huggingface:
https://huggingface.co/datasets/pchlenski/greengenes_embeddings

# MLRepo embeddings

In [4]:
import os
import pandas as pd
import numpy as np
import anndata

mlrepo_list = []

# Hardcoded task mappings
task_mapping = {
    "cho": ["control-ct-cecal", "control-ct-fecal", "penicillin-vancomycin-cecal", "penicillin-vancomycin-fecal"],
    "gevers": ["ileum", "pcdai-ileum", "pcdai-rectum", "rectum"],
    "hmp": ["gastro-oral", "sex", "stool-tongue-paired", "sub-supragingivalplaque-paired"],
    "karlsson": ["impaired-diabetes", "normal-diabetes"],
    "ravel": ["black-hispanic", "nugent-category", "nugent-score", "ph", "white-black"],
    "sokol": ["healthy-cd", "healthy-uc"],
    "turnbaugh": ["obese-lean-all"],
    "yatsunenko": ["baby-age", "malawi-venezuela", "sex", "usa-malawi"],
}

# Datasets to drop
to_drop = ["dethlefsen", "karlsson", "qin2012", "qin2014", "ridaura"]

mlrepo_dir = "../data/otu_tables/mlrepo"

for subdir in os.listdir(mlrepo_dir):
    if subdir in to_drop or not os.path.isdir(os.path.join(mlrepo_dir, subdir)):
        continue

    gg_path = os.path.join(mlrepo_dir, subdir, "gg", "otutable.txt")

    if not os.path.exists(gg_path):
        continue

    gg_otu_table = pd.read_table(gg_path, index_col=0).T
    # Convert index to string type to avoid ImplicitModificationWarning
    gg_otu_table.index = gg_otu_table.index.astype(str)
    task_adata = anndata.AnnData(gg_otu_table)

    # Initialize an empty DataFrame for all labels
    all_labels_df = pd.DataFrame(index=gg_otu_table.index)

    task_names = task_mapping.get(subdir, [])

    for task_name in task_names:
        labels_path = os.path.join(mlrepo_dir, subdir, f"task-{task_name}.txt")

        if os.path.exists(labels_path):
            labels = pd.read_table(labels_path, index_col=0)
            # Convert index to string type to avoid ImplicitModificationWarning
            labels.index = labels.index.astype(str)
            if "ControlVar" in labels.columns:
                labels = labels.drop(columns=["ControlVar"])

            labels.columns = [f"{subdir}_{task_name}"]
            all_labels_df = all_labels_df.join(labels, how="outer")

    # Ensure the final labels dataframe aligns with obs indices
    task_adata.obs = task_adata.obs.join(all_labels_df, how="left")
    task_adata.obs["dataset"] = subdir

    mlrepo_list.append(task_adata)

# Concatenate all AnnData objects with consistent vars
mlrepo = anndata.concat(mlrepo_list, join="outer", merge="same")
mlrepo.X = np.nan_to_num(mlrepo.X, nan=0)

# Ensure all column names in obs are strings
mlrepo.obs.columns = mlrepo.obs.columns.astype(str)

print(mlrepo)



ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [None]:
# Verify that all metadata has some non-nan values
mlrepo.obs.isna().all()

# Print counts for each unique value
for col in mlrepo.obs.columns:
    # Skip dataset col
    if col == "dataset":
        continue

    # Skip regression cols
    if col in ["ravel_nugent-score", "ravel_ph", "gevers_pcdai-ileum", "gevers_pcdai-rectum", "yatsunenko_baby-age"]:
        continue

    print(f"{col}: {mlrepo.obs[col].value_counts().to_dict()}")

ravel_black-hispanic: {'Black': 104, 'Hispanic': 95}
ravel_nugent-category: {'low': 245, 'high': 97}
ravel_white-black: {'Black': 104, 'White': 96}
gevers_ileum: {'CD': 78, 'no': 62}
gevers_rectum: {'no': 92, 'CD': 68}
sokol_healthy-cd: {"Crohn's disease": 59, 'Healthy': 15}
sokol_healthy-uc: {'Ulcerative Colitis': 44, 'Healthy': 15}
yatsunenko_malawi-venezuela: {'GAZ:Venezuela': 33, 'GAZ:Malawi': 21}
yatsunenko_sex: {'female': 92, 'male': 37}
yatsunenko_usa-malawi: {'GAZ:United States of America': 129, 'GAZ:Malawi': 21}
hmp_gastro-oral: {'Oral': 1843, 'Gastrointestinal_tract': 227}
hmp_sex: {'male': 98, 'female': 82}
hmp_stool-tongue-paired: {'Stool': 204, 'Tongue_dorsum': 200}
hmp_sub-supragingivalplaque-paired: {'Supragingival_plaque': 205, 'Subgingival_plaque': 203}
cho_control-ct-cecal: {'Control': 10, 'Chlortetracycline': 7}
cho_control-ct-fecal: {'Control': 10, 'Chlortetracycline': 8}
cho_penicillin-vancomycin-cecal: {'Penicillin': 10, 'Vancomycin': 10}
cho_penicillin-vancomycin

In [None]:
# Add embeddings to varm - need to drop first column and index with second column.
# Reindexing with var_names is critical!
import torch

for geom in ["H", "E"]:
    for dim in [2, 4, 8, 16, 32, 64, 128]:
        embeddings = pd.read_csv(f"../data/otu_embeddings/greengenes/{geom}{dim}.csv.gz", index_col=1)
        embeddings = embeddings.drop(columns=["Unnamed: 0.1"])
        embeddings.index = [str(x) for x in embeddings.index]
        mlrepo.varm[f"{geom}{dim}"] = embeddings.reindex(mlrepo.var_names).values

In [7]:
# Generate mixture embeddings
import geoopt
from tqdm.notebook import tqdm

DEVICE = "cpu"
MAN = geoopt.manifolds.PoincareBall().to(DEVICE)

# Euclidean case
for dim in [2, 4, 8, 16, 32, 64, 128]:
    mlrepo.obsm[f"E{dim}"] = mlrepo.X @ mlrepo.varm[f"E{dim}"]  # (n_obs, n_vars) @ (n_vars, dim) -> (n_obs, dim)

# Hyperbolic case
abundances_tensor = torch.tensor(mlrepo.X, device=DEVICE)
for dim in [2, 4, 8, 16, 32, 64, 128]:
    print(f"H{dim}")

    # Need this workaround to not crash the kernel - it's just the iterative version of the midpoint operation
    # (see following cell for proof of concept with H2)
    embeddings_tensor = torch.tensor(mlrepo.varm[f"H{dim}"], device=DEVICE).unsqueeze(0)
    if dim == 128:
        out = [MAN.weighted_midpoint(embeddings_tensor, row).numpy() for row in tqdm(abundances_tensor)]
        out = np.array(out)

    else:
        out = (
            MAN.weighted_midpoint(xs=embeddings_tensor, weights=abundances_tensor, reducedim=[1]).detach().cpu().numpy()
        )

    # Some validation here
    assert out.shape == (mlrepo.n_obs, dim)
    assert not np.isnan(out).any()
    # assert MAN.assert_check_point_on_manifold(out) # Commented out because making the types work is a pain
    # I do it in a subsequent cell instead
    mlrepo.obsm[f"H{dim}"] = out

H2
H4
H8
H16
H32
H64
H128


  0%|          | 0/10037 [00:00<?, ?it/s]

In [None]:
# Proof of concept: our vectorization works

# Get manifold and other shared objects
p2 = geoopt.manifolds.PoincareBall(2)
h2_tensor = torch.tensor(mlrepo.varm["H2"]).unsqueeze(0)
X_tensor = torch.tensor(mlrepo.X)

# Iterative version we use for H128
out = []
for row in X_tensor:
    out.append(p2.weighted_midpoint(h2_tensor, row).numpy())
out = np.array(out)
print(out.shape)
assert not np.isnan(out).any()

# Now do the vectorized version
out_vectorized = p2.weighted_midpoint(h2_tensor, X_tensor, reducedim=[1]).detach().cpu().numpy()
print(out_vectorized.shape)
assert not np.isnan(out_vectorized).any()

# They match!
assert np.allclose(out, out_vectorized)

(10037, 2)
(10037, 2)


In [9]:
# Check that all points are on the manifold
for k, v in mlrepo.varm.items():
    if k.startswith("H"):
        assert MAN.check_point_on_manifold(torch.tensor(v)), f"{k} is not on the manifold"

for k, v in mlrepo.obsm.items():
    if k.startswith("H"):
        assert MAN.check_point_on_manifold(torch.tensor(v)), f"{k} is not on the manifold"

In [None]:
# Add greengenes taxonomic information

gg_tax_path = "../data/greengenes/gg_13_5_taxonomy.txt.gz"

# Read taxonomy file - separators are tab AND semicolon
gg_tax = pd.read_csv(gg_tax_path, index_col=0, sep="\t|;", header=None)
gg_tax.index = [str(x) for x in gg_tax.index]
gg_tax.columns = ["k", "p", "c", "o", "f", "g", "s"]
for col in gg_tax.columns:
    # Strip whitespace and remove the prefix pattern (e.g., "k__", "p__")
    gg_tax[col] = gg_tax[col].str.strip().str.replace(f"{col}__", "", regex=False)

# Add to mlrepo.var
mlrepo.var = mlrepo.var.join(gg_tax, how="left")

  gg_tax = pd.read_csv(gg_tax_path, index_col=0, sep="\t|;", header=None)


# DNABERT-S embeddings and FASTA sequences

In [6]:
!pip install biopython
!pip install transformers
!pip install einops
!pip uninstall triton -y # triton breaks dnabert-s, idk why

[0m

In [7]:
# Load MLRepo anndata from checkpoint
mlrepo = anndata.read_h5ad("../data/mlrepo.h5ad.gz")

In [8]:
# Add FASTA sequences to mlrepo.var
import gzip
from Bio import SeqIO

# Get FASTA sequences by Greengenes ID
gg_fasta_path = "../data/greengenes/gg_13_5.fasta.gz"

# Read FASTA file - Pearson format to avoid warnings due to comments in the Greengenes file
gg_fasta = {rec.id: str(rec.seq) for rec in SeqIO.parse(gzip.open(gg_fasta_path, "rt"), format="fasta")}

# Add to mlrepo.var
assert not (set(mlrepo.var.index) - set(gg_fasta.keys()))  # mlrepo index subset of gg_fasta keys
mlrepo.var["fasta"] = mlrepo.var.index.map(gg_fasta)

del gg_fasta  # This is a big dict, so we'll delete it to save space

In [12]:
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import tqdm

DEVICE = "cuda"

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-S", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-S", trust_remote_code=True)
model.eval()
model = model.to(DEVICE)


# Make embeddings for each fasta
def compute_dnabert_embedding(seq):
    """Mean-pooled embedding of the entire sequence - based on https://github.com/MAGICS-LAB/DNABERT_S README.md"""
    with torch.no_grad():
        inputs = tokenizer(seq, return_tensors="pt")["input_ids"].to(DEVICE)
        hidden_states = model(inputs)[0]  # (1, seq_length, 768)
        embedding_mean = hidden_states[0].mean(dim=0)  # (768,)
    return embedding_mean.cpu().numpy()


mlrepo.varm["dnabert-s"] = np.array([compute_dnabert_embedding(seq) for seq in tqdm(mlrepo.var["fasta"])])
print(mlrepo.varm["dnabert-s"].shape)
assert not np.isnan(mlrepo.varm["dnabert-s"]).any()



  0%|          | 0/27105 [00:00<?, ?it/s]

(27105, 768)


In [15]:
# Add mixture embeddings - same as Euclidean case

mlrepo.obsm["dnabert-s"] = mlrepo.X @ mlrepo.varm["dnabert-s"]

# Save everything

In [16]:
# Final look at mlrepo
mlrepo

AnnData object with n_obs × n_vars = 10037 × 27105
    obs: 'ravel_black-hispanic', 'ravel_nugent-category', 'ravel_nugent-score', 'ravel_ph', 'ravel_white-black', 'dataset', 'gevers_ileum', 'gevers_pcdai-ileum', 'gevers_pcdai-rectum', 'gevers_rectum', 'sokol_healthy-cd', 'sokol_healthy-uc', 'yatsunenko_baby-age', 'yatsunenko_malawi-venezuela', 'yatsunenko_sex', 'yatsunenko_usa-malawi', 'hmp_gastro-oral', 'hmp_sex', 'hmp_stool-tongue-paired', 'hmp_sub-supragingivalplaque-paired', 'cho_control-ct-cecal', 'cho_control-ct-fecal', 'cho_penicillin-vancomycin-cecal', 'cho_penicillin-vancomycin-fecal', 'turnbaugh_obese-lean-all'
    var: 'k', 'p', 'c', 'o', 'f', 'g', 's', 'fasta'
    obsm: 'E128', 'E16', 'E2', 'E32', 'E4', 'E64', 'E8', 'H128', 'H16', 'H2', 'H32', 'H4', 'H64', 'H8', 'dnabert-s'
    varm: 'E128', 'E16', 'E2', 'E32', 'E4', 'E64', 'E8', 'H128', 'H16', 'H2', 'H32', 'H4', 'H64', 'H8', 'dnabert-s'

In [17]:
# Ensure all column names are strings
mlrepo.var.columns = [str(x) for x in mlrepo.var.columns]
mlrepo.obs.columns = [str(x) for x in mlrepo.obs.columns]

# Save mlrepo
mlrepo.write_h5ad("../data/mlrepo.h5ad.gz", compression="gzip")