In [1]:
import sys
import os

import pandas as pd
import scanpy as sc
import numpy as np
import warnings

import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import DataLoader, TensorDataset
from torch.nn import DataParallel
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import anndata
import seaborn as sns
import matplotlib.font_manager
from matplotlib import rcParams

font_list = []
fpaths = matplotlib.font_manager.findSystemFonts()
for i in fpaths:
    try:
        f = matplotlib.font_manager.get_font(i)
        font_list.append(f.family_name)
    except RuntimeError:
        pass

font_list = set(font_list)
plot_font = 'Helvetica' if 'Helvetica' in font_list else 'FreeSans'
rcParams['font.family'] = plot_font
rcParams.update({'font.size': 10})
rcParams.update({'figure.dpi': 300})
rcParams.update({'figure.figsize': (3,3)})
rcParams.update({'savefig.dpi': 500})
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
adata_raw = sc.read_h5ad('/hpc/mydata/siyu.he/Siyu_projects/squidward_study/perturb/data/norman/perturb_processed.h5ad')
#adata = adata[adata.obs['condition']=='ctrl'].copy()#anndata.AnnData(count)
adata_raw.obs['Group']=adata_raw.obs['condition']
adata_raw = adata_raw[adata_raw.obs['Group'].isin(['ctrl',
                                                       'PTPN12+ZBTB25',
                                                       'ZBTB25+ctrl',
                                                       'PTPN12+ctrl'])].copy()

In [None]:
adata = adata_raw.copy()
condi_df = adata.obs['Group'].map({'ctrl':0,
                                     'PTPN12+ctrl':1,
                                     'ZBTB25+ctrl':2,
                                     'PTPN12+ZBTB25':3
                             })
del adata.obs
adata.obs['Group'] = condi_df

In [None]:
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=30, n_pcs=50)
sc.tl.umap(adata,min_dist=0.005)

In [None]:
import matplotlib.colors as mcolors
hex_colors = ['#D4CB92', '#395C6B', '#80A4ED','#BCD3F2']
rgb_colors = [mcolors.to_rgb(color) for color in hex_colors]
cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", rgb_colors)
fig,axs=plt.subplots(1,1,figsize=(3,3),dpi=800)
for i in range(4):
    g=axs.scatter(adata.obsm['X_umap'][adata.obs['Group']==i,0],
                adata.obsm['X_umap'][adata.obs['Group']==i,1],
                #c=adata.obs['condition'],
                #cmap=cmap,
                  c=hex_colors[i],
                s=1,
                alpha=0.8
               )
#plt.colorbar(g)
plt.axis('off')

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=1000)

for i in ['ALAS2',
          'HBA1',
          'HBA2',
          'HIST1H1C',
          'GYPB',
          'SLC25A37',
          #'DYNLRB1',
          'LGALS1',
          #'TUFM',
          #'PKM'
         ]:
    adata.var.highly_variable[list(adata.var['gene_name']).index(i)]=True
adata = adata[:, adata.var.highly_variable].copy()

In [None]:
adata

In [7]:
adata.X = adata.X.toarray()

In [8]:
adata_ctl = adata[adata.obs['Group'].isin([0,1,2])].copy()
adata_val = adata[adata.obs['Group'].isin([3])].copy()

In [9]:
adata_val

AnnData object with n_obs × n_vars = 257 × 3267
    obs: 'Group'
    var: 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'pca', 'neighbors', 'umap', 'hvg'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'distances', 'connectivities'

In [10]:
adata_ctl

AnnData object with n_obs × n_vars = 7890 × 3267
    obs: 'Group'
    var: 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'pca', 'neighbors', 'umap', 'hvg'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'distances', 'connectivities'

In [11]:
adata_ctl.write('datasets/gears_train_data_2000.h5ad')
adata_val.write('datasets/gears_test_data_2000.h5ad')

In [2]:
sc.read_h5ad('datasets/gears_train_data.h5ad')

AnnData object with n_obs × n_vars = 7890 × 203
    obs: 'Group'
    var: 'gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'neighbors', 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'pca', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

In [14]:
adata_ctl.to_df()

gene_id,ENSG00000239945,ENSG00000223764,ENSG00000187634,ENSG00000187642,ENSG00000188290,ENSG00000187608,ENSG00000273443,ENSG00000237330,ENSG00000223823,ENSG00000205231,...,ENSG00000198899,ENSG00000198938,ENSG00000198840,ENSG00000212907,ENSG00000198886,ENSG00000198786,ENSG00000198695,ENSG00000198727,ENSG00000273554,ENSG00000278633
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGCACGAAGCA-1,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,...,4.380425,5.377383,3.458815,2.162728,4.866675,2.671797,1.072852,3.723229,0.0,0.0
AAACCTGGTATAATGG-1,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,...,3.887486,5.227680,3.857355,0.000000,4.570333,2.109076,0.895322,3.826288,0.0,0.0
AAACCTGTCCGATATG-1,0.0,0.0,0.0,0.0,0.907872,0.553610,0.0,0.0,0.0,0.0,...,4.505152,5.084357,3.656242,1.820776,4.560787,2.362153,0.553610,3.764745,0.0,0.0
AAACGGGCAATGGACG-1,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,...,3.979031,4.877613,3.687023,1.478156,4.125962,1.804557,2.976381,3.059135,0.0,0.0
AAAGATGAGATGAGAG-1,0.0,0.0,0.0,0.0,0.714880,0.714880,0.0,0.0,0.0,0.0,...,3.892254,5.213247,3.597147,1.643986,4.461336,1.127485,1.827722,3.442796,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGCGCAGGCAGGTT-8,0.0,0.0,0.0,0.0,0.000000,0.700755,0.0,0.0,0.0,0.0,...,3.821486,5.090238,3.776008,0.000000,4.421693,2.210709,1.397684,3.985166,0.0,0.0
TTTGGTTGTTAAGATG-8,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,...,4.188715,5.134850,3.393106,1.330532,4.510700,2.014076,1.330532,3.664461,0.0,0.0
TTTGGTTTCCTGCCAT-8,0.0,0.0,0.0,0.0,0.634950,0.000000,0.0,0.0,0.0,0.0,...,4.040881,5.158772,3.667091,1.020238,4.486452,1.692793,2.091298,3.494365,0.0,0.0
TTTGTCAGTAGCGTGA-8,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,...,3.883829,5.095631,3.573060,0.000000,4.282413,2.013872,1.673010,3.632032,0.0,0.0
