# Tutorial of scDCC on the 10X PBMC CITE data

Note that this tutorial is implemented on Macbook pro 2019 and all code is conducted on CPU. The results reported in the manuscript are conducted on Nvidia GPU P100.

In [1]:
from time import time
import math, os

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from scDCC import scDCC
import numpy as np
import collections
from sklearn import metrics
import h5py
import scanpy.api as sc
from preprocess import read_dataset, normalize
from utils import cluster_acc, generate_random_pair_from_proteins, generate_random_pair_from_CD_markers

Read in RNA count matrix (X matrix), normalized protein data (ADT_X matrix) and counts of CD proteins (adt_CD_normalized_counts matrix, which contains counts of CD4, CD8, CD27 and DR).

In [2]:
data_mat = h5py.File("CITE_PBMC_counts_top2000.h5")
x = np.array(data_mat['X'])
y = np.array(data_mat['Y'])
protein_markers = np.array(data_mat['ADT_X'])
data_mat.close()

CD_markers = np.loadtxt("adt_CD_normalized_counts.txt", delimiter=',')

# preprocessing scRNA-seq read counts matrix
adata = sc.AnnData(x)
adata.obs['Group'] = y

adata = read_dataset(adata,
                    transpose=False,
                    test_split=False,
                    copy=True)

adata = normalize(adata,
                    size_factors=True,
                    normalize_input=True,
                    logtrans_input=True)

input_size = adata.n_vars

print(adata.X.shape)
print(y.shape)
print(protein_markers.shape)

x_sd = adata.X.std(0)
x_sd_median = np.median(x_sd)
print("median of gene sd: %.5f" % x_sd_median)

### Autoencoder: Successfully preprocessed 2000 genes and 3762 cells.
(3762, 2000)
(3762,)
(3762, 49)
median of gene sd: 0.99987


Generate constraints based on all proteins and CDs. The details are described here (the section "Constraint Construction" in the manuscript):

1. We generated 20,000 ("n_pairwise_1") constraints based on all protein levels. We calculated Euclidean distances for all possible pairs of cells based on the normalized protein data and chose the 0.5th and 95th percentile of all pairwise distances as the must-link and cannot-link constraint cutoffs. Thirdly, we repeatedly sampled two cells and if the Euclidean distance between the two cells was less than the 0.5th percentile of all pairwise distances, we defined it as a must-link constraint; if the Euclidean distance between the two cells was greater than the 95th percentile of all pairwise distances, we defined it as a cannot-link constraint.

2. To separate CD4 and CD8 T cells, we further added 5000 ("n_pairwise_2") constraints based on following rules: if one cell has high CD4 protein level ( > 70th percentile) and low CD8 protein level (< 30th percentile), and another cell has high CD8 protein level ( > 70th percentile) and low CD4 protein level (< 30th percentile), then a cannot-link is constructed. To further identify subtypes of CD4 and CD8 T cells (CD8+CD27-, CD8+CD27+, CD4+CD27+, CD4+CD27-DR+, CD4+CD27-DR-), we generate must-links for each subtype. Taking the CD8+CD27+ T cells as an example, we require that the two randomly selected cells to form a must-link constraint should have both high CD8 protein levels (> 85th percentile) and high CD27 protein levels (> 85th percentile). In contrast, the two cells to form a must-link constraint for the CD8+CD27- subtype should have high CD8 protein levels (> 85th percentile) but low CD27 protein levels (< 30th percentile). For CD4+CD27+, CD4+CD27-DR+, CD4+CD27-DR- cells, we applied similar rules to construct must-links.

In [3]:
n_pairwise_1 = 20000
ml_ind1_1, ml_ind2_1, cl_ind1_1, cl_ind2_1 = generate_random_pair_from_proteins(protein_markers, n_pairwise_1, ML=0.005, CL=0.95)

print("Must link paris: %d" % ml_ind1_1.shape[0])
print("Cannot link paris: %d" % cl_ind1_1.shape[0])

n_pairwise_2 = 5000
ml_ind1_2, ml_ind2_2, cl_ind1_2, cl_ind2_2 = generate_random_pair_from_CD_markers(CD_markers, n_pairwise_2, low1=0.3, high1=0.7, low2=0.3, high2=0.85)

print("Must link paris: %d" % ml_ind1_2.shape[0])
print("Cannot link paris: %d" % cl_ind1_2.shape[0])

ml_ind1 = np.append(ml_ind1_1, ml_ind1_2)
ml_ind2 = np.append(ml_ind2_1, ml_ind2_2)
cl_ind1 = np.append(cl_ind1_1, cl_ind1_2)
cl_ind2 = np.append(cl_ind2_1, cl_ind2_2)

Must link paris: 1714
Cannot link paris: 18286
Must link paris: 428
Cannot link paris: 4572


Construct the model and pretrain the ZINB autoencoder for 300 epochs. Here we load the pretrained weights.

In [4]:
sd = 2.5
gamma = 1.0

model = scDCC(input_dim=adata.n_vars, z_dim=32, n_clusters=12, 
            encodeLayer=[256, 64], decodeLayer=[64, 256], sigma=sd, gamma=gamma,
            ml_weight=1., cl_weight=1.).cpu()

# model.pretrain_autoencoder(x=adata.X, X_raw=adata.raw.X, size_factor=adata.obs.size_factors, 
#                                batch_size=256, epochs=300, ae_weights="CITE_PBMC_AE_weights.pth.tar")


print("==> loading checkpoint '{}'".format("CITE_PBMC_AE_weights.pth.tar"))
checkpoint = torch.load("CITE_PBMC_AE_weights.pth.tar")
model.load_state_dict(checkpoint['ae_state_dict'])

==> loading checkpoint 'CITE_PBMC_AE_weights.pth.tar'


IncompatibleKeys(missing_keys=[], unexpected_keys=[])

Clustering with constraints.

In [5]:
if not os.path.exists("results"):
            os.makedirs("results")

y_pred, _, _, _, _ = model.fit(X=adata.X, X_raw=adata.raw.X, sf=adata.obs.size_factors, y=y, batch_size=256, num_epochs=2000, 
            ml_ind1=ml_ind1, ml_ind2=ml_ind2, cl_ind1=cl_ind1, cl_ind2=cl_ind2,
            update_interval=1, tol=0.001, save_dir="results")

Clustering stage
Initializing cluster centers with kmeans.




Initializing k-means: ACC= 0.5720, NMI= 0.6422, ARI= 0.4503
Clustering   1: ACC= 0.5720, NMI= 0.6422, ARI= 0.4503
#Epoch   1: Total: 0.4404 Clustering Loss: 0.1584 ZINB Loss: 0.2820
Pairwise Total: 23.281135215759278 ML loss 17.388620376586914 CL loss: 5.891135215759277
Clustering   2: ACC= 0.6914, NMI= 0.6605, ARI= 0.6361
#Epoch   2: Total: 0.4159 Clustering Loss: 0.1215 ZINB Loss: 0.2944
Pairwise Total: 17.739119930267336 ML loss 12.668181419372559 CL loss: 5.069119930267334
Clustering   3: ACC= 0.6933, NMI= 0.6825, ARI= 0.6690
#Epoch   3: Total: 0.4322 Clustering Loss: 0.1623 ZINB Loss: 0.2699
Pairwise Total: 15.697207927703857 ML loss 11.49691390991211 CL loss: 4.197207927703857
Clustering   4: ACC= 0.7065, NMI= 0.7043, ARI= 0.6824
#Epoch   4: Total: 0.4144 Clustering Loss: 0.1409 ZINB Loss: 0.2735
Pairwise Total: 13.983080272674561 ML loss 10.376644134521484 CL loss: 3.6030802726745605
Clustering   5: ACC= 0.7127, NMI= 0.7193, ARI= 0.6949
#Epoch   5: Total: 0.4145 Clustering Loss: