In [1]:
import torch
import numpy
import scipy
from scipy import sparse
import torch
import pandas
import pyBigWig

In [5]:
n_filters = 32
n_layers = 11
batch_size = 64

def MLLLoss(logps, true_counts):
	""" Adapted from Alex.
	"""
	# Multinomial probability = n! / (x1!...xk!) * p1^x1 * ... pk^xk
	# Log prob = log(n!) - (log(x1!) ... + log(xk!)) + x1log(p1) ... + xklog(pk)

	log_fact_sum = torch.lgamma(torch.sum(true_counts, dim=-1) + 1)
	log_prod_fact = torch.sum(torch.lgamma(true_counts + 1), dim=-1)
	log_prod_exp = torch.sum(true_counts * logps, dim=-1) 
	return -torch.mean(log_fact_sum - log_prod_fact + log_prod_exp)

def pearson_corr(arr1, arr2):
	"""The Pearson correlation between two draws of samples.

	This function is more efficient than the built-in `corrcoef` function
	because it only calculates the pairwise correlations between elements
	in arr1 and arr2 rather than the correlation between all elements in arr1
	and arr2, and does so in a vectorized manner.

	Computes the Pearson correlation in the last dimension of `arr1` and `arr2`.
	`arr1` and `arr2` must be the same shape. For example, if they are both
	A x B x L arrays, then the correlation of corresponding L-arrays will be
	computed and returned in an A x B array.

	Parameters
	----------
	arr1: numpy.ndarray, shape=(A, ..., L)
		An array of any dimensionality > 1 as long as the last dimension
		corresponds to a sample from the same distribution.

	arr2: numpy.ndarray, shape=(A, ..., L)
		An array of any dimensionality > 1 as long as the last dimension
		corresponds to a sample from the same distribution.	

	Returns
	-------
	corr : numpy.array, shape=(A, ...)
		The Pearson correlation along the last dimension.
	"""

	mean1 = numpy.mean(arr1, axis=-1, keepdims=True)
	mean2 = numpy.mean(arr2, axis=-1, keepdims=True)
	dev1, dev2 = arr1 - mean1, arr2 - mean2
	sqdev1, sqdev2 = numpy.square(dev1), numpy.square(dev2)
	numer = numpy.sum(dev1 * dev2, axis=-1)  # Covariance
	var1, var2 = numpy.sum(sqdev1, axis=-1), numpy.sum(sqdev2, axis=-1)  # Variances
	denom = numpy.sqrt(var1 * var2)
   
	# Divide numerator by denominator, but use NaN where the denominator is 0
	return numpy.divide(
		numer, denom, out=numpy.full_like(numer, numpy.nan), where=(denom != 0)
	)

class BPNet(torch.nn.Module):
	def __init__(self, n_celltypes, n_assays, n_filters=64, n_layers=4, trimming=None):
		super(BPNet, self).__init__()
		self.trimming = trimming or 2 ** n_layers
		self.n_filters = n_filters
		self.n_layers = n_layers
		self.n_celltypes = n_celltypes
		self.n_assays = n_assays

		self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10)
		self.rconvs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, dilation=2**i) for i in range(1, self.n_layers+1)
		])

		self.assay_convs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=75) for i in range(n_assays)
		])

		self.celltype_convs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=75) for i in range(n_celltypes)
		])

		#self.fconv = torch.nn.Conv1d(n_filters, (n_celltypes + n_assays) * n_filters, kernel_size=75)
		self.relu = torch.nn.ReLU()
		self.logsoftmax = torch.nn.LogSoftmax(dim=-1)

	def forward(self, X, celltype_idxs, assay_idxs):
		start, end = self.trimming, X.shape[2] - self.trimming

		X = self.relu(self.iconv(X))
		for i in range(self.n_layers):
			X_conv = self.relu(self.rconvs[i](X))
			X = torch.add(X, X_conv)

		X = X[:, :, start:end]
		
		X_celltype, X_assay = [], []
		for i, (celltype_idx, assay_idx) in enumerate(zip(celltype_idxs, assay_idxs)):
			Xc = self.celltype_convs[celltype_idx](X[i:i+1])
			X_celltype.append(Xc)

			Xa = self.assay_convs[assay_idx](X[i:i+1])
			X_assay.append(Xa)

		X_celltype = torch.cat(X_celltype)
		X_assay = torch.cat(X_assay)

		y_profile = torch.mul(X_celltype, X_assay)
		y_profile = torch.sum(y_profile, dim=1).squeeze()
		y_profile = self.logsoftmax(y_profile)

		# counts prediction
		#X_avg = torch.mean(X, axis=2)
		#y_counts = self.linear(X_avg) 
		return y_profile

	def predict(self, X, celltype_idxs, assay_idxs, batch_size=64):
		with torch.no_grad():
			starts = numpy.arange(0, X.shape[0], batch_size)
			ends = starts + batch_size

			y_hat = []
			for start, end in zip(starts, ends):
				y_hat_ = self(X[start:end], celltype_idxs[start:end],
					assay_idxs[start:end]).cpu().detach().numpy()
				y_hat.append(y_hat_)

			y_hat = numpy.concatenate(y_hat)
			return y_hat

	def fit_generator(self, training_data, optimizer, X_valid=None, 
		celltype_idxs_valid=None, assay_idxs_valid=None, y_valid=None, 
		max_epochs=100, batch_size=64, validation_iter=100, verbose=True):

		if X_valid is not None: 
			X_valid = X_valid.cuda()
			celltype_idxs_valid = celltype_idxs_valid.cuda()
			assay_idxs_valid = assay_idxs_valid.cuda()
		
		y_valid_ = y_valid.detach().numpy()

		if verbose:
			print("Epoch\tIteration\tTraining Time\tValidation Time\tTraining MLL\tValidation MLLL\tValidation Correlation")

		start = time.time()
		iteration = 0
		best_corr = 0

		for epoch in range(max_epochs):
			tic = time.time()

			for X, celltype_idxs, assay_idxs, y in training_data:
				X = X.cuda()
				celltype_idxs = celltype_idxs.cuda()
				assay_idxs = assay_idxs.cuda()
				y = y.cuda()

				optimizer.zero_grad()
				self.train()

				y_profile = self(X, celltype_idxs, assay_idxs)

				loss = MLLLoss(y_profile, y)
				train_loss = loss.item()
				loss.backward()

				optimizer.step()

				if verbose and iteration % validation_iter == 0:
					self.eval()

					train_time = time.time() - start
					tic = time.time()

					y_profile = self.predict(X_valid, celltype_idxs_valid, assay_idxs_valid, batch_size=batch_size)
					valid_loss = MLLLoss(y_profile, y_valid).item()

					y_profile = numpy.exp(y_profile)
					valid_corrs = numpy.mean(numpy.nan_to_num(pearson_corr(y_profile, y_valid_)))
					valid_time = time.time() - tic

					print("{}\t{}\t{:4.4}\t{:4.4}\t{:6.6}\t{:6.6}\t{:4.4}".format(
						epoch, iteration, train_time, valid_time, train_loss, valid_loss, 
						valid_corrs))
					start = time.time()

					if valid_corrs > best_corr:
						best_corr = valid_corrs

						self = self.cpu()
						torch.save(self, "/mnt/data/imputation_yangyuan/models/bpnet.{}.{}.torch".format(self.n_filters, self.n_layers))
						self = self.cuda()
				
				iteration += 1

In [3]:
model = torch.load("/mnt/data/imputation_yangyuan/models/bpnet.{}.{}.torch".format(n_filters, n_layers))
X_valid, celltype_idxs_valid, assay_idxs_valid, y_valid = torch.load('/mnt/data/imputation_yangyuan/data/tensor.pth')

In [6]:
model.cpu()
for i in range(1,max(assay_idxs_valid)+1):
    index_list = torch.where(assay_idxs_valid==i)
    X_valid_sel = X_valid[index_list]
    y_valid_sel = y_valid[index_list]
    celltype_idxs_valid_sel = celltype_idxs_valid[index_list]
    assay_idxs_valid_sel = assay_idxs_valid[index_list]
    try:
        y_profile_sel = model.predict(X_valid_sel, celltype_idxs_valid_sel, assay_idxs_valid_sel, batch_size=batch_size)
        valid_loss = MLLLoss(y_profile_sel, y_valid_sel).item()
        y_profile_sel = numpy.exp(y_profile_sel)
        y_valid_sel_ = y_valid_sel.detach().numpy()
        valid_corrs = numpy.mean(numpy.nan_to_num(pearson_corr(y_profile_sel, y_valid_sel_)))
        print(valid_loss, valid_corrs)
    except ValueError:
        print('skip')

16879.48046875 0.6765633
2178.964111328125 0.61263734
1305.018798828125 0.5102154
skip
skip
skip
1436.2794189453125 0.5518573
skip
2322.6044921875 0.5954517
skip
1464.83251953125 0.6632016
1682.96484375 0.64920574
skip
skip
2464.652587890625 0.5142322
1823.1298828125 0.5998121
1801.93896484375 0.5042928
skip
1681.7593994140625 0.43325308
2827.552978515625 0.31931686
3818.188720703125 0.46739948
skip
1586.700927734375 0.46862066
2059.775390625 0.3971115
3414.522216796875 0.49428535
skip
skip
1568.4915771484375 0.3894616
skip
skip
1661.671142578125 0.46577382
skip
skip
2507.698486328125 0.61141855
