In [46]:
import os 
import csv
import copy
import math 
import gzip
import torch
import random
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
from Motif_Functions import *
from sklearn import metrics
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
warnings.filterwarnings("ignore")

In [22]:
global device
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [23]:
def genome_data(data_file):
    data=open(data_file).read()
    chromosomes_data = data.split('>')[1:]
    return (chromosomes_data)

def Read_bed_file(chromosomes_data,peak_file,seq_length,Chr_dict):
    peaks = open(peak_file).readlines()
    if Chr_dict==None:
        Chr_dict = {}
        i=0
        for chrom_data in chromosomes_data:
            ref = chrom_data.split('\n')[0].split(' ')[-1][1:-1]
            Chr_dict[ref]=i
            i+=1
        Chr_dict['chromosome=Mito'] = Chr_dict.pop('top=circular')
    peak_sequences = []
    for peak in peaks:
        peak_split = peak.split('\t')
        Chr = 'chromosome='+str(peak_split[0])
        if Chr in Chr_dict:
            chrom_seq = ''.join(chromosomes_data[Chr_dict[Chr]].split('\n')[1:])
            start_idx = max(int(peak_split[1])-seq_length//2,0)
            end_idx = min(int(len(chrom_seq)), start_idx+seq_length)
            if end_idx == len(chrom_seq):
                start_idx = end_idx - seq_length
            peak_sequences.append(chrom_seq[start_idx:end_idx])
    return (peak_sequences)

def Read_narrow_file(chromosomes_data,peak_file):
    peaks = open(peak_file).readlines()
    if Chr_dict==None:
        Chr_dict = {}
        i=0
        for chrom_data in chromosomes_data:
            ref = chrom_data.split('\n')[0].split(' ')[-1][1:-1]
            Chr_dict[ref]=i
            i+=1
        Chr_dict['chromosome=Mito'] = Chr_dict.pop('top=circular')

    peak_sequences = []
    for peak in peaks:
        peak_split = peak.split('\t')
        Chr = 'chromosome='+str(peak_split[0])
        chrom_seq = ''.join(chromosomes_data[Chr_dict[Chr]].split('\n')[1:])
        if Chr in Chr_dict:
            start_idx = max(int(peak_split[1])-seq_length//2,0)
            end_idx = min(len(chrom_seq),start_idx+seq_length)
            #end_idx = start_idx+seq_length
            peak_sequences.append(chrom_seq[start_idx:end_idx])
    return (peak_sequences)
    return None

def dinucshuffle(sequence):
    b=[sequence[i:i+2] for i in range(0, len(sequence), 2)]
    random.shuffle(b)
    d=''.join([str(x) for x in b])
    return d

def seqtopad(sequence, motif_len):
    rows=len(sequence)+2*motif_len-2
    S=np.empty([rows,4])
    base=['A', 'C', 'G', 'T']
    for i in range(rows):
        for j in range(4):
            if (i-motif_len+1<len(sequence) and sequence[i-motif_len+1]=='N' 
                or i<motif_len-1 or i>len(sequence)+motif_len-2):
                S[i,j]=np.float32(0.25)
            elif sequence[i-motif_len+1]==base[j]:
                S[i,j]=np.float32(1)
            else:
                S[i,j]=np.float32(0)
    return np.transpose(S)

def generate_onehot_data(peak_sequences,motif_length,include_dinuc=True):
    alldata = []
    for seq in peak_sequences:
        alldata.append([seqtopad(seq,motif_length),[1]])#
    return (alldata)

def extract_data(data_file,peak_file,motif_length=24,seq_length=150,Chr_dict=None):
    chromosomes_data = genome_data(data_file)
    if '.bed' in peak_file:
        peak_sequences = Read_bed_file(chromosomes_data,peak_file,seq_length,Chr_dict)
    else :
        peak_sequences = Read_narrow_file(chromosomes_data,peak_file,seq_length,Chr_dict)
    alldata = []
    for seq in peak_sequences:
        alldata.append([seqtopad(seq,motif_length),[1]])#
        alldata.append([seqtopad(dinucshuffle(seq),motif_length),[0]])#
    random.shuffle(alldata)
    size=int(len(alldata)/5)
    train_data=alldata[:4*size]
    valid_data=alldata[4*size:int(4.5*size)]
    test_data = alldata[int(4.5*size):]
    return train_data,valid_data,test_data,alldata,peak_sequences

In [24]:
def Train_model(model,train_loader,valid_loader, l_rate=0.01 ,  maxepochs=100,epochs_for_early_stop=0,verbose=False):
    best_model = None
    best_loss = np.inf
    counter = 0
    nepochs=0
    valid_losses =[]
    train_losses = []
    optimizer = torch.optim.SGD(model.parameters(),lr=l_rate,momentum=0.9,nesterov=True,weight_decay=1e-02)
    criterion = nn.BCELoss(reduction='mean')
    while nepochs<maxepochs:
        model.train()
        train_loss=0
        for i, (data, target) in enumerate(train_loader):
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = criterion(output, target)#
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
        if verbose:
            print('Model trained for {0} epochs out of {2}. Training loss is {1}'.format(nepochs+1,loss.item(),maxepochs))
        train_losses.append(train_loss/(i+1))
        with torch.no_grad():
            model.eval() 
            valid_loss=0
            for i, (data, target) in enumerate(valid_loader):
                data = data.to(device)
                target = target.to(device)
                output = model(data)
                loss = F.binary_cross_entropy(output, target)#
                valid_loss+=loss.item()
            valid_losses.append(valid_loss/(i+1))
        counter+=1
        nepochs +=1
        if epochs_for_early_stop>0:
            if valid_losses[-1]<best_loss:
                if verbose:
                    print('Validation loss decreased from {0} to {1}'.format(best_loss,valid_losses[-1]))
                best_loss = valid_losses[-1]
                best_model = model
                counter = 0
            else:
                if verbose:
                    print('Counter for early stopping: {0} out of {1}'.format(counter,epochs_for_early_stop))
                if counter == epochs_for_early_stop:
                    print('early stopping at epoch ', nepochs-counter)
                    return (best_model,nepochs-counter,train_losses,valid_losses)
    print('no early stopping')
    return (best_model,nepochs,train_losses,valid_losses)

def Test_model(best_model,test_loader):
    with torch.no_grad():
        best_model.eval()
        pred_list = []
        labels_list = []
        for i, (data, target) in enumerate(test_loader):
            data = data.to(device)
            target = target.to(device)
            output = best_model(data)
            pred=output.cpu().detach().numpy().reshape(output.shape[0])
            labels=target.cpu().numpy().reshape(output.shape[0])
            pred_list.append(pred)
            labels_list.append(labels)
        labels = np.concatenate(labels_list)
        predictions = np.concatenate(pred_list)
        auc = metrics.roc_auc_score(labels, predictions)
        print('AUC on test data ', auc)
    return (auc)

In [25]:
class dataset(Dataset):
    def __init__(self, xy=None):
        self.x_data=np.asarray([el[0] for el in xy],dtype=np.float32)
        self.y_data =np.asarray([el[1] for el in xy ],dtype=np.float32)
        self.x_data = torch.from_numpy(self.x_data)
        self.y_data = torch.from_numpy(self.y_data)
        self.length=len(self.x_data)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.length

In [26]:
class Network(nn.Module):
    def __init__(self, num_motif , motif_len):
        super(Network, self).__init__()
        self.num_motif = num_motif
        self.conv = nn.Sequential(
            nn.Conv1d(4, num_motif, kernel_size=motif_len))
        self.classifier = nn.Sequential(
            nn.Linear(num_motif , num_motif),
            nn.ReLU(inplace=True),
            nn.Linear(num_motif, 1),
            nn.Sigmoid())
    def forward(self, x):
        x = self.conv(x)
        x,_ = torch.max(x, dim=2)
        predict = self.classifier(x)
        return predict
    

In [27]:
class conv_output(nn.Module):
    def __init__(self,filter_weights,filter_bias,device):
        super(conv_output, self).__init__()
        if type(filter_weights) is np.ndarray:
            self.filter_weights =  torch.from_numpy(filter_weights.astype(np.float32)).to(device)
        else :
            self.filter_weights = filter_weights.to(device)
        if type(filter_bias) is np.ndarray:
            self.filter_bias =  torch.from_numpy(filter_bias.astype(np.float32)).to(device)
        else :
            self.filter_bias = filter_bias.to(device)
        
    def forward(self,x):
        x=F.conv1d(x, self.filter_weights, bias=self.filter_bias, stride=1, padding=0)
        out=x.clamp(min=0)
        return (out)

In [28]:
def return_filter_outputs(best_model,test_loader):
    with torch.no_grad():
        best_model.eval()
        pred_list = []
        for i, (data, target) in enumerate(test_loader):
            data = data.to(device)
            output = best_model(data)
            pred=output.cpu().detach().numpy()
            pred_list.append(pred)
        predictions = np.concatenate(pred_list)
    return predictions

In [29]:
def info_content(pwm, transpose=False, bg_gc=0.415):
    ''' Compute PWM information content.
    In the original analysis, I used a bg_gc=0.5. For any
    future analysis, I ought to switch to the true hg19
    value of 0.415.
    '''
    pseudoc = 1e-9

    if transpose:
        pwm = np.transpose(pwm)

    bg_pwm = [1-bg_gc, bg_gc, bg_gc, 1-bg_gc]

    ic = 0
    for i in range(pwm.shape[0]):
        for j in range(4):
            # ic += 0.5 + pwm[i][j]*np.log2(pseudoc+pwm[i][j])
            ic += -bg_pwm[j]*np.log2(bg_pwm[j]) + pwm[i][j]*np.log2(pseudoc+pwm[i][j])

    return ic

In [30]:
def make_filter_pwm(filter_fasta):
    ''' Make a PWM for this filter from its top hits '''
    nts = {'A':0, 'C':1, 'G':2, 'T':3}
    #embd
    pwm_counts = []
    nsites = 4 # pseudocounts
    for line in open(filter_fasta):
        if line[0] != '>':
            seq = line.rstrip()
            nsites += 1
            if len(pwm_counts) == 0:
                # initialize with the length
                for i in range(len(seq)):
                    pwm_counts.append(np.array([1.0]*4))

            # count
            for i in range(len(seq)):
                try:
                    pwm_counts[i][nts[seq[i]]] += 1
                except KeyError:
                    pwm_counts[i] += np.array([0.25]*4)

    # normalize
    pwm_freqs = []
    for i in range(len(pwm_counts)):
        pwm_freqs.append([pwm_counts[i][j]/float(nsites) for j in range(4)])

    return np.array(pwm_freqs), nsites-4

In [44]:
def get_motif(filter_weights_old, filter_outs, seqs, out_dir):
    global fw
    filter_weights = []
    for x in filter_weights_old:
        x = x - np.mean(x,axis = 0)
        filter_weights.append(x)
        
    filter_weights = np.array(filter_weights)
    num_filters = filter_weights.shape[0]
    filter_size = filter_weights.shape[2]
    filters_ic = []
    meme_out = meme_intro('%s/filters_meme.txt'%(out_dir), seqs)
    fw = open('indices.txt', 'w')
    for f in range(num_filters):
        print ('Filter %d' % f)
        # plot filter parameters as a heatmap
        #plot_filter_heat(filter_weights[f,:,:filter_size], '%s/filter%d_heat.pdf' % (out_dir,f))

        # plot weblogo of high scoring outputs
        plot_filter_logo(filter_outs[:,f,:], filter_size, seqs, '%s/filter%d_logo'%(out_dir,f), maxpct_t=0.5)

        # make a PWM for the filter
        filter_pwm, nsites = make_filter_pwm('%s/filter%d_logo.fa'%(out_dir,f))

        if nsites < 10:
            # no information
            filters_ic.append(0)
        else:
            # compute and save information content
            filters_ic.append(info_content(filter_pwm))

            # add to the meme motif file
            meme_add(meme_out, f, filter_pwm, nsites, False)

    meme_out.close()
    fw.close()
    table_out = open('%s/table.txt'%out_dir, 'w')

    # print header for later panda reading
    header_cols = ('', 'consensus', 'annotation', 'ic', 'mean', 'std')
    print('%3s  %19s  %10s  %5s  %6s  %6s' % header_cols, file=table_out)

    for f in range(num_filters):
        # collapse to a consensus motif
        consensus = filter_motif(filter_weights[f,:,:])

        # grab annotation
        annotation = '.'

        # plot density of filter output scores
        fmean, fstd = plot_score_density(np.ravel(filter_outs[:,:, f]), '%s/filter%d_dens.pdf' % (out_dir,f))

        row_cols = (f, consensus, annotation, filters_ic[f], fmean, fstd)
        print( '%-3d  %19s  %10s  %5.2f  %6.4f  %6.4f' % row_cols, file=table_out)


    table_out.close()

In [32]:
data_file = 'Data/S288C_reference_sequence_R64-3-1_20210421.fsa'
peak_file = 'Data/Condensin_peaks_Log.bed'
seq_length = 150
motif_length = 24
train_data,valid_data,test_data,alldata,peak_sequences = extract_data(data_file,peak_file,motif_length,seq_length)
len(train_data),len(valid_data),len(test_data)

(1584, 198, 202)

In [33]:
train_dataset=dataset(train_data)
valid_dataset=dataset(valid_data)
test_dataset=dataset(test_data)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,shuffle=False)
valid_loader = DataLoader(dataset=valid_dataset,
                          batch_size=batch_size,shuffle=False)
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,shuffle=False)

In [38]:
verbose=True
num_motif=40
motif_len=24
sigma_conv=10**-3
sigma_w=0.1
num_conv_layers = 1
dropprob = 0.2
l_rate=0.001
maxepochs,epochs_for_early_stop = 500,50
model = Network(num_motif, motif_len).to(device)#num_conv_layers,dropprob
best_model,epochs,train_losses,valid_losses = Train_model(model,train_loader,valid_loader,l_rate ,maxepochs,epochs_for_early_stop,verbose)

Model trained for 1 epochs out of 500. Training loss is 0.6936900615692139
Validation loss decreased from inf to 0.6924862712621689
Model trained for 2 epochs out of 500. Training loss is 0.69287109375
Counter for early stopping: 1 out of 50
Model trained for 3 epochs out of 500. Training loss is 0.6923608779907227
Counter for early stopping: 2 out of 50
Model trained for 4 epochs out of 500. Training loss is 0.6920428276062012
Counter for early stopping: 3 out of 50
Model trained for 5 epochs out of 500. Training loss is 0.6918391585350037
Counter for early stopping: 4 out of 50
Model trained for 6 epochs out of 500. Training loss is 0.6917020678520203
Counter for early stopping: 5 out of 50
Model trained for 7 epochs out of 500. Training loss is 0.6915937662124634
Counter for early stopping: 6 out of 50
Model trained for 8 epochs out of 500. Training loss is 0.691506564617157
Counter for early stopping: 7 out of 50
Model trained for 9 epochs out of 500. Training loss is 0.69143080711

In [35]:
auc = Test_model(best_model,test_loader)

AUC on test data  0.5205115592720118


In [39]:
weights = model.conv[0].weight.detach().cpu().numpy()
bias = model.conv[0].bias.detach().cpu().numpy()
motif_sequences=generate_onehot_data(peak_sequences,24,include_dinuc=False)
motif_dataset=dataset(motif_sequences)
batch_size = 100000
motif_loader = DataLoader(dataset=motif_dataset,
                          batch_size=batch_size,shuffle=False)
out_model = conv_output(weights,bias,device)
filter_output = return_filter_outputs(out_model,motif_loader)

In [40]:
motif_sequences

[[array([[0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
          0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
          0.25, 1.  , 0.  , 0.  , 0.  , 0.  , 1.  , 1.  , 0.  , 1.  , 0.  ,
          1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 1.  , 0.  , 0.  , 0.  ,
          1.  , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  ,
          1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
          1.  , 1.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
          0.  , 1.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
          0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
          1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  ,
          0.  , 0.  , 0.  , 1.  , 1.  , 0.  , 0.  , 1.  , 0.  , 1.  , 0.  ,
          0.  , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 1.  , 0.  , 0.  ,
          1.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  ,
          0.

In [41]:
weights.shape

(40, 4, 24)

In [42]:
filter_output.shape

(992, 40, 173)

In [48]:
get_motif(weights,filter_output,peak_sequences,'results')

NameError: name 'data_type' is not defined