In [1]:
import click as ck
import numpy as np
import pandas as pd
from collections import deque, Counter,OrderedDict,defaultdict
import sys
import os
import pickle as pkl

tag_dict = {'biological_process':'bp','cellular_component':'cc','molecular_function':'mf'}

print('done')

done


In [2]:

class Ontology(object):

    def __init__(self, filename='data/go.obo', with_rels=False):
        self.ont,self.alts = self.load(filename, with_rels)
        self.ic = None

    def has_term(self, term_id):
        return term_id in self.ont

    def calculate_ic(self, annots):
        cnt = Counter()
        for x in annots:
            cnt.update(x)
        self.ic = {}
        for go_id, n in cnt.items():
            parents = self.get_parents(go_id)
            if len(parents) == 0:
                min_n = n
            else:
                min_n = min([cnt[x] for x in parents])
            self.ic[go_id] = math.log(min_n / n, 2)
    
    def get_ic(self, go_id):
        if self.ic is None:
            raise Exception('Not yet calculated')
        if go_id not in self.ic:
            return 0.0
        return self.ic[go_id]

    def load(self, filename, with_rels):
        ont = dict()
        alts = dict()
        obj = None
        with open(filename, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                if line == '[Term]':
                    if obj is not None:
                        ont[obj['id']] = obj
                    obj = dict()
                    obj['is_a'] = list()
                    obj['part_of'] = list()
                    obj['regulates'] = list()
                    obj['alt_ids'] = list()
                    obj['is_obsolete'] = False
                    continue
                elif line == '[Typedef]':
                    obj = None
                else:
                    if obj is None:
                        continue
                    l = line.split(": ")
                    if l[0] == 'id':
                        obj['id'] = l[1]
                    elif l[0] == 'alt_id':
                        obj['alt_ids'].append(l[1])
                    elif l[0] == 'namespace':
                        obj['namespace'] = l[1]
                    elif l[0] == 'is_a':
                        obj['is_a'].append(l[1].split(' ! ')[0])
                    elif with_rels and l[0] == 'relationship':
                        it = l[1].split()
                        # add all types of relationships
                        if it[0] == 'part_of':
                            obj['is_a'].append(it[1])
#                             obj['part_of'].append(it[1])
                            
                            
                    elif l[0] == 'name':
                        obj['name'] = l[1]
                    elif l[0] == 'is_obsolete' and l[1] == 'true':
                        obj['is_obsolete'] = True
        if obj is not None:
            ont[obj['id']] = obj
        for term_id in list(ont.keys()):
            for t_id in ont[term_id]['alt_ids']:
                ont[t_id] = ont[term_id]
            if ont[term_id]['is_obsolete']:
                alts[term_id] = ont[term_id]
                del ont[term_id]
                
        for term_id, val in ont.items():
            if 'children' not in val:
                val['children'] = set()
            for p_id in val['is_a']:
                if p_id in ont:
                    if 'children' not in ont[p_id]:
                        ont[p_id]['children'] = set()
                    ont[p_id]['children'].add(term_id)
        return ont,alts


    def get_anchestors(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        q = deque()
        q.append(term_id)
        while(len(q) > 0):
            t_id = q.popleft()
            if t_id not in term_set:
                term_set.add(t_id)
                for parent_id in self.ont[t_id]['is_a']:
                    if parent_id in self.ont:
                        q.append(parent_id)
        return term_set


    def get_parents(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        for parent_id in self.ont[term_id]['is_a']:
            if parent_id in self.ont:
                term_set.add(parent_id)
        return term_set


    def get_namespace_terms(self, namespace):
        terms = set()
        for go_id, obj in self.ont.items():
            if obj['namespace'] == namespace:
                terms.add(go_id)
        return terms

    def get_namespace(self, term_id):
        return self.ont[term_id]['namespace']
    
    def get_term_set(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        q = deque()
        q.append(term_id)
        while len(q) > 0:
            t_id = q.popleft()
            if t_id not in term_set:
                term_set.add(t_id)
                for ch_id in self.ont[t_id]['children']:
                    q.append(ch_id)
        return term_set
    
def read_pkl(input_file):
    with open(input_file,'rb') as fr:
        temp_result = pkl.load(fr)
    
    return temp_result

def save_pkl(output_file,data):
    with open(output_file,'wb') as fw:
        pkl.dump(data,fw)
        
def get_label(anations,func_list):
    temp_result = []
    for label in func_list:
        if label in anations:
            temp_result.append(1)
        else:
            temp_result.append(0)
    return np.array(temp_result)

# def getTag(go,annotation):
#     alt_items = {'GO:0030819':'biological_process','GO:0030818':'biological_process','GO:0100036':'biological_process',
#             'GO:1903474':'biological_process','GO:0008565':'molecular_function','GO:0000991':'molecular_function',
#             'GO:0001029':'molecular_function','GO:0005430':'molecular_function','GO:0005395':'molecular_function',
#             'GO:0001191':'molecular_function','GO:0005623':'cellular_component'}
#     if annotation in alt_items:
#         tags = tag_dict[alt_items[annotation]]
#     else:
#         tags = tag_dict[go.get_namespace(annotation)]
        
#     return tags

def getTag(go,annotation):
    
    tags = tag_dict[go.get_namespace(annotation)]
    return tags
    
go = Ontology('/home/wbshi/work/swissprot_data/train_test_data_handled_v4/handled_protein_messages/go.obo', with_rels=True)
print('done')

done


In [6]:
def read_fasta(filename):
    seqs = list()
    info = list()
    seq = ''
    inf = ''
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if seq != '':
                    seqs.append(seq)
                    info.append(inf)
                    seq = ''
                inf = line[1:]
            else:
                seq += line
        seqs.append(seq)
        info.append(inf)
    return info, seqs

print('done')

done


In [10]:
#目前已经获取test数据集在train数据集上经过blast工具获取到相似蛋白，当前cell提取这些值
# 其中，格式6、格式7、格式10、格式17的输出条目是可以修改的。输出格式选择 6 （--outfmt 6） ，
# 6是tabular格式对应BLAST的m8格式;
# 默认输出为：qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore 。
# 这12列对应的信息分别是：

# Query id：查询序列ID标识
# Subject id：比对上的目标序列ID标识
# % identity：序列比对的一致性百分比
# alignment length：符合比对的比对区域的长度
# mismatches：比对区域的错配数
# gap openings：比对区域的gap数目
# q. start：比对区域在查询序列(Query id)上的起始位点
# q. end：比对区域在查询序列(Query id)上的终止位点
# s. start：比对区域在目标序列(Subject id)上的起始位点
# s. end：比对区域在目标序列(Subject id)上的终止位点
# e-value：比对结果的期望值，将比对序列随机打乱重新组合，和数据库进行比对，如果功能越保守，则该值越低；
#该E值越高说明比对的高得分值是由GC区域，重复序列导致的。对于判断同源性是非常有意义的几个参数。
# bit score：比对结果的bit score值

def read_blast_result(input_file,output_file):
    
    result_dict = {}
    train_data = read_pkl('/home/wbshi/work/swissprot_data/train_test_data_handled_v4/train_data_separate.pkl')
    
    with open(input_file,'r') as fr:
        for line in fr:
            line = line.strip()
            if line:
                line = line.split()
                # line[0] = line[0].split('|')[1]
                # line[1] = line[1].split('|')[1]

                # print(line)
                # sys.exit(0)
                if line[1] not in train_data:
                    continue
                    
                if line[0] not in result_dict:
                    result_dict[line[0]] = OrderedDict()
#                     result_dict[line[0]] = {}
#                 if line[0] == line[1]:
#                     continue
#                 print(line)
                if line[1] not in result_dict[line[0]]:
                    result_dict[line[0]][line[1]] = []
            
#                 result_dict[line[0]].append((line[1],line[2:]))
                result_dict[line[0]][line[1]].append(line[2:])
                
    save_pkl(output_file,result_dict)

    
output_path = './'
if not os.path.exists(output_path):
    os.mkdir(output_path)
    
blast_result = './test_one_all_sequence.blast'
output_file = output_path + 'test_one_all_sequence_result_score.pkl'

read_blast_result(blast_result,output_file)

print('test_one_all_sequence done')


blast_result = './test_two_all_sequence.blast'
output_file = output_path + 'test_two_all_sequence_result_score.pkl'

read_blast_result(blast_result,output_file)

print('test_two_all_sequence done')

test_one_all_sequence done
test_two_all_sequence done


In [5]:
#提取blast_knn的score值
def comput_blast_knn(k,save_path, test_proteins):
    test_scores = {}
    for key,value in test_blast_data.items():
        proteinId = key
        
        test_scores[proteinId] = {}
        
        all_count = 0
        for i,(sim,score_list) in enumerate(value.items()):
            if i >= k:
                break
            all_count += float(score_list[0][-1])

        test_scores[proteinId]['sim_func'] = defaultdict(float)
        test_scores[proteinId]['bp'] = defaultdict(float)
        test_scores[proteinId]['cc'] = defaultdict(float)
        test_scores[proteinId]['mf'] = defaultdict(float)
        
        for i,(sim,score_list) in enumerate(value.items()):
            if i >= k:
                break
            
            for func in list(train_data[sim]['all_bp'] | train_data[sim]['all_cc'] | train_data[sim]['all_mf']):
                
                test_scores[proteinId]['sim_func'][func] += float(score_list[0][-1])/all_count
                inner_tag = tag_dict[go.get_namespace(func)]

                test_scores[proteinId][inner_tag][func] += float(score_list[0][-1])/all_count
    

    
    for p in test_proteins.keys():
        if p not in test_scores:
            test_scores[p] = {}
            test_scores[p]['sim_func'] = {}
            test_scores[p]['bp'] = {}
            test_scores[p]['cc'] = {}
            test_scores[p]['mf'] = {}

    save_pkl(save_path.format(k),test_scores) 

if not os.path.exists("./predictions"):
    os.mkdir("./predictions")

input_file = './test_one_all_sequence_result_score.pkl'
test_blast_data = read_pkl(input_file)
train_data = read_pkl('/home/wbshi/work/swissprot_data/train_test_data_handled_v4/train_data_separate.pkl')
save_path = './predictions/test_one_all_sequence_blast_knn_{0}_predict_score.pkl'

test_proteins = read_pkl('/home/wbshi/work/swissprot_data/train_test_data_handled_v4/test_one_data_separate.pkl')
for k in [10,20,30,40,50,50000]:
    print(k)
    comput_blast_knn(k,save_path, test_proteins)
    
print('test_one_all_sequence done')


input_file = './test_two_all_sequence_result_score.pkl'
test_blast_data = read_pkl(input_file)
train_data = read_pkl('/home/wbshi/work/swissprot_data/train_test_data_handled_v4/train_data_separate.pkl')
save_path = './predictions/test_two_all_sequence_blast_knn_{0}_predict_score.pkl'

test_proteins = read_pkl('/home/wbshi/work/swissprot_data/train_test_data_handled_v4/test_two_data_separate.pkl')
for k in [3,5,10,20,30,40,50,50000]:

    print(k)
    comput_blast_knn(k,save_path, test_proteins)
    
print('test_two_all_sequence done')

2
test_one_all_sequence done
2
test_two_all_sequence done
