In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
from os.path import join as pj
from sklearn.neighbors import BallTree
from torch import nn
import random
import torch as th
import os
from tqdm import tqdm
import time
from scsampler import scsampler
import matplotlib.pyplot as plt
plt.rc('font', size=15)


# define ball tree sampling (proposed)

In [None]:
def BallTreeSubsample(X, target_size, ls=10):
    tree = BallTree(X, leaf_size = ls)
    layer = int(np.log2(len(X)//ls))
    # layer
    t = [1]
    for i in range(layer+1):
        t.append(t[i]*2)
    t = [i-1 for i in t]
    t.sort(reverse=True)
    # t
    nodes = tree.get_arrays()[2]
    order = tree.get_arrays()[1]
    target = []
    for l in range(layer):
        # print('layer ', layer-l)
        if len(target) < target_size:
            s = (target_size - len(target)) // (t[l:l+2][0]- t[l:l+2][1])
        else:
            break
        for node in nodes[t[l:l+2][1]:t[l:l+2][0]]:
            
            start_id = node[0]
            end_id = node[1]
            available_order = list(set(order[start_id:end_id])-set(target))
            random.shuffle(available_order)
            target.extend(available_order[0:s])
    return target


# read data to subsample

In [None]:
data = pd.DataFrame([])
batch = []
total = 0
for i in range(42):
    path = '../../../MIRACLE-reproducibility/result/dcm_hcm/offline/default/predict/subset_'+str(i)+'/z/joint/'
    num = 0
    print(i)
    for j in tqdm(sorted(os.listdir(path))):
        d = pd.read_csv(os.path.join(path, j), header=None, index_col=None)
        d.index = d.index + total
        data = pd.concat([data, d])
        num += d.shape[0]
        total += d.shape[0]
    batch.extend([i for n in range(num)])

adata = sc.AnnData(data.loc[:, :31])
sc.pp.subsample(adata, n_obs=10000, random_state=42)
adata

In [None]:
pd.DataFrame(adata.X, index=adata.obs_names).to_csv('./reference.csv')

In [None]:
# data = pd.read_csv('./reference.csv', index_col=0)
# adata = sc.AnnData(data)
# adata

# experiment setting

In [None]:
num = [10000//pow(2,i) for i in range(1, 7)]
repeat_num  = 50

# random sampling

In [None]:
task = 'random'
if not os.path.exists('./%s_result'%task):
    os.mkdir('./%s_result'%task)
for n in tqdm(num):
    t = []
    id = []
    for i in range(repeat_num):
        samples = [] 
        start = time.time()
        adata_sub = sc.pp.subsample(adata, n_obs=n, random_state=i, copy=True)
        end = time.time()
        t.append(end - start)
        id.append(adata_sub.obs_names)
        pd.DataFrame(id).to_csv('./%s_result/sample_id_%d.csv'%(task, n))
        pd.DataFrame(t).to_csv('./%s_result/sample_time_%d.csv'%(task, n))

# ball tree sampling

In [None]:
task = 'ball-tree'
if not os.path.exists('./%s_result'%task):
    os.mkdir('./%s_result'%task)
x = adata.X
for n in tqdm(num):
    t = []
    id = []
    for i in range(repeat_num):
        samples = [] 
        start = time.time()
        id_sample = BallTreeSubsample(x, n, ls=10)
        end = time.time()
        t.append(end - start)
        id.append(id_sample)
        pd.DataFrame(id).to_csv('./%s_result/sample_id_%d.csv'%(task, n))
        pd.DataFrame(t).to_csv('./%s_result/sample_time_%d.csv'%(task, n))

# scscampler

In [None]:
adata.obsm['X_emb'] = adata.X

In [None]:
task = 'scsampler'
if not os.path.exists('./%s_result'%task):
    os.mkdir('./%s_result'%task)
for n in tqdm(num):
    t = []
    id = []
    for i in range(repeat_num):
        samples = [] 
        start = time.time()
        adata_sub = scsampler(adata, n_obs=n, obsm = 'X_emb', copy = True, random_state=random.randint(1,100))
        end = time.time()
        t.append(end - start)
        id.append(adata_sub.obs_names)
        pd.DataFrame(id).to_csv('./%s_result/sample_id_%d.csv'%(task, n))
        pd.DataFrame(t).to_csv('./%s_result/sample_time_%d.csv'%(task, n))

# seurat sketch

In [None]:
task = 'seurat sketch'
if not os.path.exists('compare_subsampling/%s_result'%task):
    os.mkdir('compare_subsampling/%s_result'%task)

In [None]:
# run subsample_sketch.ipynb

# compute metrics

In [None]:
data = pd.read_csv('reference.csv', header=0, index_col=0)
adata = sc.AnnData(data)
adata
adata.obs_names = np.array(adata.obs_names).astype(str)

In [None]:
# https://github.com/ZongxianLee/MMD_Loss.Pytorch (with modifications)

class MMD_loss(nn.Module):
	def __init__(self, kernel_mul = 2.0, kernel_num = 5):
		super(MMD_loss, self).__init__()
		self.kernel_num = kernel_num
		self.kernel_mul = kernel_mul
		self.fix_sigma = None
		return
	def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
		n_samples = int(source.size()[0])+int(target.size()[0])
		total = th.cat([source, target], dim=0)

		total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
		total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
		L2_distance = ((total0-total1)**2).sum(2) 
		if fix_sigma:
			bandwidth = fix_sigma
		else:
			bandwidth = th.sum(L2_distance.data) / (n_samples**2-n_samples)
		bandwidth /= kernel_mul ** (kernel_num // 2)
		bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
		kernel_val = [th.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
		return sum(kernel_val)
	def forward(self, source, target):
		batch_size = int(source.size()[0])
		kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
		XX = kernels[:batch_size, :batch_size]
		YY = kernels[batch_size:, batch_size:]
		XY = kernels[:batch_size, batch_size:]
		YX = kernels[batch_size:, :batch_size]
		loss = th.mean(XX) + th.mean(YY) - th.mean(XY) -th.mean(YX)
		return loss

In [None]:
mmd = MMD_loss()

In [None]:
task = 'random'
dir = './%s_result/'%task
mmd_list = {}
for i in num:
    mmd_list[i] = []
    id = pd.read_csv(dir + 'sample_id_' + str(i)+'.csv', index_col=0)
    for j in tqdm(range(len(id))):
        d = id.iloc[j].values.astype('str')
        mmd_list[i].append(mmd(th.from_numpy(adata.X), th.from_numpy(adata[d].X)).tolist())
pd.DataFrame(mmd_list).to_csv('./%s_mmd.csv'%task)

In [None]:
task = 'scsampler'
dir = './%s_result/'%task
mmd_list = {}
for i in num:
    mmd_list[i] = []
    id = pd.read_csv(dir + 'sample_id_' + str(i)+'.csv', index_col=0)
    for j in tqdm(range(len(id))):
        d = id.iloc[j].values.astype('str')
        mmd_list[i].append(mmd(th.from_numpy(adata.X), th.from_numpy(adata[d].X)).tolist())
pd.DataFrame(mmd_list).to_csv('./%s_mmd.csv'%task)

In [None]:
task = 'sketch'
dir = './%s_result/'%task
mmd_list = {}
for i in num:
    mmd_list[i] = []
    id = pd.read_csv(dir + 'sample_id_' + str(i)+'.csv', index_col=0, header=0).T
    for j in tqdm(range(len(id))):
        d = id.iloc[j].values
        mmd_list[i].append(mmd(th.from_numpy(adata.X), th.from_numpy(adata.X[d])).tolist())
pd.DataFrame(mmd_list).to_csv('./%s_mmd.csv'%task)

In [None]:
fig = plt.figure(figsize=(5,4), dpi=150)
num = ['1/2', '1/4', '1/8', '1/32','1/64', '1/128']
all_mmd = {}
marker = {
    'random':'--',
    'ball-tree':'-*'
}
color =  {
    'random':'#1A699E',
    'sketch':'#984ea3',
    'scsampler':'#277C24',
    'ball-tree':'#e41a1c'
}
for i in ['random', 'ball-tree']:
    all_mmd[i] = pd.read_csv(i+'_mmd.csv', index_col=0)
h1 = plt.violinplot(all_mmd['ball-tree'] ,positions=list(range(0, 12, 2)), showmeans=True)
h2 = plt.violinplot(all_mmd['random'] ,positions=list(range(1, 12, 2)), showmeans=True)
for pc in h1['bodies']:
    pc.set_facecolor('#e41a1c')
    pc.set_edgecolor('#e41a1c')
for i in ['cbars', 'cmins', 'cmaxes', 'cmeans']:
    h1[i].set_edgecolor('#e41a1c')
for pc in h2['bodies']:
    pc.set_facecolor('#FF7B23')
    pc.set_edgecolor('#FF7B23')
for i in ['cbars', 'cmins', 'cmaxes', 'cmeans']:
    h2[i].set_edgecolor('#FF7B23')
plt.legend(labels=['ball-tree', 'random'],loc='upper left', handles=[h1['bodies'][0],h2['bodies'][0]])

plt.xticks([0.5, 2.5, 4.5, 6.5, 8.5, 10.5], num)
plt.tight_layout()
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_bts_random.pdf')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_bts_random.png')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_bts_random.svg')

In [None]:
fig = plt.figure(figsize=(5,4), dpi=150)
num = ['1/2', '1/4', '1/8', '1/32','1/64', '1/128']
all_mmd = {}
marker = {
    'random':'--',
    'ball-tree':'-*'
}
color =  {
    'random':'#4daf4a',
    'sketch':'#fc8d62',
    'scsampler':'#ffd92f',
    'ball-tree':'#e41a1c'
}
for i in ['random', 'ball-tree']:
    all_mmd[i] = pd.read_csv(i+'_mmd.csv', index_col=0)
h1 = plt.violinplot(all_mmd['ball-tree'] ,positions=list(range(0, 12, 2)), showmeans=True)
h2 = plt.violinplot(all_mmd['random'] ,positions=list(range(1, 12, 2)), showmeans=True)
for pc in h1['bodies']:
    pc.set_facecolor('orangered')
    pc.set_edgecolor('orangered')
for i in ['cbars', 'cmins', 'cmaxes', 'cmeans']:
    h1[i].set_edgecolor('orangered')
for pc in h2['bodies']:
    pc.set_facecolor('dodgerblue')
    pc.set_edgecolor('dodgerblue')
for i in ['cbars', 'cmins', 'cmaxes', 'cmeans']:
    h2[i].set_edgecolor('dodgerblue')
plt.legend(labels=['ball-tree', 'random'],loc='upper left', handles=[h1['cbars'],h2['cbars']])

plt.xticks([0.5, 2.5, 4.5, 6.5, 8.5, 10.5], num)
plt.tight_layout()
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_bts_random.pdf')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_bts_random.png')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_bts_random.svg')

In [None]:
fig = plt.figure(figsize=(5,4),dpi=150)
all_mmd = {}
marker = {
    'random':'--',
    'sketch':'-^',
    'scsampler':'-o',
    'ball-tree':'-*'
}
color =  {
    'random':'#FF7B23',
    'sketch':'#377eb8',
    'scsampler':'#570F69',
    'ball-tree':'#C21316'
}

for i in ['random', 'sketch', 'scsampler', 'ball-tree']:
    all_mmd[i] = pd.read_csv(i+'_mmd.csv', index_col=0)
for i in ['random', 'sketch', 'scsampler', 'ball-tree']:
    plt.plot(list(range(6)), all_mmd[i].mean(), marker[i], color=color[i])

plt.yscale('log')
plt.xticks(list(range(6)), num)

plt.xticks(list(range(6)), num)
plt.legend(['random', 'sketch', 'scSampler', 'ball-tree'])
plt.tight_layout()
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_all.pdf')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_all.png')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/MMD_all.svg')

In [None]:
n = [5000, 2500, 1250, 625, 312, 156]
fig = plt.figure(figsize=(5, 4),dpi=150)
marker = {
    'random':'--',
    'sketch':'-^',
    'scsampler':'-o',
    'ball-tree':'-*'
}

color =  {
    'random':'#FF7B23',
    'sketch':'#377eb8',
    'scsampler':'#570F69',
    'ball-tree':'#C21316'
}

all_time = {}

for task in ['random', 'sketch', 'scsampler', 'ball-tree']:
    dir = './%s_result/'%task
    time_list = {}
    for i in n:
        time = pd.read_csv(dir + 'sample_time_' + str(i)+'.csv', index_col=0, header=0).T
        if time.shape==(1, 50):
            time_list[i] = [time.values.mean()]
        else:
            time_list[i] = time[1].values.mean()
    pd.DataFrame(time_list, index=[0]).to_csv('./%s_time.csv'%task)

for i in ['random', 'sketch', 'scsampler', 'ball-tree']:
    all_time[i] = pd.read_csv(i+'_time.csv', index_col=0)
for i in ['random', 'sketch', 'scsampler', 'ball-tree']:
    plt.plot(list(range(6)), all_time[i].values[0], marker[i], color=color[i])
plt.legend(['random', 'sketch', 'scSampler', 'ball-tree'])
plt.xticks(list(range(6)), num)
plt.legend(['random', 'sketch', 'scSampler', 'ball-tree'])
plt.tight_layout()
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/time.pdf')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/time.png')
plt.savefig('../../../MIRACLE-reproducibility/subsample_fig/time.svg')