## MIP Program for choosing preferred path for ancestor nodes.
1. Program - Gurobi Solver
2. Date - 31 Aug 2022

In [1]:
# libraries
import gurobipy as gp
from gurobipy import abs_,quicksum
from gurobipy import GRB
import time
import json
from collections import defaultdict
from ete3 import Tree

In [2]:
''' Convert the json pogs file to the data structure needed for MIP Model'''
class Pogs:
    def __init__(self,pogs_file,tree_file):
        
        self.json_pogs_file = pogs_file
        tree_file = open(nwk_file_path,"r")
        my_tree = tree_file.read() + ";"
        self.tree = Tree(my_tree, format=1)
          
    def create_neighbor_object(self):
        ''' create neighbor dict '''
        tree_neighbor_dict = defaultdict(list)
        
        for n in self.tree.traverse():
            if n.is_leaf() == False:    
                for c in n.children:
                    tree_neighbor_dict[n.name] += [c.name]
        return tree_neighbor_dict
    
    def create_node_info_dict(self):
        ''' create edge dictionary and other node information'''
        
        '''Example output-- {\
                  1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
                  2:{0:[1],1:[2],2:[3],3:[4],4:[5]} \
                  } '''
        
        
        # read the json file
        with open(self.json_pogs_file, 'r') as j:
            pog_all_data = json.loads(j.read())

            node_path_dict = {}
            node_path_reverse_dict = {}
            extant_list = []
            
            # read all ancestors
            for node_type in ['Ancestors','Extants']:
                for pog_data in pog_all_data[node_type]:
                    if node_type == 'Ancestors':
                        node_name = 'N' + pog_data['Name']
                    else:
                        node_name = pog_data['Name']
                        extant_list.append(node_name)  # all extants
                    
                    node_edges_info_dict = defaultdict(list)
                    node_edges_reverse_info_dict = defaultdict(list)
                    
                    # read that node's data
                    nodes = pog_data['Size'] + 2
                    
                    # Edges from special Start node to the start nodes
                    for s in pog_data['Starts']:
                        node_edges_info_dict[0] += [s+1] 
                        node_edges_reverse_info_dict[s+1] += [0] 
            
                    # Edges from last node to the special End node
                    for e in pog_data['Ends']:
                        node_edges_info_dict[e + 1] += [nodes-1]
                        node_edges_reverse_info_dict[nodes-1] += [e + 1]

                    # create the adjency matrix for all nodes except from special node start
                    for ind,node in enumerate(pog_data['Indices']):
                        row_mat = node
                        row_col = pog_data['Adjacent'][ind]

                        for rc in row_col:
                            node_edges_info_dict[row_mat + 1] += [rc + 1]
                            node_edges_reverse_info_dict[rc + 1] += [row_mat + 1]
                            
                    # put all info together in the final dict
                    node_path_dict[node_name] = node_edges_info_dict
                    node_path_reverse_dict[node_name] = node_edges_reverse_info_dict
        
        return node_path_dict,node_path_reverse_dict,nodes,extant_list
    
## testing
# nwk_file_path = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/grasp_ancestors.nwk'
# pogs_file     = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/pogs.json'
# p = Pogs(pogs_file,nwk_file_path)
# p.create_node_info_dict()
# p.create_neighbor_object()

In [51]:
''' MIP Model for preferred path with least parsimonous score '''
class PhyloTree:
    def __init__(self,nodes,sequence_length,neighbor_dict,node_from_edge_dict,node_to_edge_dict,folder_location\
                 ,extant_list,tree_name):

        # 1 - define the configuration  and decision variables for the tree
        self.nodes = nodes
        self.sequence_length = sequence_length
        self.neighbor_dict = neighbor_dict
        self.node_from_edge_dict = node_from_edge_dict
        self.node_to_edge_dict = node_to_edge_dict
        self.folder_location = folder_location
        self.extant_list = extant_list
        self.tree_name = tree_name
        self.edges = {}
        self.positions = {}
        self.penalty = {}
        self.diff = {}
        self.objective = []
        self.M = 999

        # 2 - create a new model
        self.m = gp.Model("PreferredPathSolve")

        # 3 - create variables
        # 3.1 - create variables - positions and edges of the POG
#         for node in range(1,nodes + 1):
#             for pos_from in range(0,sequence_length):

#                 # positions
#                 pos_id = (node,pos_from)
#                 pos = self.m.addVar(vtype=GRB.BINARY, name="n-%d-%d"%pos_id)
#                 self.positions[pos_id] = pos

#                 # edges
#                 for pos_to in range(pos_from + 1, sequence_length):
#                     edge_id = (node,pos_from,pos_to)
#                     e = self.m.addVar(vtype=GRB.BINARY, name='e-%d-%d-%d'%edge_id)
#                     self.edges[edge_id] = e


                    
        
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            for pos_from in range(0,sequence_length):
                # positions
                pos_id = (node,pos_from)
                pos = self.m.addVar(vtype=GRB.BINARY, name="n-%s-%s"%pos_id)
                self.positions[pos_id] = pos
                
                # edges
                for pos_to in range(pos_from + 1, sequence_length):
                    edge_id = (node,pos_from,pos_to)
                    e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                    self.edges[edge_id] = e
        
        # 3.2 - create variables - add penalty variable
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                for pos in range(1,sequence_length - 1): # penalty start from 1st position only
                    
                    pen_id = (node,node_neighbor_item,pos)
                    pen = self.m.addVar(vtype=GRB.BINARY, name='p-%s-%s-%s'%pen_id)
                    self.penalty[pen_id] = pen
                
                    # 3.3 - add position difference to the objective (not for start and end node)
                    node_pos_var = self.positions[(node,pos)]
                    node_neighbor_pos_var = self.positions[(node_neighbor_item,pos)]
                    diff_id = (node,node_neighbor_item,pos)
                    diff_pos = self.m.addVar(vtype=GRB.BINARY, name='d-%s-%s-%s'%diff_id)
                    self.diff[diff_id] = diff_pos
                    
                    # abs constraint
                    self.m.addConstr( diff_pos <= node_pos_var + node_neighbor_pos_var,name=\
                                     "pos_diff_constraint_1-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos >= node_pos_var - node_neighbor_pos_var,name=\
                                     "pos_diff_constraint_2-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos >= node_neighbor_pos_var - node_pos_var,name=\
                                     "pos_diff_constraint_3-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos <= 2 - node_neighbor_pos_var - node_pos_var,name=\
                                     "pos_diff_constraint_4-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.objective.append(diff_pos)


    def add_constraints_n_objective(self):
        
        # 4 - create constraints
        # 4.1 - edges and positions

        for node,node_edge_from_list in self.node_from_edge_dict.items(): # nodes information all
            #print("node",node)
            #print("node_edge_from_list",node_edge_from_list)
            for pos_from,node_edge_from_list_item in node_edge_from_list.items():  # position - edges
                all_edges_from_pos = []
                
                ######## edges ##########
                # for edges not present
                possible_edges = [ r for r in range(pos_from + 1,sequence_length)]
                present_edges = node_edge_from_list_item
                edges_not_present_pos = list(set(possible_edges) - set(present_edges))
                #print("edges_not_present_pos",edges_not_present_pos)
                #print("present_edges",present_edges)
                
                for pos_to in edges_not_present_pos:
                    edge_from_id = (node,pos_from,pos_to)
                    self.m.addConstr(self.edges[edge_from_id] == 0,name=\
                                            "na_edge_constraint-%s-%s-%s"%(node,pos_from,pos_to))
                    #all_edges_from_pos.append(self.edges[edge_from_id])
                
                # for edges present
                for pos_to in present_edges:
                    edge_from_id = (node,pos_from,pos_to)
                    all_edges_from_pos.append(self.edges[edge_from_id])

                if all_edges_from_pos: #list is not empty
                    if node not in self.extant_list:
                    ## all sum <= 1 atmost 1 edge can be chosen if it is an ancestor node
                        self.m.addConstr(quicksum(all_edges_from_pos) <= 1,name="edge_constraint-%s-%s"%(node,pos_from))
                    else: # extant
                        self.m.addConstr(quicksum(all_edges_from_pos) == 1,name="edge_constraint-%s-%s"%(node,pos_from))

                ######## position ##########
                ## position is chosen only there is an edge from position
                pos_id = (node,pos_from)
                if all_edges_from_pos: #list is not empty
                    self.m.addConstr(quicksum(all_edges_from_pos) == self.positions[pos_id],\
                                 name="forward_position_constraint-%s-%s"%(node,pos_from))
                else:
                    self.m.addConstr(self.positions[pos_id] == 0,\
                                 name="forward_position_constraint-%s-%s"%(node,pos_from))
                              
                
                ######## edges coming in == edges going out ##########
                # get all edges to that position
                node_edge_to_list = self.node_to_edge_dict[node] 
                
                if pos_from != 0: # starting position does not have anything coming to it
                    node_edge_to_list_item = node_edge_to_list[pos_from] # all edges ending at the current position
                    all_edges_to_pos = []
                
                    for edge_to_pos in node_edge_to_list_item:
                        edge_to_id = (node,edge_to_pos,pos_from)
                        all_edges_to_pos.append(self.edges[edge_to_id])

                    self.m.addConstr(sum(all_edges_from_pos) == sum(all_edges_to_pos),\
                                     name="edge_recon_constraint-%s-%s"%\
                               (node,pos_from))
        
            # 4.2 - special constraint for start and end node
            end_pos = self.sequence_length - 1
            self.m.addConstr(self.positions[(node,end_pos)] == 1,\
                                 name="end_position_constraint-%s-%s"%(node,end_pos))
            start_pos = 0
            self.m.addConstr(self.positions[(node,start_pos)] == 1,\
                                 name="start_position_constraint-%s-%s"%(node,start_pos))
            
            
        # 4.3 - redundant constraints - edges coming to a position and position selection
        for node,node_edge_to_list in self.node_to_edge_dict.items(): 
            #print("node",node)
            #print("node_edge_to_list",node_edge_to_list)
            for position, edges_to_vals in node_edge_to_list.items(): 
                #print("position",position)
                #print("edges_to_vals",edges_to_vals)
                all_edges_to_pos = [] 
                pos_id = (node,position)
                #or pos_to in edges_to_vals:
                for pos_to in range(0,position):
                    edge_from_id = (node,pos_to,position)
                    all_edges_to_pos.append(self.edges[edge_from_id])
                    
                self.m.addConstr(quicksum(all_edges_to_pos) == self.positions[pos_id],\
                                 name="backward_position_constraint-%s-%s"%(node,position))
                    
                

        # 4.4 - penalty big M constraints
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                for pos in range(1,sequence_length - 1):  # no penalty for start and end
                    diff_id  = (node,node_neighbor_item,pos)
                    pen_id   = (node,node_neighbor_item,pos)
                    diff_var = self.diff[diff_id]
                    pen_var  = self.penalty[pen_id]
                    
                    if pos == 1: # penalty for first position is simple
                        self.m.addConstr(pen_var == diff_var,"penalty_constraint-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    else:
                        pen_prev_id = (node,node_neighbor_item,pos - 1)
                        prev_pen_var =  self.penalty[pen_prev_id]
                        
                        self.m.addConstr(diff_var - prev_pen_var >= 1 - self.M * (1 - pen_var),\
                                         name="penalty_constraint_1-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                        self.m.addConstr(diff_var - prev_pen_var <= self.M * (pen_var),\
                                         name="penalty_constraint_2-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    
                    # add penalty to the objective
                    self.objective.append(2 * pen_var)
                        
                        
        
        
    def train(self,n_threads,time_out):
        
        # Params
        self.m.Params.Threads = n_threads
        self.m.Params.TimeLimit = time_out*60
        #self.m.Params.LogFile =  folder_location + 
        self.m.Params.LogToConsole = 0
        self.m.Params.Degenmoves=0
        
        # Optimize
        self.total_objective = sum([o for o in self.objective])
        #print("self.total_objective",self.total_objective)
        self.m.setObjective(self.total_objective, GRB.MINIMIZE)
        self.m.update()
        
        self.m.write((self.folder_location + 'pf_mip_formulation_' + self.tree_name + '.lp'))
        self.m.optimize()
        
        # Is feasible?
        return self.m.SolCount > 0
    
    
    def get_info(self):
        info_all = {}
        info_all["objective"] = self.m.ObjVal
        info_all["bound"] = self.m.ObjBound
        info_all["gap"] = self.m.MIPGap
        info_all["is_optimal"] = (self.m.status == GRB.OPTIMAL)
        info_all["num_nodes"] = self.m.NodeCount
        info_all["num_vars"] = self.m.NumIntVars + self.m.NumBinVars

        if self.m.SolCount > 0:
            print("objective: %0.2f"%info_all["objective"])
            print("bound: %0.2f"%info_all["bound"])
            print("gap: %0.2f"%info_all["gap"])

        return info_all
    
    def get_solution(self):
        all_node_paths = {}
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            preferred_path = []
            for pos_from in range(0,sequence_length):
                pos_id = (node,pos_from)
                preferred_path.append(int(self.positions[pos_id].X))
                
            all_node_paths[node] = preferred_path
        return all_node_paths
               

In [44]:
# TEST EXAMPLE - 1
# folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
# nodes = 3
# sequence_length = 6 #(start and end pos)
# neighbor_dict = dict({1:[2,3]})
# extant_list = [2,3]
# node_from_edge_dict = dict({\
#                   1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
#                   2:{0:[1],1:[2],2:[3],3:[4],4:[5]} ,\
#                   3:{0:[1],1:[2],2:[3],3:[4],4:[5]}})
# node_to_edge_dict = dict({\
#                   1:{1:[0],2:[0,1],3:[1,2],4:[1,2,3],5:[4]} ,\
#                   2:{1:[0],2:[1],3:[2],4:[3],5:[4]} ,\
#                   3:{1:[0],2:[1],3:[2],4:[3],5:[4]}\
#                     })
# tree_name = 'test_example_1'

In [53]:
#TEST EXAMPLE - 2
folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
nodes = 3
sequence_length = 6 #(start and end pos)
neighbor_dict = dict({1:[2,3]})
extant_list = [2,3]
node_from_edge_dict = dict({\
                  1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
                  2:{0:[1],1:[2],2:[3],3:[4],4:[5]} ,\
                  3:{0:[1],1:[4],2:[],3:[],4:[5]}})
node_to_edge_dict = dict({\
                  1:{1:[0],2:[0,1],3:[1,2],4:[1,2,3],5:[4]} ,\
                  2:{1:[0],2:[1],3:[2],4:[3],5:[4]} ,\
                  3:{1:[0],2:[],3:[],4:[1],5:[4]}\
                    })
tree_name = 'test_example_2'

In [55]:
# TEST MIP Models for different trees


## sample tree
nwk_file_path = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/grasp_ancestors.nwk'
pogs_file     = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/pogs.json'
folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
tree_name = 'sample_tree'
TreePogs = Pogs(pogs_file,nwk_file_path)
node_from_edge_dict,node_to_edge_dict,sequence_length,extant_list = TreePogs.create_node_info_dict()
nodes = len(node_from_edge_dict)
neighbor_dict = TreePogs.create_neighbor_object()
print("TOTAL NODES::",nodes)
print("TOTAL POSITIONS:",sequence_length)
print("EXTANT LIST:",extant_list)
print("NEIGHBOR DICT:",neighbor_dict)
node_from_edge_dict['A6']

TOTAL NODES:: 39
TOTAL POSITIONS: 16
EXTANT LIST: ['A3', 'A9', 'A4', 'A20', 'A8', 'A5', 'A19', 'A1', 'A17', 'A10', 'A7', 'A16', 'A18', 'A13', 'A15', 'A14', 'A11', 'A12', 'A2', 'A6']
NEIGHBOR DICT: defaultdict(<class 'list'>, {'N0': ['N1', 'N9'], 'N1': ['N2', 'N6'], 'N9': ['N10', 'N18'], 'N2': ['N3', 'N5'], 'N6': ['A5', 'N7'], 'N10': ['N11', 'N13'], 'N18': ['A2', 'A6'], 'N3': ['N4', 'A4'], 'N5': ['A20', 'A8'], 'N7': ['N8', 'A17'], 'N11': ['A10', 'N12'], 'N13': ['N14', 'N15'], 'N4': ['A3', 'A9'], 'N8': ['A19', 'A1'], 'N12': ['A7', 'A16'], 'N14': ['A18', 'A13'], 'N15': ['A15', 'N16'], 'N16': ['A14', 'N17'], 'N17': ['A11', 'A12']})


defaultdict(list,
            {0: [1],
             14: [15],
             1: [3],
             3: [4],
             4: [5],
             5: [6],
             6: [7],
             7: [10],
             10: [12],
             12: [14]})

In [56]:
start = time.time()
print("Start Time:",start)
n_threads = 1
time_out = 60

PyTree = PhyloTree(nodes,sequence_length,neighbor_dict,node_from_edge_dict,node_to_edge_dict,folder_location\
                   ,extant_list,tree_name)
PyTree.add_constraints_n_objective()
is_sat = PyTree.train(n_threads, time_out)
print("is_sat",is_sat)
total_time = ((time.time()-start))
print("-----------------------------")
print("Total time = %0.2f[m]"%total_time)
info = PyTree.get_info()
info["total_time"] = total_time
info["is_sat"] = is_sat
print("info",info)

if is_sat:
    all_node_paths = PyTree.get_solution()
    print("all_node_paths",all_node_paths)
else:
    print("Did not find any satisfactory solution to the model")

Start Time: 1662012824.1509252
Set parameter Threads to value 1
Set parameter TimeLimit to value 3600
is_sat True
-----------------------------
Total time = 0.14[m]
objective: 0.00
bound: 0.00
gap: 0.00
info {'objective': 0.0, 'bound': 0.0, 'gap': 0.0, 'is_optimal': True, 'num_nodes': 1.0, 'num_vars': 12736, 'total_time': 0.14310193061828613, 'is_sat': True}
all_node_paths {'N0': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N2': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N3': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N4': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N5': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N6': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N7': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N8': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N9': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N10': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'N11':