In [1]:
import pickle
import numpy as np
import pickle
import matplotlib.pyplot as plt
import time
from scipy.stats import pearsonr, wilcoxon
from sklearn.preprocessing import MinMaxScaler

In [3]:
# with open ("protein2.pkl", "rb") as f:
#     protein2 = pickle.load(f)

with open('protein2_full.pkl', 'rb') as f:
    protein2 = pickle.load(f)

protein2 = {k:v for k,v in protein2.items() if v[0]=='chr1'}

In [4]:
len(protein2)

1880

In [5]:
# [2] is txstart, [3] is txend
genes = np.array([[protein2[i][2],protein2[i][3]] for i in protein2])

#genes = np.array([[protein2[i][2],protein2[i][3]] for i in protein2 if protein2[i][5]])

In [6]:
genes.shape

(1880, 2)

In [7]:
def check_left(genes, current_idx, threshold):
    left_neighbor_idx = [current_idx]
    for i in range(current_idx-1, -1, -1):
        if genes[current_idx][0] - genes[i][1] <= threshold:
            left_neighbor_idx.append(i)
        else:
            break
    return left_neighbor_idx

def check_right(genes, current_idx, threshold):
    right_neighbor_idx = [current_idx]
    for i in range(current_idx+1, len(genes)):
        if genes[i][0] - genes[current_idx][1] <= threshold:
            right_neighbor_idx.append(i)
        else:
            break
    return right_neighbor_idx

def get_sets(genes, threshold):
    sets = []
    for i in range(len(genes)):
        left_set = check_left(genes, i, threshold)
        right_set = check_right(genes, i, threshold)
        current_set = sorted(list(set(left_set + right_set)))
        sets.append(current_set)

    return sets
 

In [8]:
s = get_sets(genes, 40000)

In [9]:
# LOAD the gene info to find closest neighbor, strand, info
with open ("filtered_gene_dataframe.pkl", "rb") as f:
    gene_df = pickle.load(f)

In [12]:
def avg_dst2(get_to, genes):
    dcts = []
    # get_to contains the gene set idx
    for path_idx in get_to:
        #print( f'-->path (set index) {path_idx}')
        covered = set()
        dst_dict = {}
        # path contains the actual idx of the gene
        path = [s[p] for p in path_idx]
        #print(f'    -->expanded path {path}')
        for idx, p in enumerate(path):
            #print(f'        current gene set {p}, center gene is {path_idx[idx]}')
            # get center element
            center_idx = p.index(path_idx[idx])
            # if current set contains overlapping elements
            if not set(p) & set(covered):
                for j, gene_idx in enumerate(p):
                    if j != center_idx:
                        dst = (
                            genes[p[center_idx]][0] - genes[gene_idx][1]
                            if j < center_idx
                            else genes[gene_idx][0] - genes[p[center_idx]][1]
                        )
                        # overlap or inclusive
                        dst_dict[gene_idx] = max(0, dst)
                covered.update(p)
            else:
                # if current set contains overlapping elements
                for j, gene_idx in enumerate(p):
                    if j != center_idx:
                        dst = (
                            genes[p[center_idx]][0] - genes[gene_idx][1]
                            if j < center_idx
                            else genes[gene_idx][0] - genes[p[center_idx]][1]
                        )
                        dst = max(0, dst)
                        if gene_idx not in dst_dict or dst < dst_dict[gene_idx]:
                            dst_dict[gene_idx] = dst
                covered.update(p)
        # calculate avg distance
        if dst_dict:
            dsts = sum(list(dst_dict.values()))
            avgs = dsts / len(dst_dict)
            dcts.append(avgs)
        # if single element
        else:
            dcts.append(0)
    print(dcts)
    # get all min elements
    min_idx = [i for i,x in enumerate(dcts) if x == min(dcts)]
    res = [get_to[i] for i in min_idx]
    return res


# def cal_avg_dst(get_to, genes):
#     avg_dst = []
#     for i in get_to:
#         path = [s[p] for p in i]
#         dst = 0
#         count = 0
#         for j in path:
#             if len(j) <= 1:
#                 continue
#             else:
#                 center_idx = len(j)//2
#                 for k in range(len(j)):
#                     if k < center_idx:
#                         dst += abs(genes[j[center_idx]][0] - genes[j[k]][1])
#                         count += 1
#                     elif k > center_idx:
#                         dst += abs(genes[j[k]][0] - genes[j[center_idx]][1])
#                         count += 1
#         if dst == 0:
#             avg_dst.append(0)
#         else:
#             avg_dst.append(dst/count)

#     min_idx = [i for i,x in enumerate(avg_dst) if x == min(avg_dst)]
#     res = [get_to[i] for i in min_idx]
#     return res
    

def check_overlap(get_to, gene_set):
    #find which path contains the least amount of genes
    overlap_counter = [[], float('inf')]
    no_overlap_get_to = []
    for i in range(len(get_to)):
        lst = []
        for j in get_to[i]:
            lst += gene_set[j]
        # no overlap
        #print(lst)
        if len(lst) == len(set(lst)):
            no_overlap_get_to.append(get_to[i])
        else:
            if abs(len(lst) - len(set(lst))) < overlap_counter[1]:
                overlap_counter[0] += (get_to[i])
                overlap_counter[1] = abs(len(lst) - len(set(lst)))
    if len(no_overlap_get_to) != 0:
        return no_overlap_get_to
    else:
        #print(overlap_counter)
        return [overlap_counter[0]]


def solve_gene(gene_set, genes):
    #start_time = time.time()
    # the start state is the first gene
    # 1st idx: the current gene idx (starting at 0-th gene)
    # 2nd idx: path that gets to the current gene. since we will begin at the first gene (for i in range(1,...)), the path that get to the first gene is the 0-th gene.
    # 3rd idx: max() indicates the rightmost gene in the current gene's neighborhood
    # 4th idx: the number of solves needed. Since we start at 0-th gene, we need to solve it.
    states = [0,[[0]],max(gene_set[0]),1]
    #print(states)
    prev_set = []
    # since 0-th gene is already solved, we begin at the 1st gene.
    start_time = time.time()
    for i in range(1, len(gene_set)):
        print(f' i = {i}')
        
        #print(gene_set[i])
        #if gene_set[i] == prev_set:
            #continue
        
        # flag indicate when we need to increment the amount of solves
        increase_solve = False
        # what genes (path) that can get to the current gene
        get_to = []
        # check what genes can get to the previous gene and use that to start, DP
        for j in states[1]:
            #print(f'j = {j}')
            # keeps a record of the path that gets to the previous gene. when starting, 0-th gene can get to 0-th gene.
            prev_path = j.copy()
            # since the path that gets to the previous gene is already recorded, when moving on to the next gene,
            # we need to add the previous gene to the previous path, so it becomes the path that gets to the current_gene.
            prev_path.append(states[0])
            #print(f'prev_path after appending previous gene {prev_path}')
            # these are the new paths that can get to the current gene
            # since there might be repetitive paths, use set.
            get_to.append(sorted(list(set(prev_path))))
            # check overlap sets or direct next set
            if (gene_set[i][0] <= gene_set[j[-1]][-1] <= gene_set[i][-1]) or (gene_set[i][0] == gene_set[j[-1]][-1] + 1): 
                get_to.append(j.copy())
                #print(f'get to after appending duplicate {get_to}')
            # if previous gene's set is a subset of the current, then previous and current are parallel solution
            if set(gene_set[j[-1]]).issubset(gene_set[i]) and gene_set[i][0] == gene_set[0][0]:
                # since previous can get to current, parallel solution added
                get_to.append([i])
                # parallel solution, no need to increase the number of solves
                increase_solve = True
                #print(f'get to after appending parallel {get_to}')

            # eliminate duplicates
            unique_set = {tuple(element) for element in get_to}
            get_to = sorted([list(element) for element in unique_set])

            # only keep the least amount of solves, this is the objective of DP
            min_length = min(len(i) for i in get_to)
            get_to = [i for i in get_to if len(i) == min_length]

            # if competing solution exist, use min_avg dst to break ties
            #reduced_get_to = check_overlap(get_to, gene_set)
            #reduced_get_to = cal_avg_dst(get_to, genes)
            get_to = avg_dst2(get_to, genes)
            #print(f'reduced get to {reduced_get_to}')

            # if parallel solution found, the number of solves is equal to previous number of solves
            if increase_solve:
                current_solves = states[-1]
            # if no parallel solution, add 1 to previous number of solves
            else:
                current_solves = min([len(i) for i in get_to])+1

        # track prev set to determine if completely identical set
        prev_set = gene_set[i]
        
        print(len(get_to))
        print(f'get to is {get_to}')
        #print(f'took {end_time-start_time}s')

        # update the state 
        states = [i, get_to, max(gene_set[i]), current_solves]
        # reset the parallel solution counter
        increase_solve = False
        #print(f'current state {states}')
        
    #print(states)

    solution = []
    #final_states = states[1].copy()
    for i in states[1]:
        # if solution covers to the end
        if gene_set[i[-1]][-1] != states[2]:
            i.append(states[0])
        
        else:
            i.append(gene_set[-1][-1])

        # check if optimal solution
        if len(i) == states[-1]:
            solution.append(i)

    end_time = time.time()
    print(end_time-start_time)
    
    return solution
    #print(f'final solution {states}')


In [13]:
res = solve_gene(s, genes)

 i = 1
[0]
1
get to is [[0]]
 i = 2
[0]
1
get to is [[0, 1]]
 i = 3
[0]
1
get to is [[0, 1, 2]]
 i = 4
[0]
1
get to is [[0, 1, 2]]
 i = 5
[0]
1
get to is [[0, 1, 2]]
 i = 6
[0]
1
get to is [[0, 1, 2]]
 i = 7
[0]
1
get to is [[0, 1, 2]]
 i = 8
[20334.428571428572]
1
get to is [[0, 1, 2, 7]]
 i = 9
[20334.428571428572]
1
get to is [[0, 1, 2, 7]]
 i = 10
[20334.428571428572]
1
get to is [[0, 1, 2, 7]]
 i = 11
[20334.428571428572]
1
get to is [[0, 1, 2, 7]]
 i = 12
[20334.428571428572]
1
get to is [[0, 1, 2, 7]]
 i = 13
[17192.125]
1
get to is [[0, 1, 2, 7, 12]]
 i = 14
[17192.125]
1
get to is [[0, 1, 2, 7, 12]]
 i = 15
[17192.125]
1
get to is [[0, 1, 2, 7, 12]]
 i = 16
[17192.125]
1
get to is [[0, 1, 2, 7, 12]]
 i = 17
[17192.125]
1
get to is [[0, 1, 2, 7, 12]]
 i = 18
[17291.428571428572]
1
get to is [[0, 1, 2, 7, 12, 17]]
 i = 19
[17291.428571428572]
1
get to is [[0, 1, 2, 7, 12, 17]]
 i = 20
[17291.428571428572]
1
get to is [[0, 1, 2, 7, 12, 17]]
 i = 21
[17291.428571428572]
1
get to i