In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import squidpy as sq
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, facecolor='white')

## Xenium kidney data

In [None]:
adata = sc.read_h5ad("/shared/data/xenium_kidney.h5ad")
print(f"Data shape: {adata.shape}")
print(f"Available obsm keys: {list(adata.obsm.keys())}")
print(f"Available var columns: {list(adata.var.columns)}")
print(f"Available obs columns: {list(adata.obs.columns)}")

In [None]:
print("Sample composition:")
print(adata.obs['sample'].value_counts())
print("\nUnique samples:", adata.obs['sample'].unique())

if 'cell_type' in adata.obs.columns:
    print(f"\nCell types available: {adata.obs['cell_type'].nunique()}")
    print(adata.obs['cell_type'].value_counts().head(10))

In [None]:
spatial_keys = [key for key in adata.obsm.keys() if 'spatial' in key.lower() or 'coord' in key.lower()]
print(f"Spatial coordinate keys: {spatial_keys}")

if 'spatial' in adata.obsm.keys():
    coords = adata.obsm['spatial']
elif len(spatial_keys) > 0:
    coords = adata.obsm[spatial_keys[0]]
    adata.obsm['spatial'] = coords
else:
    print("Warning: No spatial coordinates found")

print(f"Coordinate range - X: [{coords[:,0].min():.2f}, {coords[:,0].max():.2f}]")
print(f"Coordinate range - Y: [{coords[:,1].min():.2f}, {coords[:,1].max():.2f}]")

## Basic spatial visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

for i, sample in enumerate(adata.obs['sample'].unique()):
    sample_data = adata[adata.obs['sample'] == sample]
    coords = sample_data.obsm['spatial']
    
    axes[i].scatter(coords[:, 0], coords[:, 1], s=0.5, alpha=0.6)
    axes[i].set_title(f'Sample: {sample}')
    axes[i].set_xlabel('X coordinate')
    axes[i].set_ylabel('Y coordinate')
    axes[i].axis('equal')

plt.tight_layout()
plt.show()

## QC

In [None]:
print(f"Number of genes: {adata.n_vars}")
print(f"Number of cells: {adata.n_obs}")

sc.pl.highest_expr_genes(adata, n_top=20, show=True)

adata.var['mt'] = adata.var_names.str.startswith('MT-')
adata.var['ribo'] = adata.var_names.str.startswith(('RPS', 'RPL'))
adata.var['hb'] = adata.var_names.str.contains('^HB[^(P)]')

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

sc.pl.violin(adata, ['n_genes_by_counts'], groupby='sample', ax=axes[0,0], show=False)
axes[0,0].set_title('Genes per cell')

sc.pl.violin(adata, ['total_counts'], groupby='sample', ax=axes[0,1], show=False)
axes[0,1].set_title('Total counts per cell')

if 'pct_counts_mt' in adata.obs.columns:
    sc.pl.violin(adata, ['pct_counts_mt'], groupby='sample', ax=axes[0,2], show=False)
    axes[0,2].set_title('Mitochondrial gene %')

for i, sample in enumerate(adata.obs['sample'].unique()):
    sample_mask = adata.obs['sample'] == sample
    axes[1,i].scatter(adata.obs.loc[sample_mask, 'total_counts'], 
                     adata.obs.loc[sample_mask, 'n_genes_by_counts'], 
                     s=0.5, alpha=0.6)
    axes[1,i].set_xlabel('Total counts')
    axes[1,i].set_ylabel('Number of genes')
    axes[1,i].set_title(f'{sample}: Counts vs Genes')

axes[1,2].axis('off')
plt.tight_layout()
plt.show()

In [None]:
## Filter genes

print(f"Genes before filtering: {adata.n_vars}")

sc.pp.filter_genes(adata, min_cells=10)
print(f"Genes after min_cells filter: {adata.n_vars}")

# adata.var['highly_variable'] = False
# sc.pp.highly_variable_genes(adata, n_top_genes=3000, flavor='seurat_v3')
adata.var['highly_variable'] = True

print(f"Highly variable genes: {adata.var['highly_variable'].sum()}")

In [None]:
## Filter cells

print(f"Cells before filtering: {adata.n_obs}")

min_genes = 200
max_genes = adata.obs['n_genes_by_counts'].quantile(0.98)
min_counts = 500
max_counts = adata.obs['total_counts'].quantile(0.98)

sc.pp.filter_cells(adata, min_genes=min_genes)
print(f"Cells after min_genes filter: {adata.n_obs}")

cell_filter = (
    (adata.obs['n_genes_by_counts'] <= max_genes) & 
    (adata.obs['total_counts'] >= min_counts) &
    (adata.obs['total_counts'] <= max_counts)
)

if 'pct_counts_mt' in adata.obs.columns:
    cell_filter &= (adata.obs['pct_counts_mt'] <= 20)

adata = adata[cell_filter, :].copy()
print(f"Cells after quality filters: {adata.n_obs}")

In [None]:
adata.raw = adata

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

adata_hvg = adata[:, adata.var.highly_variable].copy()
sc.pp.scale(adata_hvg, max_value=10)

print("Normalization and scaling completed")
print(f"Data shape after HVG selection: {adata_hvg.shape}")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

for i, sample in enumerate(adata.obs['sample'].unique()):
    sample_data = adata[adata.obs['sample'] == sample]
    coords = sample_data.obsm['spatial']
    
    axes[0, i].scatter(coords[:, 0], coords[:, 1], 
                      c=sample_data.obs['total_counts'], 
                      s=0.8, cmap='viridis', alpha=0.7)
    axes[0, i].set_title(f'{sample}: Total counts')
    axes[0, i].axis('equal')
    
    axes[1, i].scatter(coords[:, 0], coords[:, 1], 
                      c=sample_data.obs['n_genes_by_counts'], 
                      s=0.8, cmap='plasma', alpha=0.7)
    axes[1, i].set_title(f'{sample}: Gene counts')
    axes[1, i].axis('equal')

plt.tight_layout()
plt.show()

In [None]:
## save data

adata.write('xenium_kidney_preprocessed.h5ad')

print("Preprocessed data saved")
print(f"Full data: {adata.shape}")
print(f"Samples: {adata.obs['sample'].unique()}")