In [None]:
import sys

sys.path.append("../..")

import pandas as pd
from loguru import logger

import src.preprocessing.functions as preprocessing_functions
from src.global_vars import BASE_DATA_DIR

data_root_dir = f"{BASE_DATA_DIR}/sun_et_al_data/"
columns_to_keep = ["Sample", "Group", "Project", "Project_1"]
studies_to_remove = ["LiS_2021a", "LiS_2021b"]


def print_full_df(x):
    pd.set_option("display.max_rows", None)
    pd.set_option("display.max_columns", None)
    pd.set_option("display.width", None)
    pd.set_option("display.max_colwidth", None)
    display(x)
    pd.reset_option("display.max_rows")
    pd.reset_option("display.max_columns")
    pd.reset_option("display.width")
    pd.reset_option("display.float_format")
    pd.reset_option("display.max_colwidth")

# Pre-processing

### Preprocessing before splitting

In [2]:
# Get sample group data
sample_group = pd.read_table(f"{data_root_dir}/sample.group", sep="\t", header=0)
# Remove studies
logger.info(f"sample_group.shape before removal of studies: {sample_group.shape}")
sample_group = sample_group[~sample_group["Project_1"].isin(studies_to_remove)]
logger.info(f"sample_group.shape after removal of studies: {sample_group.shape}")

# Keep recommended columns
logger.info(f"sample_group.shape before column removal: {sample_group.shape}")
sample_group = sample_group[columns_to_keep]
logger.info(f"sample_group_useful.shape after column removal: {sample_group.shape}")
# Set index to Sample
sample_group = sample_group.set_index("Sample")
logger.info(f"sample_group_useful.shape after setting index: {sample_group.shape}")

# Get species profile data
mpa4_profile = pd.read_table(
    f"{data_root_dir}/mpa4_genus.profile", sep="\t", header=0, index_col=0
)
# Remove species with no reads
mpa4_profile = mpa4_profile.loc[
    :, mpa4_profile.sum(axis=0) >= 1
]

## Remove repeated samples
logger.info(f"sample_group_useful.shape before removal: {sample_group.shape}")
sample_group = sample_group[~sample_group.index.duplicated(keep="first")]
logger.info(f"sample_group_useful.shape after removal: {sample_group.shape}")

# remove samples not in sample_group
logger.info(
    f"mpa4_species_profile.shape before filtering out samples without metadata: {mpa4_profile.shape}"
)
samples_to_keep = list(
    set(sample_group.index.tolist()) & set(mpa4_profile.columns.tolist())
)
mpa4_profile = mpa4_profile[samples_to_keep]
logger.info(
    f"mpa4_species_profile.shape after filtering out samples without metadata: {mpa4_profile.shape}"
)
mpa4_profile = mpa4_profile.T
logger.info(
    f"mpa4_species_profile.shape after transposing: {mpa4_profile.shape}"
)

# remove samples from sample_group that are not in mpa4_species_profile
logger.info(
    f"sample_group_useful.shape before filtering out samples not in mpa4_species_profile: {sample_group.shape}"
)
sample_group = sample_group.loc[samples_to_keep]
logger.info(
    f"sample_group_useful.shape after filtering out samples not in mpa4_species_profile: {sample_group.shape}"
)

# Normalize the data
logger.info(
    f"mpa4_species_profile summation before normalization: {mpa4_profile.sum(axis=1)}"
)
mpa4_profile = preprocessing_functions.total_sum_scaling(mpa4_profile)
logger.info(
    f"mpa4_species_profile summation after normalization: {mpa4_profile.sum(axis=1)}"
)


# normalize again

# transform

[32m2025-02-18 14:54:56.210[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1msample_group.shape before removal of studies: (6616, 21)[0m
[32m2025-02-18 14:54:56.219[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1msample_group.shape after removal of studies: (6463, 21)[0m
[32m2025-02-18 14:54:56.221[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1msample_group.shape before column removal: (6463, 21)[0m
[32m2025-02-18 14:54:56.223[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1msample_group_useful.shape after column removal: (6463, 4)[0m
[32m2025-02-18 14:54:56.227[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [1msample_group_useful.shape after setting index: (6463, 3)[0m
[32m2025-02-18 14:54:59.270[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1msample_group_useful.shape before removal: (6463, 3)[0

In [4]:
# prevalence and abundance filtering
# low abundance filtering per study
grouped_sample_group = sample_group.groupby("Project_1")
display(mpa4_profile)
for project, samples in grouped_sample_group.groups.items():
    logger.info(f"Project: {project}")
    rows_to_update = mpa4_profile.loc[samples]
    feature_prevalence = (rows_to_update > 0.0001).sum(axis=0) / rows_to_update.shape[0]
    low_abundance_features = feature_prevalence < 0.1

    df_masked = rows_to_update.mask(
        low_abundance_features | (rows_to_update <= 0.0001), 0
    )
    mpa4_profile.update(df_masked)

display(mpa4_profile)
display(mpa4_profile.sum(axis=1))
display(mpa4_profile.sum(axis=1).sort_values(ascending=True))

# save it all
mpa4_profile.to_csv(
    f"{data_root_dir}/mpa4_genus_profile_after_abundane_prevalence_filtering.csv"
)

name,g__Phocaeicola,g__Faecalibacterium,g__Bacteroides,g__Ruminococcus,g__Clostridium,g__Lachnospiraceae_unclassified,g__Lachnospira,g__Roseburia,g__Clostridia_unclassified,g__Phascolarctobacterium,...,g__Sphingobium,g__Agrococcus,g__Candidatus_Sulfotelmatobacter,g__Aspergillus,g__Roseobacter,g__Thermus,g__Buttiauxella,g__Thermobifida,g__Rodentibacter,g__Desulfobulbus
SRR12000208,0.126040,0.104168,0.044533,0.003377,0.033819,0.074033,0.022016,0.044088,0.013315,0.014045,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
LD-25,0.169447,0.092121,0.327071,0.054479,0.010358,0.034820,0.007015,0.027954,0.004289,0.008102,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR11992769,0.064642,0.082530,0.029992,0.004427,0.021684,0.010620,0.012481,0.042354,0.016721,0.027500,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13327549,0.004882,0.022360,0.003392,0.000614,0.017953,0.001474,0.000000,0.007582,0.000324,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR10403508,0.255034,0.132798,0.043195,0.022298,0.059294,0.007523,0.099071,0.133527,0.002949,0.017673,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR9038099,0.011448,0.078311,0.112935,0.004297,0.020476,0.042588,0.007385,0.002047,0.165319,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Yu_5,0.019144,0.005669,0.042367,0.003856,0.003608,0.006766,0.051646,0.000526,0.129221,0.008284,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR9075352,0.101948,0.090126,0.123341,0.009415,0.002299,0.016009,0.038844,0.083962,0.001573,0.007370,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13077794,0.000000,0.000000,0.205626,0.000000,0.031214,0.025321,0.000000,0.016573,0.029801,0.017485,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


[32m2025-02-18 14:55:40.186[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: ChenB_2020[0m
[32m2025-02-18 14:55:41.768[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: ChuY_2021[0m
[32m2025-02-18 14:55:43.870[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HanL_2021[0m
[32m2025-02-18 14:55:45.862[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HeQ_2017[0m
[32m2025-02-18 14:55:47.683[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HuY_2019[0m
[32m2025-02-18 14:55:49.599[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HuangR_2020[0m
[32m2025-02-18 14:55:51.418[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: JieZ_2017[0m
[32m2025-02-18 14:55:53.388[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m

name,g__Phocaeicola,g__Faecalibacterium,g__Bacteroides,g__Ruminococcus,g__Clostridium,g__Lachnospiraceae_unclassified,g__Lachnospira,g__Roseburia,g__Clostridia_unclassified,g__Phascolarctobacterium,...,g__Sphingobium,g__Agrococcus,g__Candidatus_Sulfotelmatobacter,g__Aspergillus,g__Roseobacter,g__Thermus,g__Buttiauxella,g__Thermobifida,g__Rodentibacter,g__Desulfobulbus
SRR12000208,0.126040,0.104168,0.044533,0.003377,0.033819,0.074033,0.022016,0.044088,0.013315,0.014045,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
LD-25,0.169447,0.092121,0.327071,0.054479,0.010358,0.034820,0.007015,0.027954,0.004289,0.008102,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR11992769,0.064642,0.082530,0.029992,0.004427,0.021684,0.010620,0.012481,0.042354,0.016721,0.027500,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13327549,0.004882,0.022360,0.003392,0.000614,0.017953,0.001474,0.000000,0.007582,0.000324,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR10403508,0.255034,0.132798,0.043195,0.022298,0.059294,0.007523,0.099071,0.133527,0.002949,0.017673,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR9038099,0.011448,0.078311,0.112935,0.004297,0.020476,0.042588,0.007385,0.002047,0.165319,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Yu_5,0.019144,0.005669,0.042367,0.003856,0.003608,0.006766,0.051646,0.000526,0.129221,0.008284,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR9075352,0.101948,0.090126,0.123341,0.009415,0.002299,0.016009,0.038844,0.083962,0.001573,0.007370,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13077794,0.000000,0.000000,0.205626,0.000000,0.031214,0.025321,0.000000,0.016573,0.029801,0.017485,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


SRR12000208    0.987458
LD-25          0.983398
SRR11992769    0.998567
SRR13327549    0.998785
SRR10403508    0.999278
                 ...   
SRR9038099     0.965188
Yu_5           0.971323
SRR9075352     0.949121
SRR13077794    0.995759
SRR6066175     0.711810
Length: 6303, dtype: float64

ERR1620261     0.066127
SRR13077661    0.092671
SRR16124320    0.235723
SRR13077864    0.240482
SRR10403516    0.246404
                 ...   
ERR1190555     0.999982
Yu_11          0.999985
ERR1620313     1.000000
SRR6504899     1.000000
SRR13077801    1.000000
Length: 6303, dtype: float64

In [5]:
# normalize and transform
logger.info(
    f"mpa4_species_profile summation before normalization: {mpa4_profile.sum(axis=1)}"
)
mpa4_profile = preprocessing_functions.total_sum_scaling(mpa4_profile)
logger.info(
    f"mpa4_species_profile summation after normalization: {mpa4_profile.sum(axis=1)}"
)

# Centered arcsine transform
logger.info(
    f"mpa4_species_profile summation before centered arcsine transform: {mpa4_profile.sum(axis=1)}"
)
mpa4_profile = preprocessing_functions.centered_arcsine_transform(
    mpa4_profile
)
logger.info(
    f"mpa4_species_profile summation after centered arcsine transform: {mpa4_profile.sum(axis=1)}"
)

[32m2025-02-18 14:57:46.255[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mmpa4_species_profile summation before normalization: SRR12000208    0.987458
LD-25          0.983398
SRR11992769    0.998567
SRR13327549    0.998785
SRR10403508    0.999278
                 ...   
SRR9038099     0.965188
Yu_5           0.971323
SRR9075352     0.949121
SRR13077794    0.995759
SRR6066175     0.711810
Length: 6303, dtype: float64[0m
[32m2025-02-18 14:57:46.340[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mmpa4_species_profile summation after normalization: SRR12000208    1.0
LD-25          1.0
SRR11992769    1.0
SRR13327549    1.0
SRR10403508    1.0
              ... 
SRR9038099     1.0
Yu_5           1.0
SRR9075352     1.0
SRR13077794    1.0
SRR6066175     1.0
Length: 6303, dtype: float64[0m
[32m2025-02-18 14:57:46.372[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mmpa4_species_profile summation bef

In [6]:
# Save the data
mpa4_profile.to_csv(f"{data_root_dir}/mpa4_genus_profile_preprocessed.csv")
sample_group.to_csv(f"{data_root_dir}/sample_group_genus_preprocessed.csv")

---
# Testing dataloader

In [None]:
%load_ext autoreload
%autoreload 2

from torch import manual_seed
from torch.utils.data import DataLoader

manual_seed(0)

from src.data.sun_et_al import BinaryFewShotBatchSampler, MicrobiomeDataset

test_study = ["ChenB_2020"]
val_study = ["ChuY_2021"]

train_df = mpa4_profile.loc[
    sample_group.loc[~sample_group["Project_1"].isin(test_study + val_study)].index
]
assert train_df.shape[0] == 5892

test_df = mpa4_profile.loc[
    sample_group.loc[sample_group["Project_1"].isin(test_study)].index
]
assert test_df.shape[0] == 231

val_df = mpa4_profile.loc[
    sample_group.loc[sample_group["Project_1"].isin(val_study)].index
]
assert val_df.shape[0] == 180

meta_data = sample_group[["Group", "Project_1"]].rename(
    columns={"Project_1": "project", "Group": "label"}
)

train = MicrobiomeDataset(train_df, meta_data.loc[train_df.index])
test = MicrobiomeDataset(test_df, meta_data.loc[test_df.index])
val = MicrobiomeDataset(val_df, meta_data.loc[val_df.index])

sampler = BinaryFewShotBatchSampler(train, 50, True, True)
train_loader = DataLoader(train, batch_sampler=sampler)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# samples, labels = next(iter(train_loader))
# print(samples.shape)
# print(labels.shape)
# print(samples)
# print(labels)


97
157
[np.int64(2862), np.int64(2924), np.int64(2938), np.int64(2893), np.int64(2761), np.int64(2820), np.int64(2953), np.int64(2772), np.int64(2968), np.int64(2816), np.int64(2877), np.int64(2935), np.int64(2768), np.int64(2850), np.int64(2773), np.int64(2769), np.int64(2866), np.int64(2872), np.int64(2837), np.int64(2770), np.int64(2867), np.int64(2843), np.int64(2815), np.int64(2864), np.int64(3009), np.int64(2915), np.int64(2808), np.int64(2925), np.int64(2887), np.int64(2765), np.int64(2838), np.int64(2950), np.int64(2817), np.int64(2895), np.int64(2776), np.int64(2759), np.int64(2771), np.int64(2780), np.int64(2818), np.int64(2868), np.int64(2806), np.int64(2763), np.int64(2918), np.int64(2898), np.int64(2910), np.int64(2891), np.int64(2847), np.int64(2827), np.int64(2873), np.int64(2849), np.int64(3011), np.int64(2783), np.int64(2845), np.int64(3004), np.int64(2781), np.int64(2870), np.int64(3006), np.int64(2931), np.int64(2963), np.int64(2941), np.int64(2932), np.int64(2946), 

: 

In [None]:
from src.preprocessing.functions import pandas_label_encoder

m = pandas_label_encoder(meta_data)
m = m.sort_index().reset_index(drop=True).groupby("project")[["label"]]
g = m.get_group(0)
g2 = g.groupby("label").groups

import numpy as np

np.array(list(g2.values())[0])

l = [
    1,
    2,
    3,
]
l.extend(np.array(list(g2.values())[0]))
print(l)

train.__getitem__(l[5])

(tensor([ 0.4558,  0.1662, -0.0019,  ..., -0.0019, -0.0019, -0.0019]),
 tensor(1.))

# Meta-Learning

In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
from importlib import import_module

sys.path.append("../../")

import pandas as pd
import torch
from sklearn.decomposition import PCA
from sklearn.preprocessing import Normalizer
from torch import nn
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

import src.data.sun_et_al as hf
import src.models.maml as maml
import src.models.reptile as rp
import wandb


def get_studies_desired_from_sun_et_al(
    data: pd.DataFrame, metadata: pd.DataFrame, study: list
):
    """Get the studies desired from the Sun et al data.

    Args:
        data: The data to filter. Index should be samples.
        metadata: The metadata to filter. Index should be samples.
        studies: The studies to keep. Should be in the Project_1 column of the metadata.

    Returns:
        tuple: data, metadata dataframes with only the studies of interest
    """
    # Filter metadata to only include the studies of interest
    metadata = metadata[metadata["Project_1"].isin(study)]

    # Filter data to only include samples that are in the metadata
    data = data.loc[metadata.index]

    return data, metadata


def split_sun_et_al_data(data: pd.DataFrame, metadata: pd.DataFrame, test, val):
    """Split the data into train, test and validation sets.

    Args:
        data: The data to split. Index should be samples.
        metadata: The metadata to split. Index should be samples.
        test: The studies to use for testing.
        val: The studies to use for validation.

    Returns:
        tuple: train, test, val dataframes
    """
    if not isinstance(test, list):
        test = [test]
    if not isinstance(val, list):
        val = [val]

    test_data, test_metadata = get_studies_desired_from_sun_et_al(data, metadata, test)
    val_data, val_metadata = get_studies_desired_from_sun_et_al(data, metadata, val)

    train_data = data.drop(test_data.index)
    train_data = train_data.drop(val_data.index)

    train_metadata = metadata.drop(index=test_metadata.index)
    train_metadata = train_metadata.drop(index=val_metadata.index)

    return train_data, test_data, val_data, train_metadata, test_metadata, val_metadata


def column_rename_for_sun_et_al_metadata(metadata: pd.DataFrame) -> pd.DataFrame:
    metadata = metadata[["Group", "Project_1"]]
    metadata = metadata.rename(columns={"Group": "label", "Project_1": "project"})
    return metadata


def pca_reduction(
    train_data,
    test_data,
    val_data,
    n_components_reduction_factor: int,
    use_cache: bool = False,
):
    if not use_cache:
        pca = PCA(
            n_components=int(train_data.shape[1] // n_components_reduction_factor)
        )
        print("fitting and transforming")
        train_data = pd.DataFrame(pca.fit_transform(train_data), index=train_data.index)
        print("transforming")
        test_data = pd.DataFrame(pca.transform(test_data), index=test_data.index)
        print("transforming")
        val_data = pd.DataFrame(pca.transform(val_data), index=val_data.index)
        train_data.to_csv("train_data_PCA.csv")
        test_data.to_csv("test_data_PCA.csv")
        val_data.to_csv("val_data_PCA.csv")
    else:
        train_data = pd.read_csv("train_data_PCA.csv", index_col=0)
        test_data = pd.read_csv("test_data_PCA.csv", index_col=0)
        val_data = pd.read_csv("val_data_PCA.csv", index_col=0)

    return train_data, test_data, val_data


def main(
    model_script: str,
    model_name: str,
    abundance_file: pd.DataFrame,
    metadata_file: pd.DataFrame,
    test_study: list,
    val_study: list,
    outer_lr_range: tuple[float, float],
    inner_lr_range: tuple[float, float],
    inner_rl_reduction_factor: int,
    n_gradient_steps: int,
    n_parallel_tasks: int,
    n_epochs: int,
    train_k_shot: int,
    eval_k_shot: int = None,
    n_components_reduction_factor: int = 0,  # 0 or 1 for no PCA at all
    use_cached_pca: bool = False,
    do_normalization_before_scaling: bool = True,
    scale_factor_before_training: int = 100,
    loss_fn: str = "BCELog",
    use_wandb: bool = True,
):
    if loss_fn == "BCELog":
        loss_fn = nn.BCEWithLogitsLoss()
    else:
        raise ValueError("Loss function not recognized.")

    sun_et_al_abundance = pd.read_csv(
        f"{data_root_dir}/{abundance_file}",
        index_col=0,
        header=0,
    )

    sun_et_al_metadata = pd.read_csv(
        f"{data_root_dir}/{metadata_file}",
        index_col=0,
        header=0,
    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Set up file logging
    # logger_path = get_run_dir_for_experiment(misc_config) / "log.log"
    # logger.add(logger_path, colorize=True, level="DEBUG")
    # logger.info("Setting up everything")

    job_id = os.getenv("SLURM_JOB_ID")
    tax_level = abundance_file.split("_")[1]
    config = {
        "model_script": model_script,
        "model_name": model_name,
        "abundance_file": abundance_file,
        "metadata_file": metadata_file,
        "test_study": test_study,
        "val_study": val_study,
        "outer_lr_range": outer_lr_range,
        "inner_lr_range": inner_lr_range,
        "n_gradient_steps": n_gradient_steps,
        "n_parallel_tasks": n_parallel_tasks,
        "n_epochs": n_epochs,
        "train_k_shot": train_k_shot,
        "eval_k_shot": eval_k_shot,
        "n_components_reduction_factor": n_components_reduction_factor,
        "use_cache_pca": use_cached_pca,
        "do_normalization_before_scaling": do_normalization_before_scaling,
        "scale_factor_before_training": scale_factor_before_training,
        "loss_fn": loss_fn,
        "use_wandb": use_wandb,
        "device": device,
        "job_id": job_id,
    }
    wandb_base_tags = [
        "t_s" + str(test_study),
        "v_s" + str(val_study),
        "m_" + model_name,
        "j_" + job_id if job_id else "j_local",
        "tax_" + tax_level,
        "t_k" + str(train_k_shot),
        "e_k" + str(eval_k_shot),
    ]

    wand_name = f"w{model_name}_ts{test_study}_vs{val_study}_j{job_id}_tax{tax_level}_tk{train_k_shot}_ek{eval_k_shot}"

    # Initialize wandb if enabled
    if use_wandb:
        wandb.init(
            project="meta-learning",
            name=wand_name,
            config=config,
            group="MAML",
            tags=wandb_base_tags,
        )
    else:
        wandb.init(
            name=wand_name,
            mode="disabled",
            config=config,
            project="meta-learning",
            group="MAML",
            tags=wandb_base_tags,
        )

    logger.success("wandb init done")

    sun_et_al_metadata = sun_et_al_metadata.sort_index()
    sun_et_al_abundance = sun_et_al_abundance.sort_index()

    train_data, test_data, val_data, train_metadata, test_metadata, val_metadata = (
        split_sun_et_al_data(
            sun_et_al_abundance, sun_et_al_metadata, test_study, val_study
        )
    )

    if n_components_reduction_factor != 0 and n_components_reduction_factor != 1:
        train_data, test_data, val_data = pca_reduction(
            train_data, test_data, val_data, use_cache=use_cached_pca
        )

    # tts
    # train_data = preprocessing_functions.total_sum_scaling(train_data)
    # test_data = preprocessing_functions.total_sum_scaling(test_data)
    # val_data = preprocessing_functions.total_sum_scaling(val_data)

    # centered log ratio transform
    # replace_zero_with = train_data[train_data > 0].min().min() / 100
    # train_data = preprocessing_functions.centered_log_ratio(
    #     train_data, replace_zero_with=replace_zero_with
    # )
    # test_data = preprocessing_functions.centered_log_ratio(
    #     test_data, replace_zero_with=replace_zero_with
    # )
    # val_data = preprocessing_functions.centered_log_ratio(
    #     val_data, replace_zero_with=replace_zero_with
    # )

    # normalize the data for deep learning
    if do_normalization_before_scaling:
        train_data = pd.DataFrame(
            Normalizer().fit_transform(train_data),
            index=train_data.index,
            columns=train_data.columns,
        )
        if test_study:
            test_data = pd.DataFrame(
                Normalizer().fit_transform(test_data),
                index=test_data.index,
                columns=test_data.columns,
            )
        val_data = pd.DataFrame(
            Normalizer().fit_transform(val_data),
            index=val_data.index,
            columns=val_data.columns,
        )

    train_data = train_data * scale_factor_before_training
    test_data = test_data * scale_factor_before_training
    val_data = val_data * scale_factor_before_training

    train_metadata = column_rename_for_sun_et_al_metadata(train_metadata)
    test_metadata = column_rename_for_sun_et_al_metadata(test_metadata)
    val_metadata = column_rename_for_sun_et_al_metadata(val_metadata)

    # Create Datasets for DataLoader
    train = hf.MicrobiomeDataset(train_data, train_metadata)
    test = hf.MicrobiomeDataset(test_data, test_metadata)
    val = hf.MicrobiomeDataset(val_data, val_metadata)

    # Create DataLoaders
    sampler = hf.BinaryFewShotBatchSampler(
        train, train_k_shot, include_query=True, shuffle=True
    )
    train_loader = DataLoader(train, batch_sampler=sampler)

    if eval_k_shot is None:
        eval_k_shot = train_k_shot

    sampler = hf.BinaryFewShotBatchSampler(
        test,
        eval_k_shot,
        include_query=True,
        shuffle=False,
        shuffle_once=False,
        training=False,
    )
    test_loader = DataLoader(test, batch_sampler=sampler)

    sampler = hf.BinaryFewShotBatchSampler(
        val,
        eval_k_shot,
        include_query=True,
        shuffle=False,
        shuffle_once=False,
        training=False,
    )
    val_loader = DataLoader(val, batch_sampler=sampler)

    # Get model
    model_module = import_module(model_script)
    n_features = train_data.shape[1]
    assert (
        n_features == test_data.shape[1] == val_data.shape[1]
    ), "Number of features of train, test and val must be the same."

    # Simple model to test
    model = model_module.get_model(model_name)(n_features).to(device)

    # Instantiate the Reptile meta-learner.
    # reptile = rp.Reptile(
    #     model=model,
    #     train_n_gradient_steps=n_gradient_steps,
    #     eval_n_gradient_steps=n_gradient_steps,
    #     device=device,
    #     # loss_function=loss_fn,
    #     meta_optimizer=meta_optimizer,
    #     inner_lr=inner_lr,
    #     outer_lr=outer_lr,
    #     k_shot=k_shot,
    # )

    # reptile.fit(
    #     train_dataloader=train_loader,
    #     n_epochs=n_epochs,
    #     n_parallel_tasks=n_parallel_tasks,
    #     evaluate_train=True,
    #     val_dataloader=val_loader,
    # )

    # Do MAML
    MAML = maml.MAML(
        model=model,
        train_n_gradient_steps=n_gradient_steps,
        eval_n_gradient_steps=n_gradient_steps,
        device=device,
        inner_lr_range=inner_lr_range,
        inner_rl_reduction_factor=inner_rl_reduction_factor,
        outer_lr_range=outer_lr_range,
        k_shot=train_k_shot,
        loss_fn=loss_fn,
    )

    MAML.fit(
        train_dataloader=train_loader,
        n_epochs=n_epochs,
        n_parallel_tasks=n_parallel_tasks,
        evaluate_train=True,
        val_dataloader=val_loader,
    )

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from sklearn.preprocessing import QuantileTransformer

sun_et_al_abundance = pd.read_csv(
    f"{data_root_dir}/mpa4_species_profile_after_abundane_prevalence_filtering.csv",
    index_col=0,
    header=0,
)
sun_et_al_metadata = pd.read_csv(
    f"{data_root_dir}/sample_group_preprocessed.csv",
    index_col=0,
    header=0,
)

# normalization/transformation testing different methods
## study-wise quantile transformation
grouped_samples = sun_et_al_metadata.groupby(["Project_1"])
for group, idx in grouped_samples.groups.items():
    rows_df = sun_et_al_abundance.loc[idx, :]
    print(f"qunatile transform for group: {group}")
    rows_df = pd.DataFrame(
        QuantileTransformer(n_quantiles=rows_df.shape[0]).fit_transform(rows_df),
        columns=rows_df.columns,
        index=rows_df.index,
    )
    sun_et_al_abundance.update(rows_df)

sun_et_al_abundance.to_csv(
    f"{data_root_dir}/mpa4_species_profile_quantile_transformed.csv"
)


KeyboardInterrupt: 

In [9]:
main(
    "src.models.models",
    "model2",
    "mpa4_species_profile_preprocessed.csv",
    "sample_group_species_preprocessed.csv",
    "",
    "JieZ_2017",
    outer_lr_range=(1, 1),
    inner_lr_range=(0.5, 0.001),
    inner_rl_reduction_factor=2,
    n_epochs=10,
    train_k_shot=10,
    n_gradient_steps=5, # TODO check why more gradient steps and parallel tasks gives nan for loss
    n_parallel_tasks=5,
    n_components_reduction_factor=0,
    use_cached_pca=False,
    do_normalization_before_scaling=True,
    scale_factor_before_training=100,
    loss_fn="BCELog",
)

# from torch.utils.data import TensorDataset

# # N = 2**10
# # X = torch.randn(N, 2)
# # y = (X.sum(dim=1) > 0).long()  # Label is 1 if sum > 0, else 0.
# # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # # Define a simple model: 2 input features, one hidden layer with 10 units, 2 outputs.
# # model = torch.nn.Sequential(
# #     torch.nn.Linear(2, 10),
# #     torch.nn.ReLU(),
# #     torch.nn.Linear(10, 2),
# # ).to(device)

# # # Define the loss function.
# # loss_fn = nn.CrossEntropyLoss()

# # # Outer (meta) optimizer: using SGD.
# # meta_lr = 1  # Outer learning rate.

# # # Inner-loop learning rate (α).
# # inner_lr = 0.5

# # # Number of inner-loop gradient steps per task.
# # n_gradient_steps = 5

# # ## testing
# # train_dataset = TensorDataset(X, y)
# # # Use a small batch size so that each batch is treated as a "task".
# # train_dataloader = DataLoader(train_dataset, batch_size=40, shuffle=True)
# # N = 2**10
# # X = torch.randn(N, 2)
# # y = (X.sum(dim=1) > 0).long()  # Label is 1 if sum > 0, else 0.
# # test_dataset = TensorDataset(X, y)
# # test_dataloader = DataLoader(test_dataset, batch_size=40, shuffle=True)
# # meta_optimizer = SGD(model.parameters(), lr=inner_lr)
# # MAML = maml.MAML(
# #     model=model,
# #     train_n_gradient_steps=n_gradient_steps,
# #     eval_n_gradient_steps=n_gradient_steps,
# #     loss_fn=loss_fn,
# #     device=device,
# #     meta_optimizer=meta_optimizer,
# #     outer_lr_range=(1, 0.00001),
# #     inner_lr_range=(0.5, 0.00001),
# #     k_shot=10,
# # )

# # MAML.fit(train_dataloader=train_dataloader, n_epochs=100, n_parallel_tasks=5, evaluate_train=True, val_dataloader=test_dataloader)



[32m2025-02-21 09:37:42.587[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mmain[0m:[36m211[0m - [32m[1mwandb init done[0m
[32m2025-02-21 09:37:43.590[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mevaluate[0m:[36m272[0m - [1mEvaluation after epoch 0: Loss = 0.71[0m
[32m2025-02-21 09:37:43.593[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mevaluate[0m:[36m273[0m - [1mAccuracy = 0.61, F1 = 0.61, Precision = 0.69, Recall = 0.55, ROC-AUC = 0.67[0m
[32m2025-02-21 09:38:14.987[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mfit[0m:[36m128[0m - [1mEpoch 1 complete at iteration 77/770 with 5 parallel tasks and 380 total tasks. Reinitializing DataLoader for next epoch.[0m
[32m2025-02-21 09:38:15.034[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mevaluate[0m:[36m272[0m - [1mEvaluation after epoch 1: Loss = 0.65[0m
[32m2025-02-21 09:38:15.035[0m | [1mINFO    [0m | [36msrc.models.maml[0m:[36mevaluate[0m:[36m273[0m - [1m