In [1]:
import pickle
import numpy as np
import pickle
import time

In [7]:
def check_left(genes, current_idx, threshold):
    """Checks index of genes to the left of the current gene within a specified threshold.

    Args:
        genes (list): A list of list containing each gene's txstart and txend information.
        current_idx (int): The index of the current gene in the `genes` list.
        threshold (int): The maximum distance (in base pairs) that defines the vicinity.

    Returns:
        list: A list of indices representing genes to the left of the
              current gene that fall within the specified threshold.
    """
    left_neighbor_idx = [current_idx]
    for i in range(current_idx-1, -1, -1):
        if genes[current_idx][0] - genes[i][1] <= threshold:
            left_neighbor_idx.append(i)
        else:
            break
    return left_neighbor_idx

def check_right(genes, current_idx, threshold):
    """Checks index of genes to the right of the current gene within a specified threshold.

    Args:
        genes (list): A list of list containing each gene's txstart and txend information.
        current_idx (int): The index of the current gene in the `genes` list.
        threshold (int): The maximum distance (in base pairs) that defines the vicinity.

    Returns:
        list: A list of indices representing genes to the right of the
              current gene that fall within the specified threshold.
    """
    right_neighbor_idx = [current_idx]
    for i in range(current_idx+1, len(genes)):
        if genes[i][0] - genes[current_idx][1] <= threshold:
            right_neighbor_idx.append(i)
        else:
            break
    return right_neighbor_idx

def get_sets(genes, threshold):
    """Generates gene sets represented by gene indices using a specified threshold.

    Args:
        genes (list): A list of list containing each gene's txstart and txend information.
        threshold (int): The maximum distance (in base pairs) that defines the vicinity.

    Returns:
        list: gene sets using a specified threshold.
    """
    sets = []
    for i in range(len(genes)):
        left_set = check_left(genes, i, threshold)
        right_set = check_right(genes, i, threshold)
        current_set = sorted(list(set(left_set + right_set)))
        sets.append(current_set)

    return sets

def avg_dst2(get_to, genes, gene_set):
    """Given a path set, Calculates the average distance from each non-ref gene to the reference gene.

    Args:
        get_to (list): A list of possible path sets.
        genes (list): A list of list containing each gene's txstart and txend information.
        s (list): gene sets using a specified threshold.

    Returns:
        list: A list of path set that have the minimum average distance between each non-ref gene to the reference gene..
    """
    dcts = []
    # get_to contains the gene set idx
    for path_idx in get_to:
        #print( f'-->path (set index) {path_idx}')
        covered = set()
        dst_dict = {}
        # path contains the actual idx of the gene
        path = [gene_set[p] for p in path_idx]
        #print(f'    -->expanded path {path}')
        for idx, p in enumerate(path):
            #print(f'        current gene set {p}, center gene is {path_idx[idx]}')
            # get center element
            center_idx = p.index(path_idx[idx])
            # if current set contains overlapping elements
            if not set(p) & set(covered):
                for j, gene_idx in enumerate(p):
                    if j != center_idx:
                        dst = (
                            genes[p[center_idx]][0] - genes[gene_idx][1]
                            if j < center_idx
                            else genes[gene_idx][0] - genes[p[center_idx]][1]
                        )
                        # overlap or inclusive
                        dst_dict[gene_idx] = max(0, dst)
                covered.update(p)
            else:
                # if current set contains overlapping elements
                for j, gene_idx in enumerate(p):
                    if j != center_idx:
                        dst = (
                            genes[p[center_idx]][0] - genes[gene_idx][1]
                            if j < center_idx
                            else genes[gene_idx][0] - genes[p[center_idx]][1]
                        )
                        dst = max(0, dst)
                        if gene_idx not in dst_dict or dst < dst_dict[gene_idx]:
                            dst_dict[gene_idx] = dst
                covered.update(p)
        # calculate avg distance
        if dst_dict:
            dsts = sum(list(dst_dict.values()))
            avgs = dsts / len(dst_dict)
            dcts.append(avgs)
        # if single element
        else:
            dcts.append(0)
    #print(dcts)
    # get all min elements
    min_idx = [i for i,x in enumerate(dcts) if x == min(dcts)]
    res = [get_to[i] for i in min_idx]
    return res

def solve_gene(gene_set, genes):
    """Given all the gene sets determined by the threshold, solve the vicinity set cover problem.

    Args:
        gene_set (list): gene sets using a specified threshold.
        genes (list): A list of list containing each gene's txstart and txend information.

    Returns:
        list: A list of gene sets satisfying the vicinity set cover problem.
    """
    #start_time = time.time()
    # the start state is the first gene
    # 1st idx: the current gene idx (starting at 0-th gene)
    # 2nd idx: path that gets to the current gene. since we will begin at the first gene (for i in range(1,...)), the path that get to the first gene is the 0-th gene.
    # 3rd idx: max() indicates the rightmost gene in the current gene's neighborhood
    # 4th idx: the number of solves needed. Since we start at 0-th gene, we need to solve it.
    states = [0,[[0]],max(gene_set[0]),1]
    #print(states)
    # since 0-th gene is already solved, we begin at the 1st gene.
    #start_time = time.time()
    for i in range(1, len(gene_set)):
        print(f'processing gene #{i}')
        
        #print(gene_set[i])
        #if gene_set[i] == prev_set:
            #continue
        
        # flag indicate when we need to increment the amount of solves
        increase_solve = False
        # what genes (path) that can get to the current gene
        get_to = []
        # check what genes can get to the previous gene and use that to start, DP
        for j in states[1]:
            #print(f'j = {j}')
            # keeps a record of the path that gets to the previous gene. when starting, 0-th gene can get to 0-th gene.
            prev_path = j.copy()
            # since the path that gets to the previous gene is already recorded, when moving on to the next gene,
            # we need to add the previous gene to the previous path, so it becomes the path that gets to the current_gene.
            prev_path.append(states[0])
            #print(f'prev_path after appending previous gene {prev_path}')
            # these are the new paths that can get to the current gene
            # since there might be repetitive paths, use set.
            get_to.append(sorted(list(set(prev_path))))
            # check overlap sets or direct next set
            if (gene_set[i][0] <= gene_set[j[-1]][-1] <= gene_set[i][-1]) or (gene_set[i][0] == gene_set[j[-1]][-1] + 1): 
                get_to.append(j.copy())
                #print(f'get to after appending duplicate {get_to}')
            # if previous gene's set is a subset of the current, then previous and current are parallel solution
            if set(gene_set[j[-1]]).issubset(gene_set[i]) and gene_set[i][0] == gene_set[0][0]:
                # since previous can get to current, parallel solution added
                get_to.append([i])
                # parallel solution, no need to increase the number of solves
                increase_solve = True
                #print(f'get to after appending parallel {get_to}')

            # eliminate duplicates
            unique_set = {tuple(element) for element in get_to}
            get_to = sorted([list(element) for element in unique_set])

            # only keep the least amount of solves, this is the objective of DP
            min_length = min(len(i) for i in get_to)
            get_to = [i for i in get_to if len(i) == min_length]

            # if competing solution exist, use min_avg dst to break ties
            #reduced_get_to = check_overlap(get_to, gene_set)
            #reduced_get_to = cal_avg_dst(get_to, genes)
            get_to = avg_dst2(get_to, genes, gene_set)
            #print(f'reduced get to {reduced_get_to}')

            # if parallel solution found, the number of solves is equal to previous number of solves
            if increase_solve:
                current_solves = states[-1]
            # if no parallel solution, add 1 to previous number of solves
            else:
                current_solves = min([len(i) for i in get_to])+1

        # track prev set to determine if completely identical set
        #prev_set = gene_set[i]
        
        #print(len(get_to))
        print(f'the optimal path set preceding is: {get_to}')
        #print(f'took {end_time-start_time}s')

        # update the state 
        states = [i, get_to, max(gene_set[i]), current_solves]
        # reset the parallel solution counter
        increase_solve = False
        #print(f'current state {states}')
        
    #print(states)

    solution = []
    #final_states = states[1].copy()
    for i in states[1]:
        # if solution covers to the end
        if gene_set[i[-1]][-1] != states[2]:
            i.append(states[0])
        
        else:
            i.append(gene_set[-1][-1])

        # check if optimal solution
        if len(i) == states[-1]:
            solution.append(i)

    #end_time = time.time()
    #print(end_time-start_time)
    
    return solution
    #print(f'final solution {states}')

def main(threshold, chromosome):
    # load data for all protein coding genes in the human genome
    with open('protein2_full.pkl', 'rb') as f:
        protein_dict = pickle.load(f)

    # select the genes on a particular chromosome
    protein_dict = {k:v for k,v in protein_dict.items() if v[0]==chromosome}

    # extract each gene's txstart and txsend
    genes_tx_info = np.array([[protein_dict[i][2],protein_dict[i][3]] for i in protein_dict])

    # get the gene sets using the threshold
    gene_sets = get_sets(genes_tx_info, threshold)

    # solve the vicinity set cover problem
    reference_gene = solve_gene(gene_sets, genes_tx_info)

    vicinities_of_reference_genes = [gene_sets[i] for i in reference_gene[0]]


    return reference_gene, vicinities_of_reference_genes

In [None]:
if __name__ == '__main__':

    # Example threshold value (e.g., 20000 base pairs)
    threshold = 20000
    # Example chromosome (e.g., 'chr1')
    chromosome = 'chr1'  

    # solve the vicinity set cover problem using genes on chromosome 1 with a threshold of 20000 base pairs.
    reference_genes, vicinities_of_reference_genes = main(threshold, chromosome)

processing gene #1
the optimal path set preceding is: [[0]]
processing gene #2
the optimal path set preceding is: [[0, 1]]
processing gene #3
the optimal path set preceding is: [[0, 1, 2]]
processing gene #4
the optimal path set preceding is: [[0, 1, 2]]
processing gene #5
the optimal path set preceding is: [[0, 1, 2]]
processing gene #6
the optimal path set preceding is: [[0, 1, 2, 5]]
processing gene #7
the optimal path set preceding is: [[0, 1, 2, 5]]
processing gene #8
the optimal path set preceding is: [[0, 1, 2, 5]]
processing gene #9
the optimal path set preceding is: [[0, 1, 2, 5]]
processing gene #10
the optimal path set preceding is: [[0, 1, 2, 5, 9]]
processing gene #11
the optimal path set preceding is: [[0, 1, 2, 5, 9]]
processing gene #12
the optimal path set preceding is: [[0, 1, 2, 5, 9]]
processing gene #13
the optimal path set preceding is: [[0, 1, 2, 5, 9, 12]]
processing gene #14
the optimal path set preceding is: [[0, 1, 2, 5, 9, 12]]
processing gene #15
the optima

In [4]:
# this is the number of gene sets in the final solution
len(reference_genes[0])

1018

In [5]:
# this is the genes in each of the gene set in the final solution
vicinities_of_reference_genes

[[0],
 [1],
 [2],
 [3, 4, 5, 6, 7],
 [8, 9, 10],
 [11, 12],
 [13, 14, 15, 16, 17, 18],
 [19, 20, 21, 22, 23, 24],
 [25, 26, 27, 28],
 [29, 30, 31, 32, 33],
 [34, 35, 36],
 [36, 37, 38],
 [39, 40, 41, 42],
 [43, 44],
 [45, 46],
 [47, 48, 49, 50],
 [51, 52],
 [53, 54, 55, 56],
 [57, 58, 59],
 [60],
 [61, 62],
 [63],
 [64, 65, 66],
 [66, 67, 68, 69],
 [70, 71, 72, 73],
 [74, 75],
 [76],
 [77],
 [78, 79, 80],
 [81, 82, 83, 84],
 [83, 84, 85],
 [86, 87, 88, 89],
 [90, 91, 92, 93],
 [94, 95, 96],
 [97, 98, 99, 100],
 [101],
 [102],
 [103],
 [104, 105],
 [106],
 [107],
 [108, 109],
 [110],
 [111],
 [112],
 [113, 114],
 [115, 116],
 [117, 118, 119],
 [120, 121],
 [122, 123, 124],
 [124, 125, 126, 127],
 [127, 128],
 [129],
 [130, 131, 132, 133],
 [134, 135],
 [136],
 [137],
 [138, 139, 140, 141, 142],
 [143, 144, 145, 146],
 [146, 147, 148],
 [149, 150, 151],
 [151, 152],
 [153],
 [154],
 [155],
 [156],
 [157],
 [158],
 [159, 160],
 [161, 162, 163],
 [164, 165, 166],
 [166, 167],
 [168, 169, 1