## Denoising Paired-target Count Matrices with PeakVI

In [None]:
import os
import tempfile
import scvi
import scanpy as sc
import torch
import anndata
from scipy.io import mmread
import pandas as pd
import numpy as np
import pickle

In [None]:
torch.cuda.is_available()

In [None]:
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

In [None]:
knnmems = mmread('data/out/pre_cg_output_KNNmembership_1k.mtx').tocsr()
ref_obs = pd.read_csv('data/out/pre_cg_output_obs.csv')['cell'].values

In [None]:
def peakviii(ab,indir,nhidden=20,nlatent=5,min_cells=0,min_counts=0):
  os.chdir(indir)
  prefix = ab

  matin = mmread(prefix+'_count.mtx')
  matin = matin.tocsr()
  matin = matin.transpose()

  obsin = pd.read_csv(prefix+'_obs.csv',index_col=0)
  varin = pd.read_csv(prefix+'_var.csv',index_col=0)
  
  adata = anndata.AnnData(X = matin, obs = obsin, var = varin)
  sc.pp.filter_genes(adata, min_cells=min_cells)
  sc.pp.filter_cells(adata, min_counts=min_counts)

  scvi.model.PEAKVI.setup_anndata(adata,categorical_covariate_keys=["rep"])
  model = scvi.model.PEAKVI(adata,n_hidden=nhidden,n_latent=nlatent)
  model.to_device("cuda:0")
  model.train(max_epochs=1000)

  model_dir = os.path.join(save_dir.name, "peakvi_"+ab)
  model.save(model_dir, overwrite=True)
  model = scvi.model.PEAKVI.load(model_dir, adata=adata)

  PEAKVI_LATENT_KEY = "X_peakvi"
  latent = model.get_latent_representation()
  adata.obsm[PEAKVI_LATENT_KEY] = latent

  adata.obsm['X_peakvi'] = model.get_latent_representation()
  sc.pp.neighbors(adata, use_rep='X_peakvi',n_neighbors=30)
  sc.tl.umap(adata, min_dist=0.3)

  return adata,model


def saveResults(ab,adata,model,outdir):
  if not os.path.isdir(outdir):
    os.mkdir(outdir)
  os.chdir(outdir)
  prefix = ab

  with open(prefix+'_model.pkl', 'wb') as f:
    pickle.dump(model, f)

  latent = adata.obsm["X_peakvi"]
  np.savetxt(prefix+"_latent.csv",latent)
  um = adata.obsm["X_umap"]
  np.savetxt(prefix+"_umap.csv",um)

  obs = adata.obs
  obs.to_csv(prefix+"_obs.csv")
  var = adata.var
  var.to_csv(prefix+"_var.csv")


def Denoise(ab,adata,model,outdir,usereg=None):
  prefix = ab
  denoised = model.get_accessibility_estimates(adata)
  if usereg is not None:
    regs = denoised.columns.values
    usereg = outdir+ab+'_usevar.csv'
    usereg  = pd.read_csv(usereg,delimiter='\t')['region'].values
    if np.char.find(regs[0], '_')>0:
      usereg = np.char.replace(usereg.astype(str),"-", "_")
    usereg = np.unique(usereg)
    usereg = np.intersect1d(usereg,regs)
    denoised = denoised[usereg]
  denoised = np.transpose(denoised)
  
  if ab == 'RNAPII':
    cell1 = ref_obs
  else:
    cell1 = [col for col in ref_obs if ab in col]
  cell1 = np.array(cell1)
  cell2 = denoised.columns.values
  idx = np.isin(cell1, cell2)
  use = cell1[idx]
  
  knnmems_sub = knnmems[:, idx]
  denoised = denoised.loc[:, use]

  row_sums = knnmems_sub.sum(axis=1)
  knnmems_sub = knnmems_sub/row_sums

  avgmat = denoised.values @ knnmems_sub.transpose()
  avgmat = pd.DataFrame(avgmat)
  avgmat.index = denoised.index
  avgmat.to_csv(outdir+'/'+ab+'_denoised_avg.csv', index=True, header=False)


def fun(ab,indir,outdir,train=True,nhidden=50,nlatent=10,denoise=True,modeldir=None,min_cells=0,min_counts=0,usereg=None):
  if train:
    adata,model = peakviii(ab,indir,nhidden=nhidden,nlatent=nlatent,min_cells=min_cells,min_counts=min_counts)
    saveResults(ab,adata,model,outdir)
  elif denoise:
    if modeldir == None:
      modeldir = outdir
    with open(modeldir+'/'+ab+'_model.pkl', 'rb') as f:
      model = pickle.load(f)
    adata = model.adata

  if denoise:
    Denoise(ab,adata,model,outdir,usereg)

  return(adata)


In [None]:
indir = 'data/mtx_filtered/'
outdir = 'out/peakvi_out/'
Abs = ['Brg1', 'H3K27ac', 'H3K4me3', 'MyoD', 'Myog']

adata = {}
for ab in Abs:
  adata[ab] = fun(ab,indir,outdir,denoise=False)

In [None]:
modeldir = 'data/out/peakvi_out/'
outdir = 'data/out/denoised/'

for ab in Abs:
  adata = fun(ab,indir,outdir,modeldir = modeldir,train = False,usereg = outdir+ab+'_usevar.csv')