# Goal of Notebook
In this notebook, we will prep the Visium objects, filter the 

# Load Modules

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import pandas as pd
import squidpy as sq
import importlib
import anndata as an 
import seaborn as sns

from spatialdata_io import visium_hd

In [None]:
from spatialdata_io import visium_hd

In [None]:
os.getcwd()
os.chdir("/home/plopez/data/Projects/SC_placenta/Visium_workspace")

In [None]:
## import functions
sys.path.append("/home/plopez/data/Projects/SC_placenta/Visium_workspace/Code_workspace/helpful_functions")
import prepping_core as func
importlib.reload(func)

# Import object

## TMA2 - 140

### load spatialdata obj

In [None]:
sdata = visium_hd("raw_data/TMA2-140/outs/",bin_size=8,dataset_id = "TMA2_140",
                  fullres_image_file="raw_data/TMA2-140/outs/binned_outputs/square_008um/spatial/tissue_hires_image.png")

In [None]:
for table in sdata.tables.values():
    table.var_names_make_unique()

In [None]:
adata = sdata.tables

In [None]:
sq.pl.spatial_scatter(
            adata['square_008um'],
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["ACTA2"]
        )

### merge with coordinates

In [None]:
adata_fixed = func.annotate_adata_with_tissue_core(adata,
                                              obs_key= 'square_008um',
                                              barcode_csv_path = "data/TMA_input_coordinates/TMA2140_cell_coordinates.csv",
                                              drop_nan =True)


### merge with metdata

In [None]:
adata = func.add_patient_condition_from_metadata(
    adata,
    obs_key='square_008um',
    metadata_path='Code_workspace/visium_metadata_V1.xlsx'
)


### Save Obj

In [None]:
adata['square_008um'].write_h5ad("data/0_preped_h5ad_obj_metadata/TMA2140_obj_V1.h5ad")

In [None]:
adata['square_008um']

In [None]:
sq.pl.spatial_scatter(
            adata['square_008um'],
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["PSG5"]
        )

## TMA2 - 145

### Load obj 

In [None]:
sdata = visium_hd("raw_data/TMA2-145/outs/",bin_size=8,dataset_id = "TMA2_145",
                  fullres_image_file="raw_data/TMA2-145/outs/binned_outputs/square_008um/spatial/tissue_hires_image.png")

In [None]:
for table in sdata.tables.values():
    table.var_names_make_unique()

In [None]:
adata = sdata.tables

In [None]:
adata

### merge with coordinate

In [None]:
adata_fixed = func.annotate_adata_with_tissue_core(adata,
                                              obs_key= 'square_008um',
                                              barcode_csv_path = "data/TMA_input_coordinates/TMA2145.csv",
                                              drop_nan =True)


### merge with metadata

In [None]:
adata = func.add_patient_condition_from_metadata(
    adata_fixed,
    obs_key='square_008um',
    metadata_path='Code_workspace/visium_metadata_V1.xlsx'
)


In [None]:
sq.pl.spatial_scatter(
            adata['square_008um'],
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient"]
        )

### ssave obj 

In [None]:
adata['square_008um'].write_h5ad("data/0_preped_h5ad_obj_metadata/TMA2145_obj_V1.h5ad")

## TMA4

### load obj 

In [None]:
sdata = visium_hd("raw_data/TMA4/outs/",bin_size=8,dataset_id = "TMA4",
                  fullres_image_file="raw_data/TMA4/outs/binned_outputs/square_008um/spatial/tissue_hires_image.png")

In [None]:
for table in sdata.tables.values():
    table.var_names_make_unique()

In [None]:
adata = sdata.tables
adata

### merge with coordinate 

In [None]:
adata_fixed = func.annotate_adata_with_tissue_core(adata,
                                              obs_key= 'square_008um',
                                              barcode_csv_path = "data/TMA_input_coordinates/TMA4.csv",
                                              drop_nan =True)


### merge with metadata

In [None]:
adata = func.add_patient_condition_from_metadata(
    adata_fixed,
    obs_key='square_008um',
    metadata_path='Code_workspace/visium_metadata_V1.xlsx'
)


In [None]:
sq.pl.spatial_scatter(
            adata['square_008um'],
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Condition","Patient"]
        )

### save obj 

In [None]:
adata['square_008um'].write_h5ad("data/0_preped_h5ad_obj_metadata/TMA4_obj_V1.h5ad")

## TMA 9

### load obj 

In [None]:
sdata = visium_hd("raw_data/TMA9/outs/",bin_size=8,dataset_id = "TMA9",
                  fullres_image_file="raw_data/TMA9/outs/binned_outputs/square_008um/spatial/tissue_hires_image.png")

In [None]:
for table in sdata.tables.values():
    table.var_names_make_unique()

In [None]:
adata = sdata.tables
adata

### merge with coordinate 

In [None]:
adata_fixed = func.annotate_adata_with_tissue_core(adata,
                                              obs_key= 'square_008um',
                                              barcode_csv_path = "data/TMA_input_coordinates/TMA9.csv",
                                              drop_nan =True)


### merge with metadata 

In [None]:
adata = func.add_patient_condition_from_metadata(
    adata_fixed,
    obs_key='square_008um',
    metadata_path='Code_workspace/visium_metadata_V1.xlsx'
)


In [None]:
sq.pl.spatial_scatter(
            adata['square_008um'],
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Condition","Patient"]
        )

In [None]:
adata['square_008um'].write_h5ad("data/0_preped_h5ad_obj_metadata/TMA9_obj_V1.h5ad")

# QC of objects

## TMA2 140

### load obj

In [None]:
tma2_140_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA2140_obj_V1.h5ad")
tma2_140_obj.obs["TMA_num"] = tma2_140_obj.obs["Row_num"].str.extract(r'^(TMA\d+)')

In [None]:
sq.pl.spatial_scatter(
            tma2_140_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient","Tissue_core","Condition"]
        )

In [None]:
tma2_140_obj.var_names_make_unique()
tma2_140_obj.var["mt"] = tma2_140_obj.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(tma2_140_obj, qc_vars=["mt"], inplace=True)

### plot QC

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma2_140_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma2_140_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma2_140_obj, "pct_counts_mt",ax = axs[2])

#sns.histplot(tma2_140_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[2])


In [None]:
sq.pl.spatial_scatter(
            tma2_140_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

### filter 

In [None]:
print("adata size " + str(tma2_140_obj.shape))
sc.pp.filter_cells(tma2_140_obj, min_counts=200)
tma2_140_obj = tma2_140_obj[tma2_140_obj.obs['pct_counts_mt'] < 20, :]
print("adata size " + str(tma2_140_obj.shape))

### Replot 

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma2_140_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma2_140_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma2_140_obj, "pct_counts_mt",ax = axs[2])



In [None]:
sq.pl.spatial_scatter(
            tma2_140_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient","Tissue_core","Condition"]
        )

In [None]:
sq.pl.spatial_scatter(
            tma2_140_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

### Save 

In [None]:
tma2_140_obj.write_h5ad("data/0_preped_h5ad_obj_metadata/TMA2140_obj_V2.h5ad")

## TMA2 145

### load obj 

In [None]:
tma2_145_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA2140_obj_V1.h5ad")
tma2_145_obj.obs["TMA_num"] = tma2_145_obj.obs["Row_num"].str.extract(r'^(TMA\d+)')

In [None]:
sq.pl.spatial_scatter(
            tma2_145_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient","Tissue_core","Condition"]
        )

### plot QC

In [None]:
tma2_145_obj.var_names_make_unique()
tma2_145_obj.var["mt"] = tma2_145_obj.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(tma2_145_obj, qc_vars=["mt"], inplace=True)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma2_145_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma2_145_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma2_145_obj, "pct_counts_mt",ax = axs[2])

In [None]:
sq.pl.spatial_scatter(
            tma2_145_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

### Filter 

In [None]:
print("adata size " + str(tma2_145_obj.shape))
sc.pp.filter_cells(tma2_145_obj, min_counts=200)
tma2_145_obj = tma2_145_obj[tma2_145_obj.obs['pct_counts_mt'] < 20, :]
print("adata size " + str(tma2_145_obj.shape))

### Replot 

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma2_145_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma2_145_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma2_145_obj, "pct_counts_mt",ax = axs[2])

In [None]:
sq.pl.spatial_scatter(
            tma2_145_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

### Save 

In [None]:
tma2_145_obj.write_h5ad("data/0_preped_h5ad_obj_metadata/TMA2145_obj_V2.h5ad")

## TMA 4

### load obj

In [None]:
tma4_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA4_obj_V1.h5ad")
tma4_obj.obs["TMA_num"] = tma4_obj.obs["Row_num"].str.extract(r'^(TMA\d+)')


### Get QC

In [None]:
tma4_obj.var_names_make_unique()
tma4_obj.var["mt"] = tma4_obj.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(tma4_obj, qc_vars=["mt"], inplace=True)

### plot QC

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma4_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma4_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma4_obj, "pct_counts_mt",ax = axs[2])

In [None]:
sq.pl.spatial_scatter(
            tma4_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

### Filter 

In [None]:
print("adata size " + str(tma4_obj.shape))
sc.pp.filter_cells(tma4_obj, min_counts=200)
tma4_obj = tma4_obj[tma4_obj.obs['pct_counts_mt'] < 20, :]
print("adata size " + str(tma4_obj.shape))

### replot

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma4_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma4_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma4_obj, "pct_counts_mt",ax = axs[2])

In [None]:
sq.pl.spatial_scatter(
            tma4_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

In [None]:
sq.pl.spatial_scatter(
            tma4_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient","Tissue_core","Condition"]
        )

### save obj 

In [None]:
tma4_obj.write_h5ad("data/0_preped_h5ad_obj_metadata/TMA4_obj_V2.h5ad")


## TMA 9

### load obj 

In [None]:
tma9_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA9_obj_V1.h5ad")
tma9_obj.obs["TMA_num"] = tma9_obj.obs["Row_num"].str.extract(r'^(TMA\d+)')



### get QC

In [None]:
tma9_obj.var_names_make_unique()
tma9_obj.var["mt"] = tma9_obj.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(tma9_obj, qc_vars=["mt"], inplace=True)

### plot QC

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma9_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma9_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma9_obj, "pct_counts_mt",ax = axs[2])

In [None]:
sq.pl.spatial_scatter(
            tma9_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

### filter

In [None]:
print("adata size " + str(tma9_obj.shape))
sc.pp.filter_cells(tma9_obj, min_counts=200)
tma9_obj = tma9_obj[tma9_obj.obs['pct_counts_mt'] < 20, :]
print("adata size " + str(tma9_obj.shape))

### replot 

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
sns.histplot(tma9_obj.obs["total_counts"], kde=False, ax=axs[0])
sns.histplot(tma9_obj.obs["n_genes_by_counts"], kde=False, bins=60, ax=axs[1])
sc.pl.violin(tma9_obj, "pct_counts_mt",ax = axs[2])

In [None]:
sq.pl.spatial_scatter(
            tma9_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["total_counts","n_genes_by_counts","pct_counts_mt"]
        )

In [None]:
sq.pl.spatial_scatter(
            tma9_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient","Tissue_core","Condition"]
        )

### save obj 

In [None]:
tma9_obj.write_h5ad("data/0_preped_h5ad_obj_metadata/TMA9_obj_V2.h5ad")


# Merge 

## Load objects

In [None]:
tma2_140_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA2140_obj_V2.h5ad")
tma2_140_obj

In [None]:
tma2_145_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA2145_obj_V2.h5ad")
tma2_145_obj

In [None]:
tma4_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA4_obj_V2.h5ad")
tma4_obj

In [None]:
tma9_obj = an.read_h5ad("data/0_preped_h5ad_obj_metadata/TMA9_obj_V2.h5ad")
tma9_obj

## merge into dictionary

In [None]:
tma_merge_obj = {"tma2_140":tma2_140_obj, "tma2_145":tma2_145_obj, "tma4": tma4_obj,"tma9":tma9_obj}


In [None]:
for tma_slide in tma_merge_obj:
#for batch in ["Batch12", "Batch13", "Batch14", "Batch15"]:
    print(tma_slide)
    obj = tma_merge_obj[tma_slide]
    obj.obs_names = obj.obs_names +"-"+ tma_slide


In [None]:
tma_merge_obj['tma2_140'].obs_names

In [None]:
tma_merged_obj = sc.concat(tma_merge_obj.values(), join="outer", label="TMA_num", keys=tma_merge_obj.keys(), index_unique=None)


In [None]:
tma_merged_obj

In [None]:
tma_merged_obj.var

In [None]:
tma2_145_obj.var

## merge var names 

In [None]:
gene_ids_df = tma2_145_obj.var[['gene_ids']].copy()

# Now join this into the merged .var dataframe
tma_merged_obj.var = tma_merged_obj.var.join(gene_ids_df, how='left')

## save obj 

In [None]:
tma_merged_obj.write_h5ad("data/1_merged_TMA_objs/TMA_merged_V1.h5ad")

In [None]:
tma_merged_obj

In [None]:
tma_merged_obj.X

In [None]:
sq.pl.spatial_scatter(
            tma_merged_obj,
            shape=None,            # must be positional
            library_id="spatial",
            use_raw=False,
            color=["Patient"]
        )