In [1]:
import csv
import math 
import random
import gzip
import torch
from sklearn import metrics
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import os 
import warnings
warnings.filterwarnings("ignore")
import copy
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from torch import einsum
import seaborn as sns
import subprocess
sns.set()

In [2]:
#weblogo_opts = '-X NO -Y NO --errorbars NO --fineprint ""'
weblogo_opts = '-X NO --fineprint ""'
weblogo_opts += ' -C "#CB2026" A A'
weblogo_opts += ' -C "#34459C" C C'
weblogo_opts += ' -C "#FBB116" G G'
# embed
weblogo_opts += ' -C "#0C8040" T T'

# DATA Preprocessing functions

In [3]:
def generate_chrom_train_test(peaks,Chr_dict):
    peak_chromosomes = {}
    n_peaks = 0
    chrom_train = []
    chrom_test = []
    for line in peaks:
        split_line = 'chromosome='+line.split('\t')[0]
        peak_chromosomes[split_line] = peak_chromosomes.get(split_line,0) + 1
        n_peaks+=1
    n_peaks-=peak_chromosomes.pop('chromosome=pombeIII')
    size_test = int(n_peaks/10)
    npeaks_test = 0
    while npeaks_test<size_test:
        chrom = random.choice(list(Chr_dict.keys()))
        npeaks_test += peak_chromosomes[chrom]
        chrom_test.append(chrom)
        if npeaks_test>int(n_peaks*1.3/10):
            npeaks_test = 0
            chrom_test = []
    for chrom in Chr_dict.keys():
        if chrom not in chrom_test:
            chrom_train.append(chrom)
    return chrom_train,chrom_test

In [4]:
def genome_data(data_file):
    data=open(data_file).read()
    chromosomes_data = data.split('>')[1:]
    return (chromosomes_data)

In [5]:
def Read_bed_file(chromosomes_data,peak_file,seq_length,Chr_dict,cross_chrom):
    peaks = open(peak_file).readlines()
    peak_sequences = []
    if cross_chrom:
        peak_sequences_train,peak_sequences_test = [],[]
        chrom_train , chrom_test = generate_chrom_train_test(peaks,Chr_dict)
    for peak in peaks:
        peak_split = peak.split('\t')
        Chr = 'chromosome='+str(peak_split[0])
        if Chr in Chr_dict:
            chrom_seq = Chr_dict[Chr]
            n = len(chrom_seq)
            start_idx = max(int(peak_split[1])-seq_length//2,0)
            end_idx = min(int(len(chrom_seq)), start_idx+seq_length)
            if end_idx == len(chrom_seq):
                start_idx = end_idx - seq_length
            header = Chr+':{0}-{1}'.format(start_idx,end_idx)
            if not cross_chrom:
                peak_sequences.append([header,chrom_seq[start_idx:end_idx]])
            else:
                if Chr in chrom_train:
                    peak_sequences_train.append([header,chrom_seq[start_idx:end_idx]])
                else:
                    peak_sequences_test.append([header,chrom_seq[start_idx:end_idx]])
    if not cross_chrom:
        size=int(len(peak_sequences)/10)
        peak_sequences_train = peak_sequences[:9*size]
        peak_sequences_test = peak_sequences[9*size:]
    return (peak_sequences_train,peak_sequences_test)

In [6]:
def dinucshuffle(sequence):
    b=[sequence[i:i+2] for i in range(0, len(sequence), 2)]
    random.shuffle(b)
    d=''.join([str(x) for x in b])
    return d

In [7]:
def seqtopad(sequence, motif_len):
    rows=len(sequence)+2*motif_len-2
    S=np.empty([rows,4])
    base=['A', 'C', 'G', 'T']
    for i in range(rows):
        for j in range(4):
            if (i-motif_len+1<len(sequence) and sequence[i-motif_len+1]=='N' 
                or i<motif_len-1 or i>len(sequence)+motif_len-2):
                S[i,j]=np.float32(0.25)
            elif sequence[i-motif_len+1]==base[j]:
                S[i,j]=np.float32(1)
            else:
                S[i,j]=np.float32(0)
    return np.transpose(S)

In [8]:
def generate_onehot_data(peak_sequences,motif_length,label,include_dinuc):
    alldata = []
    for header,seq in peak_sequences:
        alldata.append([header,seq,seqtopad(seq,motif_length),[int(label)]])#
        if include_dinuc:
            shuff_seq = dinucshuffle(seq)
            alldata.append([header,shuff_seq,seqtopad(shuff_seq,motif_length),[0]])#
    return (alldata)

In [47]:
def extract_data(data_file,peak_file,motif_length=24,seq_length=150,cross_chrom=False,include_dinuc=True,Chr_dict=None):
    chromosomes_data = genome_data(data_file)
    if Chr_dict==None:
        Chr_dict = {}
        for chrom_data in chromosomes_data:
            ref = chrom_data.split('\n')[0].split(' ')[-1][1:-1]
            Chr_dict[ref]=''.join(chrom_data.split('\n')[1:])
        Chr_dict['chromosome=Mito'] = Chr_dict.pop('top=circular')
    if type(peak_file) == str:
        peak_sequences_train,peak_sequences_test = Read_bed_file(chromosomes_data,peak_file,seq_length,Chr_dict,cross_chrom)
        train_data = generate_onehot_data(peak_sequences_train,motif_length,1,include_dinuc)
        test_data = generate_onehot_data(peak_sequences_test,motif_length,1,include_dinuc)
        random.shuffle(train_data)
        size=int(len(train_data)/10)
        calib_data=train_data[:9*size]
        valid_data=train_data[9*size:]
    elif type(peak_file) == list: #Using the data in each file as samples associated with 1 label.
        peak_sequences_train,peak_sequences_test = [],[]
        train_data,test_data = [],[]
        for i in range (len(peak_file)):
            peak_sequences_train_temp,peak_sequences_test_temp = Read_bed_file(chromosomes_data,peak_file[i],seq_length,Chr_dict,cross_chrom)
            peak_sequences_train.extend(peak_sequences_train_temp)
            peak_sequences_test.extend(peak_sequences_test_temp)
            label = 1 if include_dinuc else i 
            train_data.extend(generate_onehot_data(peak_sequences_train_temp,motif_length,label,include_dinuc))
            test_data.extend(generate_onehot_data(peak_sequences_test_temp,motif_length,label,include_dinuc))
        random.shuffle(train_data)
        random.shuffle(test_data)
        size=int(len(train_data)/10)
        calib_data=train_data[:9*size]
        valid_data=train_data[9*size:]
    return calib_data,valid_data,train_data,test_data,peak_sequences_train,peak_sequences_test

In [10]:
class dataset(Dataset):
    def __init__(self, xy=None):
        self.header=[el[0] for el in xy]
        self.seq =[el[1] for el in xy ]
        self.x_data=np.asarray([el[2] for el in xy],dtype=np.float32)
        self.y_data =np.asarray([el[3] for el in xy ],dtype=np.float32)
        self.x_data = torch.from_numpy(self.x_data)
        self.y_data = torch.from_numpy(self.y_data)
        self.length=len(self.x_data)

    def __getitem__(self, index):
        return self.header[index],self.seq[index],self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.length

# Networks

In [11]:
class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.to_attn_logits = nn.Parameter(torch.eye(dim)) 

    def forward(self, x):
        attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits) 
        attn = attn_logits.softmax(dim = -2) 
        return (x * attn).sum(dim = -2).squeeze()

In [12]:
class Network(nn.Module):
    def __init__(self, num_motif , motif_len , num_conv_layers , dropprob):
        super(Network, self).__init__()
        self.num_motif = num_motif
        self.conv = [nn.Conv1d(4, num_motif, kernel_size=motif_len),nn.ReLU(inplace=True)]
        in_channels = num_motif
        for i in range (num_conv_layers-1):
            motif_len = motif_len//2
            self.conv.append(nn.MaxPool1d(kernel_size=3))
            self.conv.append(nn.Conv1d(in_channels, int(1.5*in_channels), kernel_size=motif_len))
            self.conv.append(nn.ReLU(inplace=True))
            in_channels = int(1.5*in_channels)
        self.conv_layer = nn.Sequential(*self.conv)
        self.project = AttentionPool(in_channels)
        self.classifier = nn.Sequential(
            nn.Linear(in_channels , in_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropprob, inplace=False),
            nn.Linear(in_channels, 1),
            nn.Sigmoid())
    def forward(self, x):
        x = self.conv_layer(x)
        x= x.permute(0, 2, 1)
        x = self.project(x)
        predict = self.classifier(x)
        return predict

In [13]:
class conv_output(nn.Module):
    def __init__(self,filter_weights,filter_bias,device):
        super(conv_output, self).__init__()
        if type(filter_weights) is np.ndarray:
            self.filter_weights =  torch.from_numpy(filter_weights.astype(np.float32)).to(device)
        else :
            self.filter_weights = filter_weights.to(device)
        if type(filter_bias) is np.ndarray:
            self.filter_bias =  torch.from_numpy(filter_bias.astype(np.float32)).to(device)
        else :
            self.filter_bias = filter_bias.to(device)
        
    def forward(self,x):
        x=F.conv1d(x, self.filter_weights, bias=self.filter_bias, stride=1, padding=0)
        out=x.clamp(min=0)
        return (out)

In [14]:
### printing parameters ------------------
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table.get_string())
    print(f"\n Total Trainable Params: {total_params}")

# Calib - Train - Test functions 

In [15]:
def Train_model(model,train_loader,valid_loader, l_rate=0.01 , maxepochs=100,epochs_for_early_stop=0,save_model=False):
    if not new:
        try:
            best_model = torch.load(model_dir+'/best_model.pkl')
            return(best_model,maxepochs)
        except:
            print('Pretrained model not found. Training the model')
    best_model = None
    best_loss = np.inf
    counter = 0
    nepochs=0
    valid_losses =[]
    train_losses = []
    optimizer = torch.optim.AdamW(model.parameters(),lr=l_rate,weight_decay=1e-05)
    criterion = nn.BCELoss(reduction='mean')
    while nepochs<maxepochs:
        model.train()
        train_loss=0
        for i, (header, seq, data, target) in enumerate(train_loader):
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = criterion(output, target)#
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
        if verbose:
            print('Model trained for {0} epochs out of {2}. Training loss is {1}'.format(nepochs+1,loss.item(),maxepochs))
        train_losses.append(train_loss/(i+1))
        with torch.no_grad():
            model.eval() 
            valid_loss=0
            for i, (header, seq, data, target) in enumerate(valid_loader):
                data = data.to(device)
                target = target.to(device)
                output = model(data)
                loss = F.binary_cross_entropy(output, target)#
                valid_loss+=loss.item()
            valid_losses.append(valid_loss/(i+1))
        counter+=1
        nepochs +=1
        if epochs_for_early_stop>0:
            if valid_losses[-1]<best_loss:
                if verbose:
                    print('Validation loss decreased from {0} to {1}'.format(best_loss,valid_losses[-1]))
                best_loss = valid_losses[-1]
                best_model = model
                counter = 0
            else:
                if verbose:
                    print('Counter for early stopping: {0} out of {1}'.format(counter,epochs_for_early_stop))
                if counter == epochs_for_early_stop:
                    print('early stopping at epoch ', nepochs-counter)
                    if save_model:
                        torch.save(best_model,model_dir+'/best_model.pkl')
                        count_parameters(best_model)
                    return (best_model,nepochs-counter)
    print('no early stopping')
    if save_model:
        torch.save(model,model_dir+'/best_model.pkl')
        count_parameters(model)
    return (model,nepochs)

In [16]:
def Test_model(model,test_loader):
    with torch.no_grad():
        model.eval()
        pred_list = []
        labels_list = []
        for i, (header, seq, data, target) in enumerate(test_loader):
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            pred=output.cpu().detach().numpy().reshape(output.shape[0])
            labels=target.cpu().numpy().reshape(output.shape[0])
            pred_list.append(pred)
            labels_list.append(labels)
        labels = np.concatenate(labels_list)
        predictions = np.concatenate(pred_list)
    auc = metrics.roc_auc_score(labels, predictions)
    precision, recall, thresholds = metrics.precision_recall_curve(labels, predictions)
    prc = (metrics.auc(recall, precision))
    file = open('%s/results.txt'%results_dir, 'w')
    header_cols = ('N_test_samples','AUROC', 'AURPC')
    print('%5s  %10s  %10s' % header_cols, file=file)
    row_cols = (len(predictions),auc,prc)
    print('%5i  %19.4f  %10.4f'%row_cols,file=file)
    if verbose:
        print('AUROC on test data ', auc)
        print('AUPRC on test data ', prc)
    return (auc,prc)

In [17]:
def Calibrate_model(calib_loader,valid_loader, num_motif_list, num_conv_layers_list , dropprob_list,learning_rate_list, 
                    max_num_models=40, maxepochs=100,epochs_for_early_stop=0 , motif_len=24 ):
    if not new:
        try:
            best_hyperparameters=torch.load(model_dir+'/best_hyperpamarameters.pth')
            results = pd.read_csv(model_dir+'/calibration_df.csv')
            return(best_hyperparameters,results)
        except:
            print('Calibration results not found. Calibrating the model')
    results=pd.DataFrame(columns=['num_conv_layers','num_motif','Dropout','Learning Rate','epochs','AUROC','AUPRC'])
    best_AUC = 0
    if verbose:
        print('Training on ',device)
    for number in range(max_num_models):
        print('model {0} out of {1}'.format(number+1,max_num_models))
        # hyper-parameters
        num_motif = random.choice(num_motif_list)
        num_conv_layers = random.choice(num_conv_layers_list)
        dropprob = random.choice(dropprob_list)
        l_rate = random.choice(learning_rate_list)
        while ((results['num_conv_layers']==num_conv_layers) & (results['num_motif']==num_motif) 
               & (results['Dropout']==dropprob) & (results['Learning Rate']==l_rate)).any(): 
            #if hyperparameters exist in the results dataframe then randomly choose other parameters
            num_motif = random.choice(num_motif_list)
            num_conv_layers = random.choice(num_conv_layers_list)
            dropprob = random.choice(dropprob_list)
            l_rate = random.choice(learning_rate_list)
        model = Network(num_motif , motif_len , num_conv_layers , dropprob).to(device)#num_conv_layers,dropprob
        best_model,epochs = Train_model(model,calib_loader,valid_loader,l_rate ,maxepochs,epochs_for_early_stop)
        auc,prc = Test_model(best_model,valid_loader)
        results=pd.concat([results,pd.DataFrame({'num_conv_layers':num_conv_layers,'num_motif':num_motif,'Dropout':dropprob,
                                                 'Learning Rate':l_rate,'epochs':epochs,'AUROC':[auc],'AUPRC':[prc]})])
        if auc > best_AUC :
            best_AUC = auc
            best_epochs = epochs
            best_num_motif = num_motif
            best_num_conv_layers = num_conv_layers
            best_dropprob = dropprob
            best_l_rate = l_rate
    best_hyperparameters = {'best_epochs': best_epochs,'best_num_motif':best_num_motif,
                            'best_num_conv_layers':best_num_conv_layers,'best_dropprob':best_dropprob,'best_l_rate':best_l_rate}
    torch.save(best_hyperparameters, model_dir+'/best_hyperpamarameters.pth')
    results.sort_values(by='AUROC',ascending=False,inplace=True)
    results.to_csv(model_dir+'/calibration_df.csv',index=False)
    return best_hyperparameters,results

In [18]:
def return_filter_outputs(model,test_loader):
    with torch.no_grad():
        best_model.eval()
        pred_list = []
        for i, (header, seq, data, target) in enumerate(test_loader):
            data = data.to(device)
            output = model(data)
            pred=output.cpu().detach().numpy()
            pred_list.append(pred)
        predictions = np.concatenate(pred_list)
    return predictions

# Motif extraction functions

In [19]:
def info_content(pwm, transpose=False, bg_gc=0.415):
    ''' Compute PWM information content.
    In the original analysis, I used a bg_gc=0.5. For any
    future analysis, I ought to switch to the true hg19
    value of 0.415.
    '''
    pseudoc = 1e-9
    if transpose:
        pwm = np.transpose(pwm)

    bg_pwm = [1-bg_gc, bg_gc, bg_gc, 1-bg_gc]

    ic = 0
    for i in range(pwm.shape[0]):
        for j in range(4):
            # ic += 0.5 + pwm[i][j]*np.log2(pseudoc+pwm[i][j])
            ic += -bg_pwm[j]*np.log2(bg_pwm[j]) + pwm[i][j]*np.log2(pseudoc+pwm[i][j])
    return ic

In [20]:
def meme_intro(meme_file, seqs):
    ''' Open MEME motif format file and print intro
    Attrs:
        meme_file (str) : filename
        seqs [str] : list of strings for obtaining background freqs
    Returns:
        mem_out : open MEME file
    '''
    nts = {'A':0, 'C':1, 'G':2, 'T':3}

    # count
    nt_counts = [1]*4
    for i in range(len(seqs)):
        for nt in seqs[i][1]:
            try:
                nt_counts[nts[nt]] += 1
            except KeyError:
                pass

    # normalize
    nt_sum = float(sum(nt_counts))
    nt_freqs = [nt_counts[i]/nt_sum for i in range(4)]

    # open file for writing
    meme_out = open(meme_file, 'w')

    # print intro material
    print( 'MEME version 4', file=meme_out)
    print( '', file=meme_out)
    #embd

    print( 'ALPHABET= ACGT', file=meme_out)        
    
    print( '', file=meme_out)
    print( 'Background letter frequencies:', file=meme_out)
    #embd
    print( 'A %.4f C %.4f G %.4f T %.4f' % tuple(nt_freqs), file=meme_out)
    print( '', file=meme_out)
    return meme_out


In [21]:
def make_filter_pwm(filter_fasta):
    ''' Make a PWM for this filter from its top hits '''
    nts = {'A':0, 'C':1, 'G':2, 'T':3}
    #embd
    pwm_counts = []
    nsites = 4 # pseudocounts
    for line in open(filter_fasta):
        if line[0] != '>':
            seq = line.rstrip()
            nsites += 1
            if len(pwm_counts) == 0:
                # initialize with the length
                for i in range(len(seq)):
                    pwm_counts.append(np.array([1.0]*4))

            # count
            for i in range(len(seq)):
                try:
                    pwm_counts[i][nts[seq[i]]] += 1
                except KeyError:
                    pwm_counts[i] += np.array([0.25]*4)

    # normalize
    pwm_freqs = []
    for i in range(len(pwm_counts)):
        pwm_freqs.append([pwm_counts[i][j]/float(nsites) for j in range(4)])

    return np.array(pwm_freqs), nsites-4

In [22]:
def plot_score_density(f_scores, out_pdf):
    sns.set(font_scale=1.3)
    plt.figure()
    sns.distplot(f_scores, kde=False)
    plt.xlabel('ReLU output')
    plt.savefig(out_pdf)
    plt.close()

    return f_scores.mean(), f_scores.std()

In [23]:
def filter_motif(param_matrix):
    nts = 'ACGT'
    motif_list = []
    for v in range(param_matrix.shape[1]):
        max_n = 0
        for n in range(1,4):
            if param_matrix[n,v] > param_matrix[max_n,v]:
                max_n = n

        if param_matrix[max_n,v] > 0:
            motif_list.append(nts[max_n])
        else:
            motif_list.append('N')

    return ''.join(motif_list)

In [24]:
def plot_filter_heat(param_matrix, out_pdf):
    param_range = abs(param_matrix).max()

    sns.set(font_scale=2)
    plt.figure(figsize=(param_matrix.shape[1], 4))
    sns.heatmap(param_matrix, cmap='PRGn', linewidths=0.2, vmin=-param_range, vmax=param_range)
    ax = plt.gca()
    ax.set_xticklabels(range(1,param_matrix.shape[1]+1))
    ax.set_yticklabels('ACGT', rotation='horizontal') # , size=10)
    plt.savefig(out_pdf)
    plt.close()

In [25]:
def plot_filter_logo(filter_outs, filter_size, seqs, out_prefix, filter_num, raw_t=0, maxpct_t=None):
    if maxpct_t:
        all_outs = np.ravel(filter_outs)
        all_outs_mean = all_outs.mean()
        all_outs_norm = all_outs - all_outs_mean
        raw_t = maxpct_t * all_outs_norm.max() + all_outs_mean
    # print fasta file of positive outputs
    filter_fasta_out = open('%s.fa' % out_prefix, 'w')
    filter_count = 0
    for i in range(filter_outs.shape[0]):
        for j in range(filter_outs.shape[1]):
            if filter_outs[i,j] > raw_t:
                kmer = seqs[i][1][j:j+filter_size]
                chrom = motif_sequences[i][0].split(':')[0]
                pos = int(motif_sequences[i][0].split(':')[1].split('-')[0])+j
                try:
                    filter_hits[filter_num][chrom].append(pos)
                except:
                    filter_hits[filter_num][chrom] = [pos]
                #kmer = kmer.replace('T','U')
                incl_kmer = len(kmer) - kmer.count('N')
                if incl_kmer <filter_size:
                    continue
                print('>%d_%d' % (i,j), file=filter_fasta_out)
                print(kmer, file=filter_fasta_out)
                filter_count += 1
    filter_fasta_out.close()
    # make weblogo
    if filter_count > 0:
        weblogo_cmd = 'weblogo %s < %s.fa > %s.eps' % (weblogo_opts, out_prefix, out_prefix)
        subprocess.call(weblogo_cmd, shell=True)

In [26]:
def meme_add(meme_out, f, filter_pwm, nsites, trim_filters=False):
    ''' Print a filter to the growing MEME file
    Attrs:
        meme_out : open file
        f (int) : filter index #
        filter_pwm (array) : filter PWM array
        nsites (int) : number of filter sites
    '''
    if not trim_filters:
        ic_start = 0
        ic_end = filter_pwm.shape[0]-1
    else:
        ic_t = 0.2

        # trim PWM of uninformative prefix
        ic_start = 0
        while ic_start < filter_pwm.shape[0] and info_content(filter_pwm[ic_start:ic_start+1]) < ic_t:
            ic_start += 1

        # trim PWM of uninformative suffix
        ic_end = filter_pwm.shape[0]-1
        while ic_end >= 0 and info_content(filter_pwm[ic_end:ic_end+1]) < ic_t:
            ic_end -= 1

    if ic_start < ic_end:
        print('MOTIF filter%d' % f, file=meme_out)
        print('letter-probability matrix: alength= 4 w= %d nsites= %d' % (ic_end-ic_start+1, nsites), file=meme_out)

        for i in range(ic_start, ic_end+1):
            print( '%.4f %.4f %.4f %.4f' % tuple(filter_pwm[i]), file=meme_out)


        print( '', file=meme_out)

In [27]:
def plot_filter_heat(param_matrix, out_pdf):
    param_range = abs(param_matrix).max()

    sns.set(font_scale=2)
    plt.figure(figsize=(param_matrix.shape[1], 4))
    sns.heatmap(param_matrix, cmap='PRGn', linewidths=0.2, vmin=-param_range, vmax=param_range)
    ax = plt.gca()
    ax.set_xticklabels(range(1,param_matrix.shape[1]+1))
    ax.set_yticklabels('ACGT', rotation='horizontal') # , size=10)
    plt.savefig(out_pdf)
    plt.close()

In [28]:
def name_filters(num_filters, tomtom_file, meme_db_file):
    ''' Name the filters using Tomtom matches.
    Attrs:
        num_filters (int) : total number of filters
        tomtom_file (str) : filename of Tomtom output table.
        meme_db_file (str) : filename of MEME db
    Returns:
        filter_names [str] :
    '''
    # name by number
    filter_names = ['f%d'%fi for fi in range(num_filters)]

    # name by protein
    if tomtom_file is not None and meme_db_file is not None:
        print(tomtom_file, meme_db_file)
        motif_protein = get_motif_proteins(meme_db_file)
        # hash motifs and q-value's by filter
        filter_motifs = {}

        tt_in = open(tomtom_file)
        tt_in.readline()
        for line in tt_in:
            a = line.split()
            if a== []:
                break
            fi = int(a[0][6:])
            motif_id = a[1]
            qval = float(a[5])

            filter_motifs.setdefault(fi,[]).append((qval,motif_id))

        tt_in.close()
        # assign filter's best match
        for fi in filter_motifs:
            top_motif = sorted(filter_motifs[fi])[0][1]
            filter_names[fi] += '_%s' % motif_protein[top_motif]

    return np.array(filter_names)

In [29]:
def get_motif_proteins(meme_db_file):
    ''' Hash motif_id's to protein names using the MEME DB file '''
    motif_protein = {}
    for line in open(meme_db_file):
        a = line.split()
        if len(a) > 0 and a[0] == 'MOTIF':
            if a[2][0] == '(':
                motif_protein[a[1]] = a[2][1:a[2].find(')')]
            else:
                motif_protein[a[1]] = a[2]
    return motif_protein

In [30]:
def generate_filter_bed(filter_num,out_dir):
    bed = open(out_dir+'/'+filter_num+'.bed','w')
    for key in filter_hits[filter_num].keys():
        for peak in filter_hits[filter_num][key]:
            print('{0} \t {1} \t {2}'.format(key.split('=')[1],peak,peak+motif_length),file=bed) 
    bed.close()

In [31]:
def get_motif(filter_weights_old, filter_outs, seqs, out_dir):
    global filter_hits
    filter_hits = {}
    filter_weights = []
    for x in filter_weights_old:
        x = x - np.mean(x,axis = 0)
        filter_weights.append(x)
        
    filter_weights = np.array(filter_weights)
    num_filters = filter_weights.shape[0]
    filter_size = filter_weights.shape[2]
    filters_ic = []
    nsites_list = []
    meme_out = meme_intro('%s/filters_meme.txt'%(out_dir), seqs)
    for f in range(num_filters):
        filter_hits['filter_%i'%f] = {}
        # plot filter parameters as a heatmap
        plot_filter_heat(filter_weights[f,:,:filter_size], '%s/filter%d_heat.pdf' % (out_dir,f))
        # plot weblogo of high scoring outputs
        plot_filter_logo(filter_outs[:,f,:], filter_size, seqs, '%s/filter%d_logo'%(out_dir,f), 'filter_%i'%f, maxpct_t=0.8)
        generate_filter_bed('filter_%i'%f,out_dir)
        # make a PWM for the filter
        filter_pwm, nsites = make_filter_pwm('%s/filter%d_logo.fa'%(out_dir,f))
        nsites_list.append(nsites)
        if nsites < 10:
            # no information
            filters_ic.append(0)
        else:
            # compute and save information content
            filters_ic.append(info_content(filter_pwm))

            # add to the meme motif file
            meme_add(meme_out, f, filter_pwm, nsites, False)
    pd.DataFrame(filter_hits).to_csv('%s/indices.csv'%out_dir)
    meme_out.close()    
    subprocess.call('../meme-5.5.2/src/tomtom -dist pearson -thresh 0.05 -oc %s/tomtom %s/filters_meme.txt %s' % (out_dir, out_dir, 'Motif_database/YEASTRACT_20130918.meme'), shell=True)
    subprocess.call('cp %s/tomtom/tomtom.tsv %s/tomtom/tomtom.txt' %(out_dir, out_dir), shell=True)
    filter_names = name_filters(num_filters, '%s/tomtom/tomtom.txt'%out_dir, 'Motif_database/YEASTRACT_20130918.meme')
    
    table_out = open('%s/table.txt'%out_dir, 'w')

    # print header for later panda reading
    table = PrettyTable(["Filter", "consensus","annotation","ic",'mean', 'std',"nsites"])
    header_cols = ('', 'consensus', 'annotation', 'ic', 'mean', 'std',"nsites")
    print('%3s  %24s  %10s  %5s  %6s  %6s  %6s' % header_cols, file=table_out)
    
    for f in range(num_filters):
        # collapse to a consensus motif
        consensus = filter_motif(filter_weights[f,:,:])

        # grab annotation
        annotation = '.'
        name_pieces = filter_names[f].split('_')
        if len(name_pieces) > 1:
            annotation = name_pieces[1]
        # plot density of filter output scores
        fmean, fstd = plot_score_density(np.ravel(filter_outs[: , f , :]), '%s/filter%d_dens.pdf' % (out_dir,f))

        row_cols = (f, consensus, annotation, filters_ic[f], fmean, fstd,nsites_list[f])
        table.add_row(list(row_cols))
        print( '%-3d  %19s  %10s  %5.2f  %6.4f  %6.4f %4i' % row_cols, file=table_out)
        
    table_out.close()
    print(table.get_string())

# Main

In [48]:
global data_dir
data_dir = 'Data'
data_file = data_dir+'/S288C_reference_sequence_R64-3-1_20210421.fsa'
peak_file = [data_dir+'/Condensin_peaks_Log.bed',data_dir+'/Condensin_peaks_quiescence.bed']
# peak_file can be a list of bed files, it will depend on include_dinuc variable whether to assign to each file a different label or to assign them label 1 and the shuffeled seq will be labeled 0

In [49]:
global device
global file_extension
global model_dir
global results_dir
global verbose
global new

device = 'cuda' if torch.cuda.is_available() else 'cpu'

file_extension = peak_file.split('.')[0].split('_')[-1] if type(peak_file)==str else 'mix'
model_dir = 'model_'+file_extension
results_dir = 'results_'+file_extension
if not os.path.exists(model_dir):
    os.mkdir(model_dir)
if not os.path.exists(results_dir):
    os.mkdir(results_dir)
if not os.path.exists(results_dir+'/tomtom'):
    os.mkdir(results_dir+'/tomtom')
    
verbose = False
new = True  #assign False if a pretrained model already exists

In [50]:
seq_length = 500
motif_length = 24
cross_chrom = True
include_dinuc = True
calib_data,valid_data,train_data,test_data,peak_sequences_train,peak_sequences_test = extract_data(data_file,peak_file,motif_length,
                                                                                                   seq_length,cross_chrom,include_dinuc)
len(calib_data),len(valid_data),len(train_data),len(test_data)

(4248, 472, 4720, 610)

In [51]:
calib_dataset=dataset(calib_data)
valid_dataset=dataset(valid_data)
train_dataset=dataset(train_data)
test_dataset=dataset(test_data)
batch_size = 64
calib_loader = DataLoader(dataset=calib_dataset,
                          batch_size=batch_size,shuffle=False)
valid_loader = DataLoader(dataset=valid_dataset,
                          batch_size=batch_size,shuffle=False)
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,shuffle=False)
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,shuffle=False)

In [52]:
num_motif_list = [30,40,60]
num_conv_layers_list = [1,2]
dropprob_list = [0, 0.15, 0.3]
learning_rate_list = [10**-5,10**-4,10**-3,10**-2]
max_num_models = 30
maxepochs = 500
epochs_for_early_stop = 50
best_hyperparameters,results = Calibrate_model(calib_loader,valid_loader, num_motif_list, num_conv_layers_list , dropprob_list,
                                               learning_rate_list, max_num_models=max_num_models, maxepochs=maxepochs,
                                               epochs_for_early_stop=epochs_for_early_stop,motif_len=motif_length)

model 1 out of 30
early stopping at epoch  16
model 2 out of 30
early stopping at epoch  162
model 3 out of 30
no early stopping
model 4 out of 30
early stopping at epoch  273
model 5 out of 30
early stopping at epoch  138
model 6 out of 30
no early stopping
model 7 out of 30
early stopping at epoch  7
model 8 out of 30
no early stopping
model 9 out of 30
early stopping at epoch  170
model 10 out of 30
early stopping at epoch  145
model 11 out of 30
early stopping at epoch  24
model 12 out of 30
early stopping at epoch  16
model 13 out of 30
early stopping at epoch  29
model 14 out of 30
early stopping at epoch  18
model 15 out of 30
no early stopping
model 16 out of 30
early stopping at epoch  179
model 17 out of 30
early stopping at epoch  31
model 18 out of 30
early stopping at epoch  5
model 19 out of 30
no early stopping
model 20 out of 30
early stopping at epoch  254
model 21 out of 30
early stopping at epoch  20
model 22 out of 30
early stopping at epoch  4
model 23 out of 30
ea

In [53]:
results.head()

Unnamed: 0,num_conv_layers,num_motif,Dropout,Learning Rate,epochs,AUROC,AUPRC
0,2,60,0.0,0.0001,138,0.950697,0.950778
0,1,40,0.15,0.0001,254,0.947195,0.952806
0,1,30,0.0,0.0001,273,0.943856,0.949412
0,2,40,0.3,0.0001,145,0.942366,0.94657
0,1,60,0.0,0.0001,162,0.940786,0.943991


In [54]:
best_hyperparameters

{'best_epochs': 138,
 'best_num_motif': 60,
 'best_num_conv_layers': 2,
 'best_dropprob': 0,
 'best_l_rate': 0.0001}

In [55]:
maxepochs = best_hyperparameters['best_epochs']
num_motif = best_hyperparameters['best_num_motif']
num_conv_layers = best_hyperparameters['best_num_conv_layers']
dropprob = best_hyperparameters['best_dropprob']
l_rate = best_hyperparameters['best_l_rate']
epochs_for_early_stop = 0
model = Network(num_motif , motif_length , num_conv_layers , dropprob).to(device)
best_model,epochs = Train_model(model,train_loader,valid_loader,l_rate ,maxepochs,
                                                          epochs_for_early_stop,save_model=True)

no early stopping
+------------------------+------------+
|        Modules         | Parameters |
+------------------------+------------+
|  conv_layer.0.weight   |    5760    |
|   conv_layer.0.bias    |     60     |
|  conv_layer.3.weight   |   64800    |
|   conv_layer.3.bias    |     90     |
| project.to_attn_logits |    8100    |
|  classifier.0.weight   |    8100    |
|   classifier.0.bias    |     90     |
|  classifier.3.weight   |     90     |
|   classifier.3.bias    |     1      |
+------------------------+------------+

 Total Trainable Params: 87091


In [56]:
auroc,auprc = Test_model(best_model,test_loader)
auroc,auprc

(0.9756409567320613, 0.9768832059548929)

# Motif extraction

In [57]:
motif_sequences=peak_sequences_train.copy()
motif_sequences.extend(peak_sequences_test)

In [58]:
filter_Weights = best_model.conv[0].weight.detach().cpu().numpy()
filter_bias = best_model.conv[0].bias.detach().cpu().numpy()
motif_data=generate_onehot_data(motif_sequences,motif_length=1,label=1,include_dinuc=False)
motif_dataset=dataset(motif_data)
motif_loader = DataLoader(dataset=motif_dataset,batch_size=batch_size,shuffle=False)
out_model = conv_output(filter_Weights,filter_bias,device)
filter_output = return_filter_outputs(out_model,motif_loader)

In [59]:
get_motif(filter_Weights,filter_output,motif_sequences,results_dir)

The output directory 'results_mix/tomtom' already exists.
Its contents will be overwritten.
Processing query 1 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00434
#   Estimated pi_0=1
Processing query 2 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00434
#   Estimated pi_0=1
Processing query 3 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00022
#   Estimated pi_0=1
Processing query 4 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00276
#   Estimated pi_0=1
Processing query 5 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00365
#   Estimated pi_0=1
Processing query 6 out of 60 
# C

# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00503
#   Estimated pi_0=1
Processing query 48 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00434
#   Estimated pi_0=1
Processing query 49 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00159
#   Estimated pi_0=1
Processing query 50 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 0.995414
#   Estimated pi_0=0.995414
Processing query 51 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zero = 1.00228
#   Estimated pi_0=1
Processing query 52 out of 60 
# Computing q-values.
#   Estimating pi_0 from all 1464 observed p-values.
#   Estimating pi_0.
# Minimal pi_zer

results_mix/tomtom/tomtom.txt Motif_database/YEASTRACT_20130918.meme
+--------+--------------------------+------------+--------------------+-------------+------------+--------+
| Filter |        consensus         | annotation |         ic         |     mean    |    std     | nsites |
+--------+--------------------------+------------+--------------------+-------------+------------+--------+
|   0    | GGGCATACGTTAACCTGCGAGCAT |     .      | 11.350989127442071 |  0.23894928 | 0.25270543 |   38   |
|   1    | TTAACACGCGGTAAGCGCTTAACC |     .      | 10.181426068901843 |  0.16367914 | 0.19221641 |  147   |
|   2    | ACACATACTTATCGTTTTTTTTGC |   Hsf1p    | 19.859827827416787 |   0.29639   | 0.31235033 |  457   |
|   3    | AGTGCGCCGCGTACTCACTGATTA |     .      | 9.610126126135519  |  0.12056698 | 0.18135278 |   97   |
|   4    | CTGCTGTTGGAGCAGCAAAAGCTG |     .      | 14.195094272291023 |  0.4280269  | 0.29944867 |   97   |
|   5    | GGCGTAACGCTCGGCTAAAGACGC |     .      | 9.38027305456453