# Data loading & exploration

In [0]:
df = spark.read.table("onesource_eu_dev_rni.onebiome.mpa_2")
display(df)

In [0]:
memory_size_gb = df.rdd.map(lambda row: len(str(row))).sum() / (1024 ** 3)
print(f"Memory size of df: {memory_size_gb:.2f} GB")

In [0]:
df.printSchema()

## Column understanding

### IDs

In [0]:
# _id VS userID
print(df.select("_id").distinct().count())
print(df.select("userId").distinct().count())

In [0]:
# check unique _id and userId
id_userid_mapper = df.select("_id", "userId").distinct()
redundant_user_ids = id_userid_mapper.groupBy("userId").count().filter("count > 1").select("userId")
redundant_rows = id_userid_mapper.join(redundant_user_ids, on="userId").orderBy("userId")
display(redundant_rows)

**Each userId can correspond to multiple _id --> userId represents individual while _id represents sample**

### Version

In [0]:
# check the column 'version'
print(df.select("_id", "userId", "version").distinct().count())
print(df.select("_id", "userId").distinct().count())

In [0]:
display(df.select("_id", "userId", "version").distinct().groupBy("version").count())

**--> ignore version**

### how many data points does each user have?


In [0]:
display(df.select("userId", "_id").distinct().groupBy("userId").count())

In [0]:
display(id_userid_mapper)

In [0]:
import matplotlib.pyplot as plt
sample_counts = id_userid_mapper.toPandas()['userId'].value_counts()
plt.bar(sample_counts.value_counts().index, sample_counts.value_counts().values)
plt.xlabel("Number of Samples per User")
plt.ylabel("Count")
plt.title("Distribution of Number of Samples per User")
for i, v in enumerate(sample_counts.value_counts().values):
    plt.text(sample_counts.value_counts().index[i], v, str(v), ha='center', va='bottom')
plt.xticks(rotation=45);

### check ID mapping with MPA4


In [0]:
barcode_mpa4 = spark.table("onesource_eu_dev_rni.onebiome.mpa_4").select("barcode")
missing_barcodes = barcode_mpa4.join(df.select("_id").distinct(), barcode_mpa4["barcode"] == df["_id"], "left_anti")
display(missing_barcodes)

--> "barcode" in mpa4 corresponds to "_id" in mpa2

  two samples in mpa4 are not found in mpa2

###Abundance

In [0]:
import matplotlib.pyplot as plt
# Plot distribution of the column 'abundance'
plt.hist(df.select("abundance").toPandas(), bins=100, edgecolor='k')
plt.xlabel('Abundance')
plt.ylabel('Frequency')
plt.title('Distribution of Abundance')
plt.show()

**The values of abundance is between 0 and 100 --> relative abundance**

##Taxanomic level selection

In [0]:
from pyspark.sql import functions as F

abundance_sum_by_taxonomic_level = df.groupBy('_id').agg(
    F.sum(F.when((df['family'].isNotNull()) & (df['genus'].isNull()), df['abundance']).otherwise(0)).alias('family_sum'),
    F.sum(F.when((df['genus'].isNotNull()) & (df['species'].isNull()), df['abundance']).otherwise(0)).alias('genus_sum'),
    F.sum(F.when((df['species'].isNotNull()) & (df['strain'].isNull()), df['abundance']).otherwise(0)).alias('species_sum'),
    F.sum(F.when(df['strain'].isNotNull(), df['abundance']).otherwise(0)).alias('strain_sum')
)



# Plot the distribution of each column
abundance_sum_by_taxonomic_level.toPandas().hist(bins=100)
plt.suptitle('Distribution of Values in abundance_sum_by_taxonomic_level')
plt.show()

**Species level: most samples have sum abundance near 100%**

**Strain level: most samples > 75%**

**--> choose species or genus level (genus to have less feature)**

# Create, check and save taxa tables

## species

In [0]:
df_species = df.filter(df['species'].isNotNull() & df['strain'].isNull()).groupBy("_id").pivot("species").sum("abundance").toPandas()
print(df_species.shape)
display(df_species)

In [0]:
df_species.sum(axis=1).hist(bins=100)


In [0]:
sns.heatmap(df_species)

In [0]:
df_species_spark = spark.createDataFrame(df_species)
df_species_spark.write.format("delta").mode("overwrite").saveAsTable("onesource_eu_dev_rni.onebiome.mpa_2_species_table")

## Genus

In [0]:
df_genus = df.filter(df['genus'].isNotNull() & df['species'].isNull()).groupBy("_id").pivot("genus").sum("abundance").toPandas()
print(df_genus.shape)
display(df_genus)

In [0]:
import matplotlib.pyplot as plt
sns.histplot(df_genus.sum(axis=1))
plt.xlim(99.999, 100.001)

In [0]:
sns.heatmap(df_genus)

In [0]:
df_genus_spark = spark.createDataFrame(df_genus)
df_genus_spark.write.format("delta").mode("overwrite").saveAsTable("onesource_eu_dev_rni.onebiome.mpa_2_genus_table")