In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
level_mapping = {
    'd': 'domain',
    'p': 'phylum',
    'c': 'class',
    'o': 'order',
    'f': 'family',
    'g': 'genus',
    's': 'species'
}

# Data loading & exploration

In [0]:
df_raw = spark.read.table("onesource_eu_dev_rni.onebiome.mpa_4")
display(df_raw)

In [0]:
num_samples = df_raw.count()
print(f"Number of samples in df_raw: {num_samples}")

num_columns = len(df_raw.columns)
print(f"Number of columns in df_raw: {num_columns}")



In [0]:
memory_size_gb = df_raw.rdd.map(lambda row: len(str(row))).sum() / (1024 ** 3)
print(f"Memory size of df_raw: {memory_size_gb:.2f} GB")

In [0]:
print(df_raw.select('version').distinct().collect())
df_raw = df_raw.drop('version')

### Check the finest taxonomic level

In [0]:
import pandas as pd
df_raw = df_raw.toPandas().set_index('barcode')
df_raw = df_raw.apply(pd.to_numeric, errors='coerce')
df_raw

In [0]:
for i, col in enumerate(df_raw.columns[26:35]):
    print(i+26, col)

In [0]:
for i in range(26, 30):
    print(i, df_raw.columns[i], df_raw.iloc[:, i].sum())

In [0]:
df_raw.iloc[:, 11:21]

In [0]:
import pandas as pd

# Check if the values in the specified columns are the same
columns_to_check = df_raw.columns[df_raw.columns.str.startswith('barcode')]
print(columns_to_check)
all(df_raw[columns_to_check[0]].equals(df_raw[col]) for col in columns_to_check[1:4])


In [0]:
df_raw.drop(columns=columns_to_check, inplace=True)

### bacteria percentage

In [0]:

df_raw['d__bacteria'].hist()


### Unique values in at each taxonomic level

In [0]:
import matplotlib.pyplot as plt

# Count the occurrences of each unique value 
from collections import Counter
taxa_counts = Counter([level_mapping[col.split('|')[-1][0]] for col in df_raw.columns])

# Plot the counts
plt.figure(figsize=(10, 6))
plt.bar(taxa_counts.keys(),taxa_counts.values(), edgecolor='k')
for i, (taxa, count) in enumerate(taxa_counts.items()):
    plt.text(i, count + 0.5, str(count), ha='center', va='bottom')
plt.xlabel('Taxa')
plt.ylabel('Count')
plt.title('Count of Each Unique Value in Each Taxonomic level')
plt.xticks(rotation=90)
plt.show()

In [0]:
lvls = ['o', 'f', 'g', 's']
dict_df_level = dict()
for lvl in lvls:
    level = level_mapping[lvl]

    # filter the columns based on the level
    df_level = df_raw[[col for col in df_raw.columns if col.split('|')[-1].startswith(lvl)]].dropna(axis=1, how='all')
    num_col = df_level.shape[1]

    # remove all zero columns -> absence
    df_level = df_level.fillna(0).loc[:, (df_level != 0).any(axis=0)]  
    print("number of all zero columns: ", num_col - df_level.shape[1])

    # rename the columns
    if lvl == 's':  
        # species level: if species absent, complete by genus_unknown 
        df_level.columns = [col.split('|')[-1][3:] if col.split('|')[-1][3:] != '' else col.split('|')[-2][3:] + '_' for col in df_level.columns]  # get genus_species
        num_underscore_cols = (df_level.columns == '_').sum()
        if num_underscore_cols > 0:
            print(f"number of columns lacking both genus and species: {num_underscore_cols}")
            df_level = df_level.loc[:, df_level.columns != '_']
        df_level.columns = [col + 'unknown' if col.endswith('_') else col for col in df_level.columns]
        df_level.columns = df_level.columns.str.capitalize()  # capitalize first letter to get Genus_species

    else:  
        # other levels, only for stats, will not be used in the next steps
        df_level.columns = [col.split('|')[-1][3:] for col in df_level.columns]  
        df_level.columns = df_level.columns.str.lower()

    df_level = df_level.drop(columns=[''], errors='ignore')  # remove if level information absent
    print(f'level: {level}, shape: {df_level.shape}')
    dict_df_level[level] = df_level

In [0]:

import matplotlib.pyplot as plt
levels = ['order', 'family', 'genus', 'species']
fig, axes = plt.subplots(2, 2, figsize=(8, 6))
axes = axes.flatten()
for i, level in enumerate(levels):
    dict_df_level[level].fillna(0).sum(axis=1).hist(ax=axes[i], bins=100)
    axes[i].xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.4f}'))
    axes[i].set_title(f'Distribution of sum for {level} level')
plt.tight_layout()
plt.show()


In [0]:
def save_df_to_table(df, table_name):
    df.reset_index(inplace=True)
    df_spark = spark.createDataFrame(df)
    spark.sql(f"DROP TABLE IF EXISTS onesource_eu_dev_rni.onebiome.{table_name}")
    df_spark.write.format("delta").mode("overwrite").saveAsTable(f"onesource_eu_dev_rni.onebiome.{table_name}")


In [0]:
# for level, df in dict_df_level.items():
#     print(level)
#     save_df_to_table(df, f'mpa4_{level}_level')

save_df_to_table(dict_df_level['species'], 'mpa4_species_level_reformated')