## MIP Program for choosing preferred path for ancestor nodes.
1. Program - Gurobi Solver
2. Date - 31 Aug 2022

In [1]:
# libraries
import gurobipy as gp
from gurobipy import abs_,quicksum
from gurobipy import GRB
import time
import json
from collections import defaultdict
from ete3 import Tree

In [19]:
''' Convert the json pogs file to the data structure needed for MIP Model'''
class Pogs:
    def __init__(self,pogs_file,tree_file):
        
        self.json_pogs_file = pogs_file
        tree_file = open(nwk_file_path,"r")
        my_tree = tree_file.read() + ";"
        self.tree = Tree(my_tree, format=1)
          
    def create_neighbor_object(self):
        ''' create neighbor dict '''
        tree_neighbor_dict = defaultdict(list)
        
        for n in self.tree.traverse():
            if n.is_leaf() == False:    
                for c in n.children:
                    tree_neighbor_dict[n.name] += [c.name]
        return tree_neighbor_dict
    
    def create_node_info_dict(self):
        ''' create edge dictionary and other node information'''
        
        '''Example output-- {\
                  1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
                  2:{0:[1],1:[2],2:[3],3:[4],4:[5]} \
                  } '''
        
        
        # read the json file
        with open(self.json_pogs_file, 'r') as j:
            pog_all_data = json.loads(j.read())

            node_path_dict = {}
            node_path_reverse_dict = {}
            extant_list = []
            
            # read all ancestors
            for node_type in ['Ancestors','Extants']:
                for pog_data in pog_all_data[node_type]:
                    if node_type == 'Ancestors':
                        node_name = 'N' + pog_data['Name']
                    else:
                        node_name = pog_data['Name']
                        extant_list.append(node_name)  # all extants
                    
                    node_edges_info_dict = defaultdict(list)
                    node_edges_reverse_info_dict = defaultdict(list)
                    
                    # read that node's data
                    nodes = pog_data['Size'] + 2
                    
                    # Edges from special Start node to the start nodes
                    for s in pog_data['Starts']:
                        node_edges_info_dict[0] += [s+1] 
                        node_edges_reverse_info_dict[s+1] += [0] 
            
                    # Edges from last node to the special End node
                    for e in pog_data['Ends']:
                        node_edges_info_dict[e + 1] += [nodes-1]
                        node_edges_reverse_info_dict[nodes-1] += [e + 1]

                    # create the adjency matrix for all nodes except from special node start
                    for ind,node in enumerate(pog_data['Indices']):
                        row_mat = node
                        row_col = pog_data['Adjacent'][ind]

                        for rc in row_col:
                            node_edges_info_dict[row_mat + 1] += [rc + 1]
                            node_edges_reverse_info_dict[rc + 1] += [row_mat + 1]
                            
                    # put all info together in the final dict
                    node_path_dict[node_name] = node_edges_info_dict
                    node_path_reverse_dict[node_name] = node_edges_reverse_info_dict
                    
        # fill in missing positions ( easier to code in MIP)
        node_edges_info_dict_final = node_path_dict.copy()
        
        for node,all_edges in node_path_dict.items():
            pos_present = []
            for pos_from, pos_edges in all_edges.items():
                pos_present.append(pos_from)
            all_pos = [r for r in range(0,nodes)]
            pos_absent = list(set(all_pos) - set(pos_present))
            for pa in pos_absent:
                node_edges_info_dict_final[node][pa] = []
        return node_edges_info_dict_final,node_path_reverse_dict,nodes,extant_list

    
## testing
# nwk_file_path = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/grasp_ancestors.nwk'
# pogs_file     = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/pogs.json'
# p = Pogs(pogs_file,nwk_file_path)
# p.create_node_info_dict()
# p.create_neighbor_object()

In [34]:
''' MIP Model for preferred path with least parsimonous score '''
class PhyloTree:
    def __init__(self,nodes,sequence_length,neighbor_dict,node_from_edge_dict,folder_location\
                 ,extant_list,tree_name):

        # 1 - define the configuration  and decision variables for the tree
        self.nodes = nodes
        self.sequence_length = sequence_length
        self.neighbor_dict = neighbor_dict
        self.node_from_edge_dict = node_from_edge_dict
        self.folder_location = folder_location
        self.extant_list = extant_list
        self.tree_name = tree_name
        self.edges = {}
        self.positions = {}
        self.penalty = {}
        self.diff = {}
        self.objective = []
        self.M = 999

        # 2 - create a new model
        self.m = gp.Model("PreferredPathSolve")

        # 3 - create variables
        # 3.1 - create variables - positions and edges of the POG
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            for pos_from in range(0,sequence_length):
                # positions
                pos_id = (node,pos_from)
                pos = self.m.addVar(vtype=GRB.BINARY, name="n-%s-%s"%pos_id)
                self.positions[pos_id] = pos
                
                # edges
                for pos_to in range(pos_from + 1, sequence_length):
                    edge_id = (node,pos_from,pos_to)
                    e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                    self.edges[edge_id] = e
                    
        # 3.2 - create variables - add penalty variable
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                for pos in range(1,sequence_length - 1): # penalty start from 1st position only
                    
                    pen_id = (node,node_neighbor_item,pos)
                    pen = self.m.addVar(vtype=GRB.BINARY, name='p-%s-%s-%s'%pen_id)
                    self.penalty[pen_id] = pen
                
                    # 3.3 - add position difference to the objective (not for start and end node)
                    node_pos_var = self.positions[(node,pos)]
                    node_neighbor_pos_var = self.positions[(node_neighbor_item,pos)]
                    diff_id = (node,node_neighbor_item,pos)
                    diff_pos = self.m.addVar(vtype=GRB.BINARY, name='d-%s-%s-%s'%diff_id)
                    self.diff[diff_id] = diff_pos
                    
                    # abs constraint
                    self.m.addConstr( diff_pos <= node_pos_var + node_neighbor_pos_var,name=\
                                     "pos_diff_constraint_1-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos >= node_pos_var - node_neighbor_pos_var,name=\
                                     "pos_diff_constraint_2-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos >= node_neighbor_pos_var - node_pos_var,name=\
                                     "pos_diff_constraint_3-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos <= 2 - node_neighbor_pos_var - node_pos_var,name=\
                                     "pos_diff_constraint_4-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.objective.append(diff_pos)
                    
        
        
    def add_constraints_extants(self):
        ''' function to add constraints for extants '''
        # 4 - constraints for extants
        for e in self.extant_list:
            position_present = [] # track the position having amino acid
            extant_node_edges_forward = self.node_from_edge_dict[e]
            for pos_from,node_edge_from_list_item in extant_node_edges_forward.items():             
                pos_id = (e,pos_from)
                if node_edge_from_list_item:
                    self.m.addConstr(self.positions[pos_id] == 1,\
                                 name="other_position_constraint-%s-%s"%pos_id)
                else:
                    self.m.addConstr(self.positions[pos_id] == 0,\
                                 name="other_position_constraint-%s-%s"%pos_id)
            
                        
    def add_constraints_ancestors(self):
        ''' function to add constraints for ancestor node '''
        for node,node_edge_from_list in self.node_from_edge_dict.items(): # nodes information all
            if node not in self.extant_list:
                for from_pos,edge_from_pos in node_edge_from_list.items():  # position - edges

                    pos_id = (node,from_pos)
                    ######## edges ##########
                    possible_edges = [ r for r in range(from_pos + 1,sequence_length)]
                    present_edges  = edge_from_pos
                    edges_not_present_pos = list(set(possible_edges) - set(present_edges))
                
                    for to_pos in edges_not_present_pos:
                        edge_from_id = (node,from_pos,to_pos)
                        self.m.addConstr(self.edges[edge_from_id] == 0,name=\
                                                "na_edge_constraint-%s-%s-%s"%(node,from_pos,to_pos))
                    all_edges_from_pos = []
                    for to_pos in present_edges:
                        edge_from_id = (node,from_pos,to_pos)
                        all_edges_from_pos.append(self.edges[edge_from_id])
                    if all_edges_from_pos: #list is not empty
                        self.m.addConstr(quicksum(all_edges_from_pos) <= 1,name=\
                                                 "possible_edge_constraint-%s-%s"%pos_id)
                        
                    
                    ###### edges recon ######### arrving edge must leave except for start and end
                    # edges coming in
                    edges_coming_in = [ r for r in range(0,from_pos)]
                    #edges going out
                    edges_going_out = [ r for r in range(from_pos + 1,self.sequence_length)]
                    
                    edges_coming_in_list = []
                    edges_going_out_list = []

                    for edges_coming_in_item in edges_coming_in:
                        edge_to_id = (node,edges_coming_in_item,from_pos)
                        edges_coming_in_list.append(self.edges[edge_to_id])

                    for edges_going_out_item in edges_going_out:
                        edge_to_id = (node,from_pos,edges_going_out_item)
                        edges_going_out_list.append(self.edges[edge_to_id])
                        
                    if from_pos != 0 and from_pos != self.sequence_length - 1:
                        self.m.addConstr(sum(edges_coming_in_list) == sum(edges_going_out_list),\
                                             name="edge_recon_constraint-%s-%s"%pos_id)
                    
            
                    ###### POSITIONS ######
                    if from_pos != 0:
                        self.m.addConstr(sum(edges_coming_in_list) == self.positions[pos_id],\
                                            name="backward_position_constraint-%s-%s"%pos_id)
                    if from_pos != self.sequence_length - 1:
                        self.m.addConstr(sum(edges_going_out_list) == self.positions[pos_id],\
                                            name="forward_position_constraint-%s-%s"%pos_id)
                    
                                     
                
                # special constraint for start and end node
                end_pos = self.sequence_length - 1
                self.m.addConstr(self.positions[(node,end_pos)] == 1,\
                                     name="end_position_constraint-%s-%s"%(node,end_pos))
                start_pos = 0
                self.m.addConstr(self.positions[(node,start_pos)] == 1,\
                                     name="start_position_constraint-%s-%s"%(node,start_pos))
                
        
        
             
    def penalty_constraints(self):
        ''' function to add penalty constraints '''
        
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                for pos in range(1,sequence_length - 1):  # no penalty for start and end
                    diff_id  = (node,node_neighbor_item,pos)
                    pen_id   = (node,node_neighbor_item,pos)
                    diff_var = self.diff[diff_id]
                    pen_var  = self.penalty[pen_id]
                    
                    if pos == 1: # penalty for first position is simple
                        self.m.addConstr(pen_var == diff_var,"penalty_constraint-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    else:
                        pen_prev_id = (node,node_neighbor_item,pos - 1)
                        prev_pen_var =  self.penalty[pen_prev_id]
                        
                        self.m.addConstr(diff_var - prev_pen_var >= 1 - self.M * (1 - pen_var),\
                                         name="penalty_constraint_1-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                        self.m.addConstr(diff_var - prev_pen_var <= self.M * (pen_var),\
                                         name="penalty_constraint_2-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    
                    # add penalty to the objective
                    self.objective.append(2 * pen_var)
                    
        
    def train(self,n_threads,time_out):
        # Params
        self.m.Params.Threads = n_threads
        self.m.Params.TimeLimit = time_out*60
        #self.m.Params.LogFile =  folder_location + 
        self.m.Params.LogToConsole = 0
        self.m.Params.Degenmoves=0
        
        # Optimize
        self.total_objective = sum([o for o in self.objective])
        self.m.setObjective(self.total_objective, GRB.MINIMIZE)
        self.m.update()
        
        self.m.write((self.folder_location + 'pf_mip_formulation_' + self.tree_name + '.lp'))
        self.m.optimize()
        
        # Is feasible?
        return self.m.SolCount > 0
    
    
    def get_info(self):
        info_all = {}
        info_all["objective"] = self.m.ObjVal
        info_all["bound"] = self.m.ObjBound
        info_all["gap"] = self.m.MIPGap
        info_all["is_optimal"] = (self.m.status == GRB.OPTIMAL)
        info_all["num_nodes"] = self.m.NodeCount
        info_all["num_vars"] = self.m.NumIntVars + self.m.NumBinVars

        if self.m.SolCount > 0:
            print("objective: %0.2f"%info_all["objective"])
            print("bound: %0.2f"%info_all["bound"])
            print("gap: %0.2f"%info_all["gap"])

        return info_all
    
    def get_solution(self):
        # get the path
        all_node_paths = {}
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            preferred_path = []
            for pos_from in range(0,sequence_length):
                pos_id = (node,pos_from)
                preferred_path.append(int(self.positions[pos_id].X))
                
            all_node_paths[node] = preferred_path
            
        # get the differnece and penalty solution
        score_dict = {}
        overall_score = 0
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                total_score = 0
                for pos in range(1,sequence_length - 1): # penalty start from 1st position only
                    pen_id = (node,node_neighbor_item,pos)
                    total_score = total_score + 2 * int(self.penalty[pen_id].X)
#                     print("penalty",int(self.penalty[pen_id].X))
                    
                    diff_id = (node,node_neighbor_item,pos)
                    total_score = total_score +int(self.diff[diff_id].X)
                #print("total_score between node {} and node-neighbor{} is {}"\
                                  #.format(node,node_neighbor_item,total_score)) 
                score_dict[(node,node_neighbor_item)] = total_score
                overall_score = overall_score + total_score
        #print("overall_score",overall_score)
        return all_node_paths,score_dict
               

In [35]:
# TEST EXAMPLE - 1
# folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
# nodes = 3
# sequence_length = 6 #(start and end pos)
# neighbor_dict = dict({1:[2,3]})
# extant_list = [2,3]
# node_from_edge_dict = dict({\
#                   1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
#                   2:{0:[1],1:[2],2:[3],3:[4],4:[5]} ,\
#                   3:{0:[1],1:[2],2:[3],3:[4],4:[5]}})
# node_to_edge_dict = dict({\
#                   1:{1:[0],2:[0,1],3:[1,2],4:[1,2,3],5:[4]} ,\
#                   2:{1:[0],2:[1],3:[2],4:[3],5:[4]} ,\
#                   3:{1:[0],2:[1],3:[2],4:[3],5:[4]}\
#                     })
# tree_name = 'test_example_1'

In [36]:
#TEST EXAMPLE - 2
# folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
# nodes = 3
# sequence_length = 6 #(start and end pos)
# neighbor_dict = dict({1:[2,3]})
# extant_list = [2,3]
# node_from_edge_dict = dict({\
#                   1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
#                   2:{0:[1],1:[2],2:[3],3:[4],4:[5]} ,\
#                   3:{0:[1],1:[4],2:[],3:[],4:[5]}})
# node_to_edge_dict = dict({\
#                   1:{1:[0],2:[0,1],3:[1,2],4:[1,2,3],5:[4]} ,\
#                   2:{1:[0],2:[1],3:[2],4:[3],5:[4]} ,\
#                   3:{1:[0],2:[],3:[],4:[1],5:[4]}\
#                     })
# tree_name = 'test_example_2'

In [37]:
# TEST MIP Models for different trees

## sample tree
# nwk_file_path = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/grasp_ancestors.nwk'
# pogs_file     = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/pogs.json'
# folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
# tree_name = 'sample_tree'



## CYP2U - 165
nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/grasp_ancestors.nwk'
ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/grasp_ancestors.fa"
pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/pogs.json'
tree_name = 'cyp2u_165'

folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
TreePogs = Pogs(pogs_file,nwk_file_path)
node_from_edge_dict,node_to_edge_dict,sequence_length,extant_list = TreePogs.create_node_info_dict()
nodes = len(node_from_edge_dict)
neighbor_dict = TreePogs.create_neighbor_object()
print("TOTAL NODES::",nodes)
print("TOTAL POSITIONS:",sequence_length)
#print("EXTANT LIST:",extant_list)
#print("NEIGHBOR DICT:",neighbor_dict)

TOTAL NODES:: 329
TOTAL POSITIONS: 627


In [38]:
# testing for N53 only
nodes = 3
sequence_length = 627
tree_name = 'cyp2u_165_test'
extant_list_test = ['XP_003226556.2']

node_from_edge_dict_test = {}
node_from_edge_dict_test['N53'] =  node_from_edge_dict['N53']
node_from_edge_dict_test['N54'] =  node_from_edge_dict['N54']
node_from_edge_dict_test['XP_003226556.2'] =  node_from_edge_dict['XP_003226556.2']
neighbor_dict_test = {}
neighbor_dict_test['N53'] = neighbor_dict['N53']
node_from_edge_dict_test

{'N53': defaultdict(list,
             {0: [1],
              604: [626],
              1: [12],
              12: [13],
              13: [14],
              14: [15],
              15: [24],
              24: [25],
              25: [26],
              26: [27],
              27: [28],
              28: [29],
              29: [135],
              135: [136],
              136: [137],
              137: [138],
              138: [139],
              139: [140],
              140: [141],
              141: [144],
              144: [146],
              146: [147],
              147: [148],
              148: [151],
              151: [152],
              152: [153],
              153: [154],
              154: [155],
              155: [156],
              156: [157],
              157: [158],
              158: [159],
              159: [160],
              160: [161],
              161: [162],
              162: [163],
              163: [164],
              164: [165],
            

In [39]:
start = time.time()
print("Start Time:",start)
n_threads = 1
time_out = 5

PyTree = PhyloTree(nodes,sequence_length,neighbor_dict_test,node_from_edge_dict_test,folder_location\
                   ,extant_list_test,tree_name)
PyTree.add_constraints_extants()
PyTree.add_constraints_ancestors()
PyTree.penalty_constraints()

is_sat = PyTree.train(n_threads, time_out)
print("is_sat",is_sat)
total_time = ((time.time()-start))
print("-----------------------------")
print("Total time = %0.2f[m]"%total_time)
info = PyTree.get_info()
info["total_time"] = total_time
info["is_sat"] = is_sat
print("info",info)

if is_sat:
    all_node_paths,score_dict = PyTree.get_solution()
    #print("all_node_paths",all_node_paths)
else:
    print("Did not find any satisfactory solution to the model")

Start Time: 1662124910.752915
Set parameter Threads to value 1
Set parameter TimeLimit to value 300
is_sat True
-----------------------------
Total time = 8.87[m]
objective: 0.00
bound: 0.00
gap: 0.00
info {'objective': 0.0, 'bound': 0.0, 'gap': 0.0, 'is_optimal': True, 'num_nodes': 0.0, 'num_vars': 1186268, 'total_time': 8.871127128601074, 'is_sat': True}


In [40]:
# check output from mip
# read the fasta file
from pysam import FastaFile
import re

grasp_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/grasp_ancestors.fa"
sequences_info = FastaFile(grasp_fasta_file)
total_node_change = 0

for node_name,preferred_path in all_node_paths.items():
    grasp_output_seq = sequences_info.fetch(node_name)
    
    # convert into 1/0
    grasp_output_seq =  re.sub('[a-zA-Z]', '1', grasp_output_seq)
    grasp_output_seq = grasp_output_seq.replace('-','0')
    #print('grasp_output_seq',grasp_output_seq)
    
    # compare with mip output
    pf_str = ''
    for p in preferred_path:
        pf_str = pf_str + str(p)
    
    # remove first and last path
    pf_str = pf_str[1:-1]
    #print("preferred_path",pf_str)
    if grasp_output_seq != pf_str:
        print("checking node::",node_name)
        total_node_change += 1
        
        if node_name in extant_list:
            print("ERROR: path is different for an extant")
        else:
            print("path is different for ancestor.Should not change")
            
print("Total nodes changed",total_node_change)

Total nodes changed 0


In [41]:
pf_str = ''
for p in all_node_paths['N53']:
    pf_str = pf_str + str(p)
pf_str[1:-1]

'1000000000011110000000011111100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000011111110010111001111111111111111111111111111111111111111111111111111111111111101111111111111111111111111111111111111111111111111000001110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110001111111011111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111000000000000000000000'

In [None]:
'1000000000011110000000011111100000000000110000110000000000010001111100000000001100011000000000000000000000000000000000010000000000000011111110010111001111111111111111111111111111111111111111111111111111111111111101111111111111111111111111111111111111111111111111000001110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110001111111011111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111000000000000000000000'
'1000000000011110000000011111100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000011111110010111001111111111111111111111111111111111111111111111111111111111111101111111111111111111111111111111111111111111111111000001110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110001111111011111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111000000000000000000000'