In [11]:
import sys

sys.path.append("../")

import pandas as pd
from loguru import logger

import src.preprocessing.functions as preprocessing_functions

data_root_dir = "../data/sun_et_al_data/"
columns_to_keep = ["Sample", "Group", "Project", "Project_1"]
studies_to_remove = ["LiS_2021a", "LiS_2021b"]


def print_full_df(x):
    pd.set_option("display.max_rows", None)
    pd.set_option("display.max_columns", None)
    pd.set_option("display.width", None)
    pd.set_option("display.max_colwidth", None)
    display(x)
    pd.reset_option("display.max_rows")
    pd.reset_option("display.max_columns")
    pd.reset_option("display.width")
    pd.reset_option("display.float_format")
    pd.reset_option("display.max_colwidth")

### Preprocessing before splitting

In [14]:
# Get sample group data
sample_group = pd.read_table(f"{data_root_dir}/sample.group", sep="\t", header=0)
logger.info(f"sample_group.shape before removal of studies: {sample_group.shape}")
sample_group = sample_group[~sample_group["Project_1"].isin(studies_to_remove)]
logger.info(f"sample_group.shape after removal of studies: {sample_group.shape}")
logger.info(f"sample_group.shape before column removal: {sample_group.shape}")
sample_group = sample_group[columns_to_keep]
logger.info(f"sample_group_useful.shape after column removal: {sample_group.shape}")
sample_group = sample_group.set_index("Sample")
logger.info(f"sample_group_useful.shape after setting index: {sample_group.shape}")

# Get species profile data
mpa4_species_profile = pd.read_table(
    f"{data_root_dir}/mpa4_species.profile", sep="\t", header=0, index_col=0
)
# Remove species with no reads
mpa4_species_profile = mpa4_species_profile.loc[
    :, mpa4_species_profile.sum(axis=0) >= 1
]

## Remove repeated samples
logger.info(f"sample_group_useful.shape before removal: {sample_group.shape}")
sample_group = sample_group[~sample_group.index.duplicated(keep="first")]
logger.info(f"sample_group_useful.shape after removal: {sample_group.shape}")

# remove samples not in sample_group
logger.info(
    f"mpa4_species_profile.shape before filtering out samples without metadata: {mpa4_species_profile.shape}"
)
samples_to_keep = list(
    set(sample_group.index.tolist()) & set(mpa4_species_profile.columns.tolist())
)
mpa4_species_profile = mpa4_species_profile[samples_to_keep]
logger.info(
    f"mpa4_species_profile.shape after filtering out samples without metadata: {mpa4_species_profile.shape}"
)
mpa4_species_profile = mpa4_species_profile.T
logger.info(
    f"mpa4_species_profile.shape after transposing: {mpa4_species_profile.shape}"
)

# remove samples from sample_group that are not in mpa4_species_profile
logger.info(
    f"sample_group_useful.shape before filtering out samples not in mpa4_species_profile: {sample_group.shape}"
)
sample_group = sample_group.loc[samples_to_keep]
logger.info(
    f"sample_group_useful.shape after filtering out samples not in mpa4_species_profile: {sample_group.shape}"
)

# Normalize the data
mpa4_species_profile = preprocessing_functions.total_sum_scaling(mpa4_species_profile)
logger.info(
    f"mpa4_species_profile summation after normalization: {mpa4_species_profile.sum(axis=1)}"
)


# normalize again

# transform

[32m2025-01-29 12:20:38.599[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1msample_group.shape before removal of studies: (6616, 21)[0m
[32m2025-01-29 12:20:38.609[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1msample_group.shape after removal of studies: (6463, 21)[0m
[32m2025-01-29 12:20:38.613[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1msample_group.shape before column removal: (6463, 21)[0m
[32m2025-01-29 12:20:38.620[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1msample_group_useful.shape after column removal: (6463, 4)[0m
[32m2025-01-29 12:20:38.625[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1msample_group_useful.shape after setting index: (6463, 3)[0m
[32m2025-01-29 12:21:01.342[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1msample_group_useful.shape before removal: (6463, 3)[0m

In [15]:
# prevalence and abundance filtering
# low abundance filtering per study
grouped_sample_group = sample_group.groupby("Project_1")
display(mpa4_species_profile)
for project, samples in grouped_sample_group.groups.items():
    logger.info(f"Project: {project}")
    rows_to_update = mpa4_species_profile.loc[samples]
    feature_prevalence = (rows_to_update > 0.0001).sum(axis=0) / rows_to_update.shape[0]
    low_abundance_features = feature_prevalence < 0.1

    df_masked = rows_to_update.mask(
        low_abundance_features | (rows_to_update <= 0.0001), 0
    )
    mpa4_species_profile.update(df_masked)

display(mpa4_species_profile)
display(mpa4_species_profile.sum(axis=1))

name,s__Phocaeicola_plebeius,s__Faecalibacterium_prausnitzii,s__Ruminococcus_sp_NSJ_71,s__Eubacterium_rectale,s__Bacteroides_uniformis,s__Clostridium_sp_AF15_49,s__Lachnospira_eligens,s__Roseburia_sp_AF02_12,s__Phocaeicola_vulgatus,s__Ruminococcus_bicirculans,...,s__Rodentibacter_myodis,s__Rhodococcus_hoagii,s__Pseudomonas_psychrophila,s__Pseudomonas_sp_DG56_2,s__Providencia_rustigianii,s__Pseudomonas_vranovensis,s__Pseudomonas_taetrolens,s__Pseudomonas_deceptionensis,s__Desulfobulbus_oralis,s__Bacteroides_reticulotermitis
SRR10983017,0.000000,0.000000,0.000000,0.000000,0.001081,0.000000,0.006767,0.000000,0.000969,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CKD_KD-036,0.000000,0.029553,0.000000,0.000487,0.014621,0.000000,0.000000,0.012382,0.024954,0.003701,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13327437,0.000000,0.091379,0.000000,0.157463,0.046527,0.000000,0.067244,0.009116,0.041449,0.101153,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Yu_71,0.000000,0.024430,0.040396,0.016105,0.076397,0.000000,0.000227,0.000000,0.157059,0.052066,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR341655,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.062516,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR8849287,0.000000,0.064171,0.226972,0.008365,0.048112,0.017147,0.010338,0.000000,0.019989,0.000103,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13327475,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR8845657,0.000229,0.031350,0.000000,0.000000,0.082272,0.000000,0.000000,0.000000,0.226140,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SAMEA2737870,0.000000,0.106169,0.000000,0.014949,0.050147,0.000267,0.000000,0.002332,0.079171,0.004624,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


[32m2025-01-29 12:23:22.092[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: ChenB_2020[0m
[32m2025-01-29 12:23:32.522[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: ChuY_2021[0m
[32m2025-01-29 12:23:43.623[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HanL_2021[0m
[32m2025-01-29 12:23:53.742[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HeQ_2017[0m
[32m2025-01-29 12:24:03.939[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HuY_2019[0m
[32m2025-01-29 12:24:13.919[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: HuangR_2020[0m
[32m2025-01-29 12:24:22.972[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mProject: JieZ_2017[0m
[32m2025-01-29 12:24:32.068[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m

name,s__Phocaeicola_plebeius,s__Faecalibacterium_prausnitzii,s__Ruminococcus_sp_NSJ_71,s__Eubacterium_rectale,s__Bacteroides_uniformis,s__Clostridium_sp_AF15_49,s__Lachnospira_eligens,s__Roseburia_sp_AF02_12,s__Phocaeicola_vulgatus,s__Ruminococcus_bicirculans,...,s__Rodentibacter_myodis,s__Rhodococcus_hoagii,s__Pseudomonas_psychrophila,s__Pseudomonas_sp_DG56_2,s__Providencia_rustigianii,s__Pseudomonas_vranovensis,s__Pseudomonas_taetrolens,s__Pseudomonas_deceptionensis,s__Desulfobulbus_oralis,s__Bacteroides_reticulotermitis
SRR10983017,0.000000,0.000000,0.000000,0.000000,0.001081,0.000000,0.006767,0.000000,0.000969,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
CKD_KD-036,0.000000,0.029553,0.000000,0.000487,0.014621,0.000000,0.000000,0.012382,0.024954,0.003701,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13327437,0.000000,0.091379,0.000000,0.157463,0.046527,0.000000,0.067244,0.000000,0.041449,0.101153,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Yu_71,0.000000,0.024430,0.040396,0.016105,0.076397,0.000000,0.000227,0.000000,0.157059,0.052066,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR341655,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.062516,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR8849287,0.000000,0.064171,0.226972,0.008365,0.048112,0.017147,0.010338,0.000000,0.019989,0.000103,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR13327475,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SRR8845657,0.000229,0.031350,0.000000,0.000000,0.082272,0.000000,0.000000,0.000000,0.226140,0.000000,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
SAMEA2737870,0.000000,0.106169,0.000000,0.014949,0.050147,0.000267,0.000000,0.002332,0.079171,0.004624,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


SRR10983017     0.985616
CKD_KD-036      0.986571
SRR13327437     0.985775
Yu_71           0.957424
SRR341655       0.904047
                  ...   
SRR8849287      0.961110
SRR13327475     0.931802
SRR8845657      0.978143
SAMEA2737870    0.998919
HD-8            0.790368
Length: 6303, dtype: float64

In [16]:
# normalize and transform
logger.info(
    f"mpa4_species_profile summation before normalization: {mpa4_species_profile.sum(axis=1)}"
)
mpa4_species_profile = preprocessing_functions.total_sum_scaling(mpa4_species_profile)
logger.info(
    f"mpa4_species_profile summation after normalization: {mpa4_species_profile.sum(axis=1)}"
)

# Centered arcsine transform
logger.info(
    f"mpa4_species_profile summation before centered arcsine transform: {mpa4_species_profile.sum(axis=1)}"
)
mpa4_species_profile = preprocessing_functions.centered_arcsine_transform(
    mpa4_species_profile
)
logger.info(
    f"mpa4_species_profile summation after centered arcsine transform: {mpa4_species_profile.sum(axis=1)}"
)

[32m2025-01-29 12:54:02.308[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mmpa4_species_profile summation before normalization: SRR10983017     0.985616
CKD_KD-036      0.986571
SRR13327437     0.985775
Yu_71           0.957424
SRR341655       0.904047
                  ...   
SRR8849287      0.961110
SRR13327475     0.931802
SRR8845657      0.978143
SAMEA2737870    0.998919
HD-8            0.790368
Length: 6303, dtype: float64[0m
[32m2025-01-29 12:54:02.545[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mmpa4_species_profile summation after normalization: SRR10983017     1.0
CKD_KD-036      1.0
SRR13327437     1.0
Yu_71           1.0
SRR341655       1.0
               ... 
SRR8849287      1.0
SRR13327475     1.0
SRR8845657      1.0
SAMEA2737870    1.0
HD-8            1.0
Length: 6303, dtype: float64[0m
[32m2025-01-29 12:54:02.615[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mmpa4_species_

In [17]:
# Save the data
mpa4_species_profile.to_csv(f"{data_root_dir}/mpa4_species_profile_preprocessed.csv")
sample_group.to_csv(f"{data_root_dir}/sample_group_preprocessed.csv")

---
## Testing dataloader

In [261]:
%load_ext autoreload
%autoreload 2

from torch import manual_seed
from torch.utils.data import DataLoader

manual_seed(0)

from src.data.sun_et_al import BinaryFewShotBatchSampler, MicrobiomeDataset

test_study = ["ChenB_2020"]
val_study = ["ChuY_2021"]

train_df = mpa4_species_profile.loc[
    sample_group.loc[~sample_group["Project_1"].isin(test_study + val_study)].index
]
assert train_df.shape[0] == 5892

test_df = mpa4_species_profile.loc[
    sample_group.loc[sample_group["Project_1"].isin(test_study)].index
]
assert test_df.shape[0] == 231

val_df = mpa4_species_profile.loc[
    sample_group.loc[sample_group["Project_1"].isin(val_study)].index
]
assert val_df.shape[0] == 180

meta_data = sample_group[["Group", "Project_1"]].rename(
    columns={"Project_1": "project", "Group": "label"}
)

train = MicrobiomeDataset(train_df, meta_data.loc[train_df.index])
test = MicrobiomeDataset(test_df, meta_data.loc[test_df.index])
val = MicrobiomeDataset(val_df, meta_data.loc[val_df.index])

sampler = BinaryFewShotBatchSampler(train, 50, True, True)
train_loader = DataLoader(train, batch_sampler=sampler)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [271]:
# samples, labels = next(iter(train_loader))
# print(samples.shape)
# print(labels.shape)
# print(samples)
# print(labels)


97
157
[np.int64(2862), np.int64(2924), np.int64(2938), np.int64(2893), np.int64(2761), np.int64(2820), np.int64(2953), np.int64(2772), np.int64(2968), np.int64(2816), np.int64(2877), np.int64(2935), np.int64(2768), np.int64(2850), np.int64(2773), np.int64(2769), np.int64(2866), np.int64(2872), np.int64(2837), np.int64(2770), np.int64(2867), np.int64(2843), np.int64(2815), np.int64(2864), np.int64(3009), np.int64(2915), np.int64(2808), np.int64(2925), np.int64(2887), np.int64(2765), np.int64(2838), np.int64(2950), np.int64(2817), np.int64(2895), np.int64(2776), np.int64(2759), np.int64(2771), np.int64(2780), np.int64(2818), np.int64(2868), np.int64(2806), np.int64(2763), np.int64(2918), np.int64(2898), np.int64(2910), np.int64(2891), np.int64(2847), np.int64(2827), np.int64(2873), np.int64(2849), np.int64(3011), np.int64(2783), np.int64(2845), np.int64(3004), np.int64(2781), np.int64(2870), np.int64(3006), np.int64(2931), np.int64(2963), np.int64(2941), np.int64(2932), np.int64(2946), 

In [257]:
from src.preprocessing.functions import pandas_label_encoder


m = pandas_label_encoder(meta_data)
m = m.sort_index().reset_index(drop=True).groupby("project")[["label"]]
g = m.get_group(0)
g2 = g.groupby("label").groups

import numpy as np
np.array(list(g2.values())[0])

l = [1, 2, 3,]
l.extend(np.array(list(g2.values())[0]))
print(l)

train.__getitem__(l[5])

(tensor([ 0.4558,  0.1662, -0.0019,  ..., -0.0019, -0.0019, -0.0019]),
 tensor(1.))