In [None]:
import sys

import time 
import numpy
import scipy
from scipy import sparse
import torch
import pandas
import pyBigWig

from tqdm import tqdm
from torchsummary import summary



torch.backends.cudnn.benchmark = True

#@torch.jit.script
def MLLLoss(logps, true_counts):
	""" Adapted from Alex.
	"""
	# Multinomial probability = n! / (x1!...xk!) * p1^x1 * ... pk^xk
	# Log prob = log(n!) - (log(x1!) ... + log(xk!)) + x1log(p1) ... + xklog(pk)

	log_fact_sum = torch.lgamma(torch.sum(true_counts, dim=-1) + 1)
	log_prod_fact = torch.sum(torch.lgamma(true_counts + 1), dim=-1)
	log_prod_exp = torch.sum(true_counts * logps, dim=-1) 
	return -torch.mean(log_fact_sum - log_prod_fact + log_prod_exp)

def pearson_corr(arr1, arr2):
	"""The Pearson correlation between two draws of samples.

	This function is more efficient than the built-in `corrcoef` function
	because it only calculates the pairwise correlations between elements
	in arr1 and arr2 rather than the correlation between all elements in arr1
	and arr2, and does so in a vectorized manner.

	Computes the Pearson correlation in the last dimension of `arr1` and `arr2`.
	`arr1` and `arr2` must be the same shape. For example, if they are both
	A x B x L arrays, then the correlation of corresponding L-arrays will be
	computed and returned in an A x B array.

	Parameters
	----------
	arr1: numpy.ndarray, shape=(A, ..., L)
		An array of any dimensionality > 1 as long as the last dimension
		corresponds to a sample from the same distribution.

	arr2: numpy.ndarray, shape=(A, ..., L)
		An array of any dimensionality > 1 as long as the last dimension
		corresponds to a sample from the same distribution.	

	Returns
	-------
	corr : numpy.array, shape=(A, ...)
		The Pearson correlation along the last dimension.
	"""

	mean1 = numpy.mean(arr1, axis=-1, keepdims=True)
	mean2 = numpy.mean(arr2, axis=-1, keepdims=True)
	dev1, dev2 = arr1 - mean1, arr2 - mean2
	sqdev1, sqdev2 = numpy.square(dev1), numpy.square(dev2)
	numer = numpy.sum(dev1 * dev2, axis=-1)  # Covariance
	var1, var2 = numpy.sum(sqdev1, axis=-1), numpy.sum(sqdev2, axis=-1)  # Variances
	denom = numpy.sqrt(var1 * var2)
   
	# Divide numerator by denominator, but use NaN where the denominator is 0
	return numpy.divide(
		numer, denom, out=numpy.full_like(numer, numpy.nan), where=(denom != 0)
	)

class BPNet(torch.nn.Module):
	def __init__(self, n_celltypes, n_assays, n_filters=64, n_layers=4, trimming=None):
		super(BPNet, self).__init__()
		self.trimming = trimming or 2 ** n_layers
		self.n_filters = n_filters
		self.n_layers = n_layers
		self.n_celltypes = n_celltypes
		self.n_assays = n_assays

		self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10)
		self.rconvs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, dilation=2**i) for i in range(1, self.n_layers+1)
		])

		self.assay_convs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=75) for i in range(n_assays)
		])

		self.celltype_convs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=75) for i in range(n_celltypes)
		])

		#self.fconv = torch.nn.Conv1d(n_filters, (n_celltypes + n_assays) * n_filters, kernel_size=75)
		self.relu = torch.nn.ReLU()
		self.logsoftmax = torch.nn.LogSoftmax(dim=-1)

	def forward(self, X, celltype_idxs, assay_idxs):
		start, end = self.trimming, X.shape[2] - self.trimming

		X = self.relu(self.iconv(X))
		for i in range(self.n_layers):
			X_conv = self.relu(self.rconvs[i](X))
			X = torch.add(X, X_conv)

		X = X[:, :, start:end]
		
		X_celltype, X_assay = [], []
		for i, (celltype_idx, assay_idx) in enumerate(zip(celltype_idxs, assay_idxs)):
			Xc = self.celltype_convs[celltype_idx](X[i:i+1])
			X_celltype.append(Xc)

			Xa = self.assay_convs[assay_idx](X[i:i+1])
			X_assay.append(Xa)

		X_celltype = torch.cat(X_celltype)
		X_assay = torch.cat(X_assay)

		y_profile = torch.mul(X_celltype, X_assay)
		y_profile = torch.sum(y_profile, dim=1).squeeze()
		y_profile = self.logsoftmax(y_profile)

		# counts prediction
		#X_avg = torch.mean(X, axis=2)
		#y_counts = self.linear(X_avg) 
		return y_profile

	def predict(self, X, celltype_idxs, assay_idxs, batch_size=64):
		with torch.no_grad():
			starts = numpy.arange(0, X.shape[0], batch_size)
			ends = starts + batch_size

			y_hat = []
			for start, end in zip(starts, ends):
				y_hat_ = self(X[start:end], celltype_idxs[start:end],
					assay_idxs[start:end]).cpu().detach().numpy()
				y_hat.append(y_hat_)

			y_hat = numpy.concatenate(y_hat)
			return y_hat

	def fit_generator(self, training_data, optimizer, X_valid=None, 
		celltype_idxs_valid=None, assay_idxs_valid=None, y_valid=None, 
		max_epochs=100, batch_size=64, validation_iter=100, verbose=True):

		if X_valid is not None: 
			X_valid = X_valid.cuda()
			celltype_idxs_valid = celltype_idxs_valid.cuda()
			assay_idxs_valid = assay_idxs_valid.cuda()
		
		y_valid_ = y_valid.detach().numpy()

		if verbose:
			print("Epoch\tIteration\tTraining Time\tValidation Time\tTraining MLL\tValidation MLLL\tValidation Correlation")

		start = time.time()
		iteration = 0
		best_corr = 0

		for epoch in range(max_epochs):
			tic = time.time()

			for X, celltype_idxs, assay_idxs, y in training_data:
				X = X.cuda()
				celltype_idxs = celltype_idxs.cuda()
				assay_idxs = assay_idxs.cuda()
				y = y.cuda()

				optimizer.zero_grad()
				self.train()

				y_profile = self(X, celltype_idxs, assay_idxs)

				loss = MLLLoss(y_profile, y)
				train_loss = loss.item()
				loss.backward()

				optimizer.step()

				if verbose and iteration % validation_iter == 0:
					self.eval()

					train_time = time.time() - start
					tic = time.time()

					y_profile = self.predict(X_valid, celltype_idxs_valid, assay_idxs_valid, batch_size=batch_size)
					valid_loss = MLLLoss(y_profile, y_valid).item()

					y_profile = numpy.exp(y_profile)
					valid_corrs = numpy.mean(numpy.nan_to_num(pearson_corr(y_profile, y_valid_)))
					valid_time = time.time() - tic

					print("{}\t{}\t{:4.4}\t{:4.4}\t{:6.6}\t{:6.6}\t{:4.4}".format(
						epoch, iteration, train_time, valid_time, train_loss, valid_loss, 
						valid_corrs))
					start = time.time()

					if valid_corrs > best_corr:
						best_corr = valid_corrs

						self = self.cpu()
						torch.save(self, "/mnt/data/imputation_yangyuan/models/bpnet.{}.{}.torch".format(self.n_filters, self.n_layers))
						self = self.cuda()
				
				iteration += 1



class PeakGenerator(torch.utils.data.Dataset):
	def __init__(self, sequence, signal, assays, celltypes, peakfiles, trimming, window, chroms, reverse_complement=True, 
		random_state=None):
		self.trimming = trimming
		self.window = window
		self.chroms = chroms
		self.reverse_complement = reverse_complement
		self.random_state = numpy.random.RandomState(random_state)

		self.signal = {chrom: signal[chrom] for chrom in chroms}
		self.sequence = {chrom: sequence[chrom] for chrom in chroms}

		self.assays = assays
		self.celltypes = celltypes

		self.tracks = list(signal['chr18'].keys()) 

		self.peakfiles = {}
		for track, peaks in peakfiles.items():
			peaks['mid'] = (peaks['end'] - peaks['start']) // 2 + peaks['start']
			self.peakfiles[track] = peaks[numpy.isin(peaks['chrom'], chroms)].reset_index(drop=True)

		#print("Training Examples: {}".format(sum(map(len, self.peakfiles.values()))))
	

	def __len__(self):
		#return self.peaks.shape[0]
		return sum(map(len, self.peakfiles.values()))

	def __getitem__(self, idx):
		track_idx = numpy.random.choice(len(self.tracks))
		track = self.tracks[track_idx]
		peaks = self.peakfiles[track]

		i = numpy.random.choice(peaks.shape[0])
		mid = peaks['mid'][i]
		mid += self.random_state.randint(-128, 129)
		start, end = mid - self.window // 2, mid + self.window // 2
		
		chrom = 'chr18'

		celltype_idx = self.celltypes.index(track[0])
		assay_idx = self.assays.index(track[1])

		X = self.sequence[chrom][start:end].T
		y = self.signal[chrom][track][start+trimming:end-trimming]

		if self.reverse_complement and numpy.random.choice(2) == 1:
			X = X[::-1][:, ::-1]
			y = y[::-1]

		X = torch.tensor(X.copy(), dtype=torch.float32)
		y = torch.tensor(y.copy())
		celltype_idx = torch.LongTensor([celltype_idx])
		assay_idx = torch.LongTensor([assay_idx])
		return X, celltype_idx, assay_idx, y


def validation_data(sequence, signal, assays, celltypes, peakfiles, trimming=2**4, window=1000, chroms=None, size=3000, random_state=None):
	sequence = {chrom: sequence[chrom] for chrom in chroms}
	signal = {chrom: signal[chrom] for chrom in chroms}

	for track, peaks in peakfiles.items():
		peaks['mid'] = (peaks['end'] - peaks['start']) // 2 + peaks['start']
		peakfiles[track] = peaks[numpy.isin(peaks['chrom'], chroms)].reset_index(drop=True)
	
	X = numpy.zeros((size, 4, window), dtype='float32')
	y = numpy.zeros((size, window-trimming*2), dtype='float32')

	chrom = 'chr18'
	n = len(sequence[chrom])
	track_names = list(signal[chrom].keys())

	random_state = numpy.random.RandomState(random_state)

	track_idxs = random_state.choice(len(track_names), size=size)
	tracks = [track_names[track_idx] for track_idx in track_idxs]

	celltype_idxs = [celltypes.index(celltype) for celltype, _ in tracks]
	assay_idxs = [assays.index(assay) for _, assay in tracks]

	for i, track in enumerate(tracks):
		peaks = peakfiles[track]
		j = numpy.random.choice(peaks.shape[0])
		mid = peaks['mid'][j]

		start, end = mid - window // 2, mid + window // 2

		X[i] = sequence[chrom][start:end].T
		y[i] = signal[chrom][track][start+trimming:end-trimming]

	X = torch.tensor(X.copy())
	y = torch.tensor(y.copy())
	celltype_idxs = torch.LongTensor(celltype_idxs)
	assay_idxs = torch.LongTensor(assay_idxs)
	return X, celltype_idxs, assay_idxs, y


In [None]:
n_filters = 32
n_layers = 8


window = 2114
trimming = (window - 1000) // 2
chroms = ['chr18']

training_metadata = pandas.read_csv("/users/jmschr/oak/proj/2021_sequence_imputation/scripts/metadata_training_data.tsv", sep="\t")
training_tracks = training_metadata[['cell_type_id', 'mark_id']].values

validation_metadata = pandas.read_csv("/users/jmschr/oak/proj/2021_sequence_imputation/scripts/metadata_validation_data.tsv", sep="\t")
validation_tracks = validation_metadata[['cell_type_id', 'mark_id']].values

celltypes = list(numpy.unique(training_tracks[:,0]))
assays = list(numpy.unique(training_tracks[:,1]))

n_celltypes = len(celltypes)
n_assays = len(assays)

seq_dir = "/users/jmschr/oak/common/hg38/"

###
# Load single-cell data
#print("Loading data")

sequence = {}
training_signal, validation_signal = {}, {}
training_peaks, validation_peaks = {}, {}

for chrom in chroms:
	training_signal[chrom] = {}
	for track in tqdm(training_tracks):
		track = tuple(track)
		
		sig = numpy.load("/users/jmschr/oak/proj/2021_sequence_imputation/data/tracks/{}{}.{}.npy".format(*track, chrom), mmap_mode='r')
		training_signal[chrom][track] = sig

		names = ['chrom', 'start', 'end']
		peaks = pandas.read_csv("/users/jmschr/oak/proj/2021_sequence_imputation/data/peaks/{}{}.bed.gz".format(*track), sep="\t", 
			usecols=(0, 1, 2), header=None, index_col=False, names=names)
		training_peaks[track] = peaks

	validation_signal[chrom] = {}
	for track in tqdm(validation_tracks):
		track = tuple(track)

		sig = numpy.load("/users/jmschr/oak/proj/2021_sequence_imputation/data/tracks/{}{}.{}.npy".format(*track, chrom), mmap_mode='r')
		validation_signal[chrom][track] = sig

		names = ['chrom', 'start', 'end']
		peaks = pandas.read_csv("/users/jmschr/oak/proj/2021_sequence_imputation/data/peaks/{}{}.bed.gz".format(*track), sep="\t", 
			usecols=(0, 1, 2), header=None, index_col=False, names=names)
		validation_peaks[track] = peaks

	seq = numpy.load("{}/{}.npy".format(seq_dir, chrom), mmap_mode='r')
	sequence[chrom] = seq

	#print("Loaded {}".format(chrom))

#print("Done loading data")
###

training_peaks = PeakGenerator(
	sequence=sequence,
	signal=training_signal,
	assays=assays,
	celltypes=celltypes,
	peakfiles=training_peaks, 
	trimming=trimming, 
	window=window, 
	chroms=chroms)

training_data = torch.utils.data.DataLoader(training_peaks, 
	pin_memory=True, 
	batch_size=batch_size)

X_valid, celltype_idxs_valid, assay_idxs_valid, y_valid = validation_data(
	sequence=sequence,
	signal=validation_signal,
	assays=assays,
	celltypes=celltypes,
	peakfiles=validation_peaks,
	trimming=trimming, 
	window=window,
	chroms=chroms)



100%|██████████| 267/267 [02:53<00:00,  1.54it/s]
100%|██████████| 45/45 [00:53<00:00,  1.19s/it]


Epoch	Iteration	Training Time	Validation Time	Training MLL	Validation MLLL	Validation Correlation
0	0	5.194	1.793	13468.8	8681.63	0.0006257


In [None]:
batch_size = 64


n_layers_list = [7,8,9,10]
for n_layers in n_layers_list:

	model = BPNet(n_celltypes=n_celltypes, n_assays=n_assays, 
		n_filters=n_filters, n_layers=n_layers, trimming=trimming-37).cuda()

	#print(n_celltypes, n_assays)

	optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
	#summary(model, [(4, window), (1, 1), (1, 1)], batch_size=batch_size, dtypes=torch.long)

	model.fit_generator(training_data, optimizer, X_valid, celltype_idxs_valid,
		assay_idxs_valid, y_valid, max_epochs=30, validation_iter=100, 
		batch_size=batch_size)
