In [None]:
import sys
sys.path.append("/kaggle/input")

from starfysh import (utils, plot_utils, post_analysis)
from starfysh import starfysh as sf_model

import scanpy as sc
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import matplotlib.font_manager
from matplotlib import rcParams
import gc

import seaborn as sns
sns.set_style('white')

In [None]:
cheap_load_path = '/kaggle/input/breastcancerdataset/BreastCancerDataset'
def cheap_load(path):
    adata = sc.read_visium(path)
    adata.var_names_make_unique()
    adata = adata[0:500]
    raw = adata.copy()
    del adata.uns
    gc.collect()
    adata.var['mt'] = np.logical_or(adata.var_names.str.startswith('MT-'), adata.var_names.str.startswith('mt-'))
    adata.var['rb'] = (adata.var_names.str.startswith('RP-') | adata.var_names.str.startswith('rp-'))
    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
    mask_cell = adata.obs['pct_counts_mt'] < 100 #mt_thld
    mask_gene = np.logical_and(~adata.var['mt'], ~adata.var['rb'])
    adata = adata[mask_cell, mask_gene]
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, flavor='seurat', n_top_genes=1000, inplace=True)
    raw = raw[adata.obs_names, adata.var_names]
    raw.var['highly_variable'] = adata.var['highly_variable']
    raw.obs = adata.obs
    return raw, adata

In [None]:
adata, adata_normed = cheap_load(cheap_load_path)
del adata.uns
gc.collect()

In [None]:
df = pd.read_csv("/kaggle/input/breastcancerdataset/nuclei_morphology.csv", index_col='Spot')
df = df[:500]
adata.uns['morph'] = df.values.astype('float64')
adata.uns['morph'] = adata.uns['morph'][~df.Mean_filtered_nuclei_AreaShape_Area.isna()]
mask = ~df[:500].Mean_filtered_nuclei_AreaShape_Area.isna()
adata = adata[mask]
adata_normed = adata_normed[mask]

In [None]:
gene_sig = pd.read_csv("/kaggle/input/breastcancerdataset/gene_sig.csv")
gene_sig.drop("Unnamed: 0", axis=1, inplace=True)

In [None]:
data_path = '/kaggle/input/breastcancerdataset'
sample_id = 'BreastCancerDataset'

img_metadata = utils.preprocess_img(data_path,
                                    sample_id,
                                    adata_index=adata.obs.index,
                                    hchannel=False
                                    )
img, map_info, scalefactor = img_metadata['img'], img_metadata['map_info'], img_metadata['scalefactor']

In [None]:
visium_args = utils.VisiumArguments(adata,
                                    adata_normed,
                                    gene_sig,
                                    img_metadata,
                                    n_anchors=20,
                                    window_size=2,
                                    sample_id=sample_id
                                   )

adata, adata_normed = visium_args.get_adata()
anchors_df = visium_args.get_anchors()
gc.collect()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, loss = utils.run_starfysh(visium_args,
                                 n_repeats=1,
                                 num_features=38,
                                 epochs=200,
                                 poe=True,
                                 device=device)

In [None]:
inference_outputs, generative_outputs = sf_model.model_eval(model,
                                                            adata,
                                                            visium_args,
                                                            poe=True,
                                                            device=device)

In [None]:
# Specify output directory
outdir = '/kaggle/working/'

# save the model
torch.save(model.state_dict(), os.path.join(outdir, 'starfysh_model.pt'))

# save `adata` object with inferred parameters
adata.write(os.path.join(outdir, 'adata.h5ad'))
adata_normed.write(os.path.join(outdir, 'adata_normed.h5ad'))

# save inference outputs
torch.save(inference_outputs['qc_m'], os.path.join(outdir, 'qc_m.pt'))