# Extract all the features

Similar to notebook2 but we package everything inside a for loop to exctract features for all tissues based on all pretrained models.

Further analysis of embeddings produced by all models can be found in the code to reproduce figures from the paper.

In [1]:
# TO REMOVE when notebook is stable

%load_ext autoreload
%autoreload 2

### Common Imports

In [2]:
import torch
import tarfile
import os
from anndata import read_h5ad

# tissue_purifier import
import tissue_purifier as tp

### Download the example dataset

In [3]:
## replace with your own path
data_destination_folder = "../../TissueMosaic_Figures/TissueMosaic_data/testis_anndata_corrected_doubletmode_annotated/"

# Make a list of all the h5ad files in the data_destination_folder
fname_list = []
for f in os.listdir(data_destination_folder):
    if f.endswith('.h5ad'):
        fname_list.append(f)
print(fname_list)

['wt3_dm.h5ad', 'wt1_dm.h5ad', 'diabetes2_dm.h5ad', 'wt2_dm.h5ad', 'diabetes1_dm.h5ad', 'diabetes3_dm.h5ad']


### copy the data into a new folder for further processing

In [4]:
import shutil

new_data_destination_folder = "testis_anndata_featurized/"
shutil.copytree(data_destination_folder, new_data_destination_folder)

'testis_anndata_featurized/'

### Download all the checkpoint files

In [14]:

all_ckpts = ["testis_dino.pt", "testis_barlow.pt", "testis_simclr.pt", "testis_vae.pt"]
all_models = ["dino", "barlow", "simclr", "vae"] 
## replace with your own path
ckpt_path = os.path.abspath("../../model_checkpoints/testis/")
all_ckpts_dest = []
for ckpt in all_ckpts:
    ckpt_dest= os.path.join(ckpt_path, ckpt)
    all_ckpts_dest.append(ckpt_dest)
    
# print(all_ckpts_dest)

### Extract features with all the models (Barlow, Simclr, Dino, Vae) and ncv_k for multiple k

In [15]:
from tissue_purifier.data import AnndataFolderDM
from tissue_purifier.models.ssl_models import *
# now you have access to Barlow, SImclr, Dino, Vae

n_patches_max = 1000 # cover each tissue with this many overlapping patches

for ckpt_path, model_name in zip(all_ckpts_dest, all_models):
    
    print("----------")
    print("Model --->", model_name, ckpt_path)
    print("----------")
    
    # load the model from checkpoint
    if model_name == "barlow":
        model = tp.models.ssl_models.Barlow.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False)
    elif model_name == "simclr":
        model = tp.models.ssl_models.Simclr.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False)
    elif model_name == "dino":
        model = tp.models.ssl_models.Dino.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False)
    elif model_name == "vae":
        model = tp.models.ssl_models.Vae.load_from_checkpoint(checkpoint_path=ckpt_path, strict=False)
    else:
        raise Exception("Model name not recongnized {}".format(model_name))
        
    # create the datamodule associated with the pretrained model
    dm = tp.data.AnndataFolderDM(**model._hparams) 
    
    # put the model on GPU if available
    if torch.cuda.is_available():
        model = model.cuda()
    
    # process all the anndata with the model-datamodule pair
    for fname in fname_list:
        
        # open adata and convert to sparse_image
        adata_path = os.path.join(new_data_destination_folder, fname)
        adata = read_h5ad(adata_path)
        sp_img = dm.anndata_to_sparseimage(adata)
                
        # put sparse image on GPU if available
        if torch.cuda.is_available():
            sp_img = sp_img.cuda()
            
        # compute nvc with different k
        if model_name == "barlow":
            for k in 10, 20, 50, 100, 200, 500:
                sp_img.compute_ncv(feature_name="ncv_k{}".format(k), k=k)
        
        # compute the patch-feature (internally it crops sparse image and feed crops to pretrained model)
        sp_img.compute_patch_features(
            feature_name=model_name, 
            datamodule=dm, 
            model=model, 
            batch_size=64,
            strategy='random',
            remove_overlap=False,
            n_patches_max=n_patches_max,
            overwrite=True)
        
        # transfer the patch-level annotation to the spot-level
        sp_img.transfer_patch_to_spot(keys_to_transfer=model_name, overwrite=True)
        
        # write the new adata to disk
        new_adata = sp_img.to_anndata()
        new_adata.write(filename=adata_path) # overwrite the file but with extra annotations 
        
        # free memory by erasing the sparse_image
        del sp_img
    
    # after loop over anndata erase the model and free memory
    del model
        
        

----------
Model ---> dino /home/skambha6/chenlab/tissue_purifier/model_checkpoints/testis/testis_dino.pt
----------




number of elements ---> 29178
mean and median spacing 15.90507495709278, 15.497339152935078
The dense shape of the image is -> torch.Size([9, 1178, 1175])
number of elements ---> 27840
mean and median spacing 16.009033744023068, 15.768961335552781
The dense shape of the image is -> torch.Size([9, 1160, 1143])
number of elements ---> 29607
mean and median spacing 15.810478612949094, 15.727658385209352
The dense shape of the image is -> torch.Size([9, 1180, 855])
number of elements ---> 30132
mean and median spacing 16.353857684013548, 15.931447916615909
The dense shape of the image is -> torch.Size([9, 1180, 1180])
number of elements ---> 34868
mean and median spacing 15.821949004591055, 15.638433550603624
The dense shape of the image is -> torch.Size([9, 1180, 1181])
number of elements ---> 34868
mean and median spacing 15.821949004591055, 15.638433550603624
The dense shape of the image is -> torch.Size([9, 1180, 1181])
----------
Model ---> barlow /home/skambha6/chenlab/tissue_purifie



number of elements ---> 29178
mean and median spacing 15.90507495709278, 15.497339152935078
The dense shape of the image is -> torch.Size([9, 1178, 1175])
Key ncv_k10 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k20 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k50 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k100 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k200 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k500 already present in spot dictionary. Set overwrite to True to overwrite
number of elements ---> 27840
mean and median spacing 16.009033744023068, 15.768961335552781
The dense shape of the image is -> torch.Size([9, 1160, 1143])
Key ncv_k10 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k20 already present in spot dictionary. Set overwrite to True to overwrite
Key ncv_k50 already pr



number of elements ---> 29178
mean and median spacing 15.90507495709278, 15.497339152935078
The dense shape of the image is -> torch.Size([9, 1178, 1175])
number of elements ---> 27840
mean and median spacing 16.009033744023068, 15.768961335552781
The dense shape of the image is -> torch.Size([9, 1160, 1143])
number of elements ---> 29607
mean and median spacing 15.810478612949094, 15.727658385209352
The dense shape of the image is -> torch.Size([9, 1180, 855])
number of elements ---> 30132
mean and median spacing 16.353857684013548, 15.931447916615909
The dense shape of the image is -> torch.Size([9, 1180, 1180])
number of elements ---> 34868
mean and median spacing 15.821949004591055, 15.638433550603624
The dense shape of the image is -> torch.Size([9, 1180, 1181])
number of elements ---> 34868
mean and median spacing 15.821949004591055, 15.638433550603624
The dense shape of the image is -> torch.Size([9, 1180, 1181])
----------
Model ---> vae /home/skambha6/chenlab/tissue_purifier/m

  net = ResNetDecoder(DecoderBlock, [3, 4, 6, 3], latent_dim=1,
  resize_conv1x1(self.inplanes, planes * block.expansion, scale),
  return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes))
  return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes))
  layers.append(block(self.inplanes, planes, scale, upsample))
  self.conv1 = resize_conv3x3(inplanes, inplanes)
  return conv3x3(in_planes, out_planes)


number of elements ---> 29178
mean and median spacing 15.90507495709278, 15.497339152935078
The dense shape of the image is -> torch.Size([9, 1178, 1175])
number of elements ---> 27840
mean and median spacing 16.009033744023068, 15.768961335552781
The dense shape of the image is -> torch.Size([9, 1160, 1143])
number of elements ---> 29607
mean and median spacing 15.810478612949094, 15.727658385209352
The dense shape of the image is -> torch.Size([9, 1180, 855])
number of elements ---> 30132
mean and median spacing 16.353857684013548, 15.931447916615909
The dense shape of the image is -> torch.Size([9, 1180, 1180])
number of elements ---> 34868
mean and median spacing 15.821949004591055, 15.638433550603624
The dense shape of the image is -> torch.Size([9, 1180, 1181])
number of elements ---> 34868
mean and median spacing 15.821949004591055, 15.638433550603624
The dense shape of the image is -> torch.Size([9, 1180, 1181])


### check that the anndata object have the new annotations stored in .obsm

In [16]:
for fname in fname_list:
    anndata = read_h5ad(os.path.join(new_data_destination_folder, fname))
    print("----")
    print(fname)
    print(anndata)

----
wt3_dm.h5ad
AnnData object with n_obs × n_vars = 29178 × 24450
    obs: 'x', 'y', 'UMI', 'cell_type'
    uns: 'status'
    obsm: 'barlow', 'cell_type_proportions', 'dino', 'ncv_k10', 'ncv_k100', 'ncv_k20', 'ncv_k200', 'ncv_k50', 'ncv_k500', 'simclr', 'vae'
----
wt1_dm.h5ad
AnnData object with n_obs × n_vars = 27840 × 23514
    obs: 'x', 'y', 'UMI', 'cell_type'
    uns: 'status'
    obsm: 'barlow', 'cell_type_proportions', 'dino', 'ncv_k10', 'ncv_k100', 'ncv_k20', 'ncv_k200', 'ncv_k50', 'ncv_k500', 'simclr', 'vae'
----
diabetes2_dm.h5ad
AnnData object with n_obs × n_vars = 29607 × 23741
    obs: 'x', 'y', 'UMI', 'cell_type'
    uns: 'status'
    obsm: 'barlow', 'cell_type_proportions', 'dino', 'ncv_k10', 'ncv_k100', 'ncv_k20', 'ncv_k200', 'ncv_k50', 'ncv_k500', 'simclr', 'vae'
----
wt2_dm.h5ad
AnnData object with n_obs × n_vars = 30132 × 24263
    obs: 'x', 'y', 'UMI', 'cell_type'
    uns: 'status'
    obsm: 'barlow', 'cell_type_proportions', 'dino', 'ncv_k10', 'ncv_k100', 'ncv_k20