In [1]:
import pandas as pd
import numpy as np
import scipy
import scipy.sparse
import scipy.stats
import os
import scipy.io as sio
%matplotlib inline
from pylab import *

import warnings
warnings.filterwarnings(action='ignore')

from keras.models import load_model
from pkg_resources import resource_filename
from utils2 import one_hot_encode

import datetime as dt
import h5py
import numpy as np
import fnmatch
import time
import re
from sklearn.metrics import average_precision_score
from pandas import DataFrame


bases = ['A','T','C','G']
dna_dict = dict(zip(list('ATCG'),range(4)))
watsoncrick = {'N':'N','.':'.','C':'G','G':'C','A':'T','T':'A','*':'*'}

def add_base(li):
		new_li = []
		for s in li:
			for b in bases:
				new_li.append(s+b)
		return new_li

def make_mer_list(mer_len):
	li = bases
	for i in range(mer_len-1):
		li = add_base(li)
	return li

def get_energy(X,w,w0):
    B = X*w+w0;
    mx = amax(B);
    Z = scipy.matrix(mx+log(scipy.sum(exp(B-mx),axis=1)));
    try:
        B = B-Z
    except:
        B = B-Z.transpose()
    return B

def predict(W,b,X):
    
    B = X*W+b;
    mx = amax(B);
    Z = scipy.matrix(mx+log(scipy.sum(exp(B-mx),axis=1)));
    try:
        B = B-Z
    except:
        B = B-Z.transpose()
    return np.array(exp(B))


def make_mer_matrix_no_pos(seqs,mer_len):
    mer_dict = dict(zip(make_mer_list(mer_len),range(4**mer_len)))
    rows,cols = [],[]
    r = 0
    for i in range(len(seqs)):
        cur_seq = seqs[i]
        for b in range(len(cur_seq)-mer_len+1):
            rows.append(r)
            cols.append(mer_dict[cur_seq[b:b+mer_len]])
        
        r+=1
    vals = np.ones_like(cols)
    rows.append(r-1)
    cols.append(4**mer_len-1)
    vals = np.append(vals,0)
    X = scipy.sparse.csr_matrix((vals,(rows,cols)),dtype=np.float64)
    return X

In [2]:
data = sio.loadmat('rosen_weight/Reads.mat')

# A5SS
A5SS_data = data['A5SS']
A5SS_data = np.array(A5SS_data.todense())
# Get minigenes with reads
A5SS_nn = A5SS_data.sum(axis=1)>0
A5SS_data = A5SS_data[A5SS_nn]
A5SS_data = A5SS_data/A5SS_data.sum(axis=1)[:,newaxis]
A5SS_seqs = pd.read_csv('rosen_weight/A5SS_Seqs.csv',index_col=0).Seq[A5SS_nn]

In [3]:
# A3SS
A3SS_data = data['A3SS']

# Only look at SA_1 usage:
A3SS_data = np.array(A3SS_data[:,235].todense()).reshape(-1)/np.array(A3SS_data.sum(axis=1),dtype=np.float64).reshape(-1)
# Get minigenes with reads
A3SS_nn = pd.notnull(A3SS_data)>0
A3SS_data = A3SS_data[A3SS_nn]
A3SS_seqs = pd.read_csv('rosen_weight/A3SS_Seqs.csv',index_col=0).Seq[A3SS_nn]

In [4]:

Y = A3SS_data
Y3 = scipy.matrix(np.hstack((1-Y[:,newaxis],Y[:,newaxis])))


Y = A5SS_data[:,44]/(A5SS_data[:,0]+A5SS_data[:,44])
nn = pd.notnull(Y)>0
A5SS_data=A5SS_data[nn]
A5SS_seqs=A5SS_seqs[nn]
Y5 = scipy.matrix(np.hstack((1-Y[nn,newaxis],Y[nn,newaxis])))

In [5]:
mlr_weight_data = sio.loadmat('rosen_weight/param_mlr.mat')
mlr_W=mlr_weight_data['W']
mlr_b=mlr_weight_data['b']

In [6]:
paths1 = ('rosen_weight/SpliceAI400_g{}.h5'.format(x) for x in [1,2,3,4,5])
models_400=[load_model(resource_filename('spliceai', x),compile=False) for x in paths1]

paths2 = ('rosen_weight/joint_spliceai14.h5' for x in [1])
models_reg=[load_model(resource_filename('spliceai', x),compile=False) for x in paths2]

In [7]:
def test_with_mlr(index,data_type,W=mlr_W,b=mlr_b):

    if data_type=='A5SS':
        feat_vec = make_mer_matrix_no_pos(pd.Series(A5SS_seqs[index]).str.slice(7,32),6)
        feat_vec = scipy.sparse.hstack((feat_vec, 0.0))
        
    if data_type=='A3SS':
        feat_vec = make_mer_matrix_no_pos(pd.Series(A3SS_seqs[index]).str.slice(-22),6)
        feat_vec = scipy.sparse.hstack((feat_vec, 1.0))
    
    return predict(W,b,feat_vec)

In [8]:
def test_with_400_model(index,data_type,models=models_400):
    if data_type=='A5SS':
        a5ss_full='atggtgtccaagggcgaggagctgttcaccggggtggtgcccatcctggtcgagctggacggcgacgtaaacggccacaagttcagcgtcagcggcgagggcgagggcgatgccacctacggcaaactgaccctgaagttcatctgcaccaccggcaagctgcccgtgccctggcccaccctcgtgaccaccttcggctacggcctgatgtgcttcgcccgctaccccgaccacatgaagcagcacgacttcttcaagtccgccatgcccgaaggctacgtccaggagcgcaccatcttcttcaaggacgacggcaactacaagacccgcgccgaagtgaagttcgagggcgacaccctcgtgaaccgcatcgagctaaagggcatcgacttcaaggaggacggcaacatcctggggcacaagctggagtacaactacaacagccacaacgtctatatcatggccgacaagcagaagaacggcatcaaagtgaacttcaagatccgccacaacatcgaggtgcttggnnnnnnnnnnnnnnnnnnnnnnnnnggtcgacccaggttcgtgnnnnnnnnnnnnnnnnnnnnnnnnngaggtattcttatcaccttcgtggctacagagtttccttatttgtctctgttgccggcttatatggacaagcatatcacagccatttatcggagcgcctccgtacacgctattatcggacgcctcgcgagatcaatacgtataccagctgccctcgatacatgtcttggcatcgtttgcttctcgagtactacctggttcctcttctttctttctcttctctttcaggacggcagcgtgcagctcgccgaccactaccagcagaacacccccatcggcgacggccccgtgctgctgcccgacaaccactacctgagctaccagtccgccctgagcaaagaccccaacgagaagcgcgatcacatggtcctgctggagttcgtgaccgccgccgggatcactctcggcatggacgagctgtacaagga'.upper()
        input_sequence=a5ss_full[:520]+A5SS_seqs[index]+a5ss_full[621:]
    
    if data_type=='A3SS':
        a3ss_full='atggtgtccaagggcgaggagctgttcaccggggtggtgcccatcctggtcgagctggacggcgacgtaaacggccacaagttcagcgtcagcggcgagggcgagggcgatgccacctacggcaaactgaccctgaagttcatctgcaccaccggcaagctgcccgtgccctggcccaccctcgtgaccaccttcggctacggcctgatgtgcttcgcccgctaccccgaccacatgaagcagcacgacttcttcaagtccgccatgcccgaaggctacgtccaggagcgcaccatcttcttcaaggacgacggcaactacaagacccgcgccgaagtgaagttcgagggcgacaccctcgtgaaccgcatcgagctaaagggcatcgacttcaaggaggacggcaacatcctggggcacaagctggagtacaactacaacagccacaacgtctatatcatggccgacaagcagaagaacggcatcaaagtgaacttcaagatccgccacaacatcgaggtaagttatcaccttcgtggctacagagtttccttatttgtctctgttgccggcttatatggacaagcatatcacagccatttatcggagcgcctccgtacacgctattatcggacgcctcgcgagatcaatacgtataccagctgccctcgatacatgtcttggacggggtcggtgttgatatcgtatNNNNNNNNNNNNNNNNNNNNNNNNNGCTTGGATCTGATCTCAACAGGGTNNNNNNNNNNNNNNNNNNNNNNNNNatgattacacatatagacacgcgagcacccatcttttatagaatgggtagaacccgtcctaaggactcagattgagcatcgtttgcttctcgagtactacctggtacagatgtctcttcaaacaggacggcagcgtgcagctcgccgaccactaccagcagaacacccccatcggcgacggccccgtgctgctgcccgacaaccactacctgagctaccagtccgccctgagcaaagaccccaacgagaagcgcgatcacatggtcctgctggagttcgtgaccgccgccgggatcactctcggcatggacgagctgtacaag'.upper()
        input_sequence=a3ss_full[:708]+A3SS_seqs[index]+a3ss_full[6+708+25+24+25:]
    
    
    context = 400

    x = one_hot_encode('N'*(context//2) + input_sequence + 'N'*(context//2))[None, :]
    y = np.mean([models[m].predict(x) for m in range(5)], axis=0)
    
    if data_type=='A5SS':
        return np.array([y[0][518+0][2],y[0][518+44][2]])
    
    if data_type=='A3SS':
        return np.array([y[0][518][0],y[0][519+235][1]])

In [9]:
def test_with_reg_model(index,data_type,models=models_reg):
    if data_type=='A5SS':
        a5ss_full='atggtgtccaagggcgaggagctgttcaccggggtggtgcccatcctggtcgagctggacggcgacgtaaacggccacaagttcagcgtcagcggcgagggcgagggcgatgccacctacggcaaactgaccctgaagttcatctgcaccaccggcaagctgcccgtgccctggcccaccctcgtgaccaccttcggctacggcctgatgtgcttcgcccgctaccccgaccacatgaagcagcacgacttcttcaagtccgccatgcccgaaggctacgtccaggagcgcaccatcttcttcaaggacgacggcaactacaagacccgcgccgaagtgaagttcgagggcgacaccctcgtgaaccgcatcgagctaaagggcatcgacttcaaggaggacggcaacatcctggggcacaagctggagtacaactacaacagccacaacgtctatatcatggccgacaagcagaagaacggcatcaaagtgaacttcaagatccgccacaacatcgaggtgcttggnnnnnnnnnnnnnnnnnnnnnnnnnggtcgacccaggttcgtgnnnnnnnnnnnnnnnnnnnnnnnnngaggtattcttatcaccttcgtggctacagagtttccttatttgtctctgttgccggcttatatggacaagcatatcacagccatttatcggagcgcctccgtacacgctattatcggacgcctcgcgagatcaatacgtataccagctgccctcgatacatgtcttggcatcgtttgcttctcgagtactacctggttcctcttctttctttctcttctctttcaggacggcagcgtgcagctcgccgaccactaccagcagaacacccccatcggcgacggccccgtgctgctgcccgacaaccactacctgagctaccagtccgccctgagcaaagaccccaacgagaagcgcgatcacatggtcctgctggagttcgtgaccgccgccgggatcactctcggcatggacgagctgtacaagga'.upper()
        input_sequence=a5ss_full[:520]+A5SS_seqs[index]+a5ss_full[621:]
    
    if data_type=='A3SS':
        a3ss_full='atggtgtccaagggcgaggagctgttcaccggggtggtgcccatcctggtcgagctggacggcgacgtaaacggccacaagttcagcgtcagcggcgagggcgagggcgatgccacctacggcaaactgaccctgaagttcatctgcaccaccggcaagctgcccgtgccctggcccaccctcgtgaccaccttcggctacggcctgatgtgcttcgcccgctaccccgaccacatgaagcagcacgacttcttcaagtccgccatgcccgaaggctacgtccaggagcgcaccatcttcttcaaggacgacggcaactacaagacccgcgccgaagtgaagttcgagggcgacaccctcgtgaaccgcatcgagctaaagggcatcgacttcaaggaggacggcaacatcctggggcacaagctggagtacaactacaacagccacaacgtctatatcatggccgacaagcagaagaacggcatcaaagtgaacttcaagatccgccacaacatcgaggtaagttatcaccttcgtggctacagagtttccttatttgtctctgttgccggcttatatggacaagcatatcacagccatttatcggagcgcctccgtacacgctattatcggacgcctcgcgagatcaatacgtataccagctgccctcgatacatgtcttggacggggtcggtgttgatatcgtatNNNNNNNNNNNNNNNNNNNNNNNNNGCTTGGATCTGATCTCAACAGGGTNNNNNNNNNNNNNNNNNNNNNNNNNatgattacacatatagacacgcgagcacccatcttttatagaatgggtagaacccgtcctaaggactcagattgagcatcgtttgcttctcgagtactacctggtacagatgtctcttcaaacaggacggcagcgtgcagctcgccgaccactaccagcagaacacccccatcggcgacggccccgtgctgctgcccgacaaccactacctgagctaccagtccgccctgagcaaagaccccaacgagaagcgcgatcacatggtcctgctggagttcgtgaccgccgccgggatcactctcggcatggacgagctgtacaag'.upper()
        input_sequence=a3ss_full[:708]+A3SS_seqs[index]+a3ss_full[6+708+25+24+25:]
    
    context = 400

    x = one_hot_encode('N'*(context//2) + input_sequence + 'N'*(context//2))[None, :]
    y = np.mean([models[m].predict(x) for m in range(1)], axis=0)
    
    if data_type=='A5SS':
        return np.array([y[0][518+0][2],y[0][518+44][2]])
    
    if data_type=='A3SS':
        return np.array([y[0][518][0],y[0][519+235][1]])

In [10]:
def compare_three_model(index,data_type,models_400=models_400,models_reg=models_reg,W=mlr_W,b=mlr_b):
    result_mlr = test_with_mlr(index,data_type,W=mlr_W,b=mlr_b)[0]
    result_400 = test_with_400_model(index,data_type,models=models_400)
    result_reg = test_with_reg_model(index,data_type,models=models_reg)
    
    table_for_comparison=pd.DataFrame()
    
    if data_type=='A5SS':
        ground_truth=np.array(Y5[index])[0]
        table_for_comparison['Predicted_value'] = ['Prob(splicing at SD2)']
        a5ss_full='atggtgtccaagggcgaggagctgttcaccggggtggtgcccatcctggtcgagctggacggcgacgtaaacggccacaagttcagcgtcagcggcgagggcgagggcgatgccacctacggcaaactgaccctgaagttcatctgcaccaccggcaagctgcccgtgccctggcccaccctcgtgaccaccttcggctacggcctgatgtgcttcgcccgctaccccgaccacatgaagcagcacgacttcttcaagtccgccatgcccgaaggctacgtccaggagcgcaccatcttcttcaaggacgacggcaactacaagacccgcgccgaagtgaagttcgagggcgacaccctcgtgaaccgcatcgagctaaagggcatcgacttcaaggaggacggcaacatcctggggcacaagctggagtacaactacaacagccacaacgtctatatcatggccgacaagcagaagaacggcatcaaagtgaacttcaagatccgccacaacatcgaggtgcttggnnnnnnnnnnnnnnnnnnnnnnnnnggtcgacccaggttcgtgnnnnnnnnnnnnnnnnnnnnnnnnngaggtattcttatcaccttcgtggctacagagtttccttatttgtctctgttgccggcttatatggacaagcatatcacagccatttatcggagcgcctccgtacacgctattatcggacgcctcgcgagatcaatacgtataccagctgccctcgatacatgtcttggcatcgtttgcttctcgagtactacctggttcctcttctttctttctcttctctttcaggacggcagcgtgcagctcgccgaccactaccagcagaacacccccatcggcgacggccccgtgctgctgcccgacaaccactacctgagctaccagtccgccctgagcaaagaccccaacgagaagcgcgatcacatggtcctgctggagttcgtgaccgccgccgggatcactctcggcatggacgagctgtacaagga'.upper()
        input_sequence=a5ss_full[:520]+A5SS_seqs[index]+a5ss_full[621:]
        
        print(input_sequence[500:580])
        print(' '*18+'S'+' '*43+'S')
        print(' '*18+'D'+' '*43+'D')
        print(' '*18+'1'+' '*43+'2')
        
    if data_type=='A3SS':
        ground_truth=np.array(Y3[index])[0]
        table_for_comparison['Predicted_value'] = ['Prob(splicing at SA1)']
        a3ss_full='atggtgtccaagggcgaggagctgttcaccggggtggtgcccatcctggtcgagctggacggcgacgtaaacggccacaagttcagcgtcagcggcgagggcgagggcgatgccacctacggcaaactgaccctgaagttcatctgcaccaccggcaagctgcccgtgccctggcccaccctcgtgaccaccttcggctacggcctgatgtgcttcgcccgctaccccgaccacatgaagcagcacgacttcttcaagtccgccatgcccgaaggctacgtccaggagcgcaccatcttcttcaaggacgacggcaactacaagacccgcgccgaagtgaagttcgagggcgacaccctcgtgaaccgcatcgagctaaagggcatcgacttcaaggaggacggcaacatcctggggcacaagctggagtacaactacaacagccacaacgtctatatcatggccgacaagcagaagaacggcatcaaagtgaacttcaagatccgccacaacatcgaggtaagttatcaccttcgtggctacagagtttccttatttgtctctgttgccggcttatatggacaagcatatcacagccatttatcggagcgcctccgtacacgctattatcggacgcctcgcgagatcaatacgtataccagctgccctcgatacatgtcttggacggggtcggtgttgatatcgtatNNNNNNNNNNNNNNNNNNNNNNNNNGCTTGGATCTGATCTCAACAGGGTNNNNNNNNNNNNNNNNNNNNNNNNNatgattacacatatagacacgcgagcacccatcttttatagaatgggtagaacccgtcctaaggactcagattgagcatcgtttgcttctcgagtactacctggtacagatgtctcttcaaacaggacggcagcgtgcagctcgccgaccactaccagcagaacacccccatcggcgacggccccgtgctgctgcccgacaaccactacctgagctaccagtccgccctgagcaaagaccccaacgagaagcgcgatcacatggtcctgctggagttcgtgaccgccgccgggatcactctcggcatggacgagctgtacaag'.upper()
        input_sequence=a3ss_full[:708]+A3SS_seqs[index]+a3ss_full[6+708+25+24+25:]
        
        print(input_sequence[519+215:519+255])
        print(' '*20+'S')
        print(' '*20+'A')
        print(' '*20+'1')
    
    
    
    table_for_comparison['GT value'] = ground_truth[1:]
    table_for_comparison['MLR'] = result_mlr[1:]
    table_for_comparison['400 model'] = result_400[1:]
    table_for_comparison['Reg (Ours)'] = result_reg[1:]
    
    loss_mlr=(result_mlr[1]-ground_truth[1])**2
    loss_400=(result_400[1]-ground_truth[1])**2
    loss_reg=(result_reg[1]-ground_truth[1])**2
    print()
    print()
    print()
    print()
    print(loss_mlr,loss_400,loss_reg,sep='\t\t')
    return table_for_comparison

In [30]:
compare_three_model(230,'A5SS')

GATCCGCCACAACATCGAGGTGCTTGGGTGCAATTAGGAAAACAGAGAGATAGGTCGACCCAGGTTCGTGTGTGGCGGTG
                  S                                           S
                  D                                           D
                  1                                           2




0.001251994085161016		0.20592411273068562		0.024667417444087514


Unnamed: 0,Predicted_value,GT value,MLR,400 model,Reg (Ours)
0,Prob(splicing at SD2),0.621622,0.657005,0.167833,0.77868





<br/><br/><br/>



In [28]:
compare_three_model(2000,'A3SS')

CTTGGATCTGATCTCAACAGGGTTAAGTATAAAACAGCTT
                    S
                    A
                    1




0.02855732910840038		0.0013381397398868344		0.0009411041603451555


Unnamed: 0,Predicted_value,GT value,MLR,400 model,Reg (Ours)
0,Prob(splicing at SA1),0.0,0.168989,0.036581,0.030677


compare_three_model 실행 전 설정해야할 것은 1) index, 2) data_type입니다.

data_type은 'A5SS' 또는 'A3SS'이며
index는 data_type이 'A5SS'인 경우 0 부터 256478 까지,
data_type이 'A3SS'인 경우 0 부터 1686095까지 지정이 가능합니다.

In [31]:
Y5

matrix([[0.        , 1.        ],
        [0.20833333, 0.79166667],
        [0.5       , 0.5       ],
        ...,
        [1.        , 0.        ],
        [0.        , 1.        ],
        [0.02      , 0.98      ]])

In [12]:
make_mer_list(6)

['AAAAAA',
 'AAAAAT',
 'AAAAAC',
 'AAAAAG',
 'AAAATA',
 'AAAATT',
 'AAAATC',
 'AAAATG',
 'AAAACA',
 'AAAACT',
 'AAAACC',
 'AAAACG',
 'AAAAGA',
 'AAAAGT',
 'AAAAGC',
 'AAAAGG',
 'AAATAA',
 'AAATAT',
 'AAATAC',
 'AAATAG',
 'AAATTA',
 'AAATTT',
 'AAATTC',
 'AAATTG',
 'AAATCA',
 'AAATCT',
 'AAATCC',
 'AAATCG',
 'AAATGA',
 'AAATGT',
 'AAATGC',
 'AAATGG',
 'AAACAA',
 'AAACAT',
 'AAACAC',
 'AAACAG',
 'AAACTA',
 'AAACTT',
 'AAACTC',
 'AAACTG',
 'AAACCA',
 'AAACCT',
 'AAACCC',
 'AAACCG',
 'AAACGA',
 'AAACGT',
 'AAACGC',
 'AAACGG',
 'AAAGAA',
 'AAAGAT',
 'AAAGAC',
 'AAAGAG',
 'AAAGTA',
 'AAAGTT',
 'AAAGTC',
 'AAAGTG',
 'AAAGCA',
 'AAAGCT',
 'AAAGCC',
 'AAAGCG',
 'AAAGGA',
 'AAAGGT',
 'AAAGGC',
 'AAAGGG',
 'AATAAA',
 'AATAAT',
 'AATAAC',
 'AATAAG',
 'AATATA',
 'AATATT',
 'AATATC',
 'AATATG',
 'AATACA',
 'AATACT',
 'AATACC',
 'AATACG',
 'AATAGA',
 'AATAGT',
 'AATAGC',
 'AATAGG',
 'AATTAA',
 'AATTAT',
 'AATTAC',
 'AATTAG',
 'AATTTA',
 'AATTTT',
 'AATTTC',
 'AATTTG',
 'AATTCA',
 'AATTCT',
 'AATTCC',

In [35]:
def loss_three_model(index,data_type,models_400=models_400,models_reg=models_reg,W=mlr_W,b=mlr_b):
    result_mlr = test_with_mlr(index,data_type,W=mlr_W,b=mlr_b)[0]
    result_400 = test_with_400_model(index,data_type,models=models_400)
    result_reg = test_with_reg_model(index,data_type,models=models_reg)
    
    table_for_comparison=pd.DataFrame()
    
    if data_type=='A5SS':
        ground_truth=np.array(Y5[index])[0]
        table_for_comparison['Predicted_value'] = ['Prob(splicing at SD2)']
        
        
    if data_type=='A3SS':
        ground_truth=np.array(Y3[index])[0]
        table_for_comparison['Predicted_value'] = ['Prob(splicing at SA1)']
        
    
    
    table_for_comparison['GT value'] = ground_truth[1:]
    table_for_comparison['MLR'] = result_mlr[1:]
    table_for_comparison['400 model'] = result_400[1:]
    table_for_comparison['Reg (Ours)'] = result_reg[1:]
    
    loss_mlr=(result_mlr[1]-ground_truth[1])**2
    loss_400=(result_400[1]-ground_truth[1])**2
    loss_reg=(result_reg[1]-ground_truth[1])**2
    
    return loss_mlr,loss_400,loss_reg

In [43]:
total_loss_mlr_A5SS=0.0
total_loss_400_A5SS=0.0
total_loss_reg_A5SS=0.0

for i in range(10005):
    l1,l2,l3=loss_three_model(i,'A5SS')
    total_loss_mlr_A5SS+=l1
    total_loss_400_A5SS+=l2
    total_loss_reg_A5SS+=l3
    
    if (i+1)%200==0:
        print(str(i+1)+' : '+str(total_loss_mlr_A5SS)+'\t'+str(total_loss_400_A5SS)+'\t'+str(total_loss_reg_A5SS))

200 : 15.211634325365829	36.26701621534211	7.920761035785103
400 : 27.451716385014514	72.17944958586234	15.662369405208961
600 : 38.80259052405552	105.08625736757443	23.470922829645662
800 : 50.797470926528916	143.068790487604	30.71276877888431
1000 : 63.302545200597365	175.43511081576966	39.340366749364414
1200 : 77.39306356074455	209.57049138111984	49.99552130875865
1400 : 94.3897542869743	243.15684905957602	61.06788110456046
1600 : 108.34215608144389	278.30736513500347	71.51971141899838
1800 : 120.92029059356187	309.07891530120315	78.95042043729549
2000 : 137.3112132740962	345.9904569390897	90.66173312429181
2200 : 151.7194789088601	374.51172368177174	99.3453330411633
2400 : 164.8951816025877	406.4589881108545	109.32081082004352
2600 : 179.72245501753466	442.2973958979876	118.6025626859351
2800 : 193.95082590537476	477.75558995413763	130.1098857306747
3000 : 208.80596218864636	515.0576458444696	142.00777664736128
3200 : 222.01265398323795	544.9411991843127	154.48220314238196
3400 : 

In [44]:
total_loss_mlr_A3SS=0.0
total_loss_400_A3SS=0.0
total_loss_reg_A3SS=0.0

for i in range(10005):
    l1,l2,l3=loss_three_model(i,'A3SS')
    total_loss_mlr_A3SS+=l1
    total_loss_400_A3SS+=l2
    total_loss_reg_A3SS+=l3
    
    if (i+1)%200==0:
        print(str(i+1)+' : '+str(total_loss_mlr_A3SS)+'\t'+str(total_loss_400_A3SS)+'\t'+str(total_loss_reg_A3SS))

200 : 6.975147851527837	6.13760503728897	5.199875048534687
400 : 8.758340236528225	7.863559132362075	7.19353230398935
600 : 14.1548261205811	12.137327024224307	11.487695741091974
800 : 17.74356295847219	15.152755484427647	14.31757193229365
1000 : 20.99761806625819	17.748699948137396	17.54142811673083
1200 : 27.77220306313384	24.361938646893215	23.40711815008055
1400 : 32.94074348401881	29.13485469285761	27.378294063465532
1600 : 37.3717796674579	33.12134251207307	31.021745433982524
1800 : 40.78080010765584	36.26937801189369	33.95483098417979
2000 : 42.4726892122847	38.012871845822815	35.87755776939405
2200 : 45.32012388975555	40.5159810142853	38.158160327758026
2400 : 48.55666019481435	43.71141379567839	41.358677118209755
2600 : 53.975588323663395	48.01186097679334	44.89275009511866
2800 : 61.45231031122926	55.02872891393231	51.287148296242854
3000 : 65.98417384806028	59.16321534281176	55.57728591475167
3200 : 69.9844201316308	62.85215147929305	58.698568407590756
3400 : 75.710142231211

In [32]:
def compare_three_model_loss(index,all_seqs=A5SS_seqs,models_400=models_400,models_reg=models_reg,wfull_mlr=wfull_mlr,w0_mlr=w0_mlr):
    result_mlr = test_with_mlr(index,all_seqs=A5SS_seqs,wfull_mlr=wfull_mlr,w0_mlr=w0_mlr)
    result_400 = test_with_400_model(index,all_seqs=A5SS_seqs,models=models_400)
    result_reg = test_with_reg_model(index,all_seqs=A5SS_seqs,models=models_reg)
    
    result_mlr_crop=np.array(result_mlr)[0][:-1]
    result_400_crop=result_400[518:518+80]
    result_reg_crop=result_reg[518:518+80]
    
    ground_truth_label = A5SS_data[index]
    target_no_splicing = float(A5SS_data[index][-1])
    
    gt_list = []
    for i in range(80):
        if A5SS_data[index][i]>0.001:
            gt_list.append([A5SS_data[index][i],i])
    
    gt_list.sort()

    list_for_index = []
    list_for_prob = []
    
    
    pred_of_mlr=[]
    pred_of_400=[]
    pred_of_reg=[]
    
    remain_400 = 1.0
    remain_reg = 1.0
    
    
    for i in range(len(gt_list)-1,-1,-1):
        list_for_index.append(gt_list[i][1])
        list_for_prob.append(gt_list[i][0])
        
        remain_400-=result_400_crop[gt_list[i][1]]
        remain_reg-=result_reg_crop[gt_list[i][1]]
        
        pred_of_mlr.append(result_mlr_crop[gt_list[i][1]])
        pred_of_400.append(result_400_crop[gt_list[i][1]])
        pred_of_reg.append(result_reg_crop[gt_list[i][1]])

        
    pred_of_mlr.append(float(np.array(result_mlr)[0][-1]))
    pred_of_400.append(remain_400)
    pred_of_reg.append(remain_reg)
        
    list_for_index.append('no splicing')
    list_for_prob.append(target_no_splicing)
    
    

    table_for_comparison=pd.DataFrame()
    
    table_for_comparison['GT index'] = list_for_index
    table_for_comparison['GT prob'] = list_for_prob
        
    table_for_comparison['MLR'] = pred_of_mlr
    table_for_comparison['400'] = pred_of_400
    table_for_comparison['Reg'] = pred_of_reg
    
    loss_of_mlr = (target_no_splicing-pred_of_mlr[-1])**2
    loss_of_400 = (target_no_splicing-pred_of_400[-1])**2
    loss_of_reg = (target_no_splicing-pred_of_reg[-1])**2
    
    for i in range(80):
        loss_of_mlr += (result_mlr_crop[i]-ground_truth_label[i])**2
        loss_of_400 += (result_400_crop[i]-ground_truth_label[i])**2
        loss_of_reg += (result_reg_crop[i]-ground_truth_label[i])**2
    
    return loss_of_mlr,loss_of_400,loss_of_reg

In [33]:
loss_mlr=0.0
loss_400=0.0
loss_reg=0.0

In [34]:
for i in range(265044):
    losses=compare_three_model_loss(i)
    loss_mlr += losses[0]
    loss_400 += losses[1]
    loss_reg += losses[2]
    
    if (i+1)%200==0:
        print(str(i+1)+' : '+str(loss_mlr)+', '+str(loss_400)+', '+str(loss_reg))


200 : 14.034362739025031, 53.870041360677035, 10.327790633518767
400 : 31.27900068271353, 110.948312118355, 23.351025289971734
600 : 47.69568887284781, 165.29272975364836, 34.14324705077355
800 : 63.45252073881033, 221.36201482632674, 44.962304772366394
1000 : 79.41643205717467, 266.61665434277546, 57.205967268236215
1200 : 95.40796434105665, 317.07767085236003, 71.04337881361691
1400 : 113.04432143028748, 371.2019594311144, 83.50086954826428
1600 : 128.12318720366795, 423.12193716752626, 94.83987944604561


KeyboardInterrupt: 

In [None]:
scvsvvv