## MIP Program for choosing preferred path for ancestor nodes.
1. Program - Gurobi Solver
2. Date - 5 September 2022
3. Add each ancestor POG as constraints

In [1]:
# libraries
import gurobipy as gp
from gurobipy import abs_,quicksum
from gurobipy import GRB
import time
import json
from collections import defaultdict
from ete3 import Tree
import numpy as np
from pysam import FastaFile
import re

## MIP INPUT

In [2]:
''' Convert the json pogs file to the data structure needed for MIP Model'''
''' Information needed 
extant list - list of all extant nodes
sequence length
total nodes in the tree
edges from position in the node
edges to the position in the node
fully connected, start, end, single connected and unconnected positions
'''


class Pogs:
    def __init__(self,pogs_file,tree_file):
        
        self.json_pogs_file = pogs_file
        tree_file = open(nwk_file_path,"r")
        my_tree = tree_file.read() + ";"
        self.tree = Tree(my_tree, format=1)
        self.node_path_dict = {}
        self.node_path_reverse_dict = {}
        self.tree_neighbor_dict = defaultdict(list)
        self.node_type_dict = defaultdict(list)
        self.node_pogs_cnt = {}
        self.node_pogs = {}
        self.extant_list = []
        self.nodes = 0
          
    def count_path(self,a):
        ''' count the number of paths in a graph '''
        a = a + a.T    #add up the transpose
        a = np.clip(a,0,1)
        a = np.triu(a) #only the upper triangle

        nodes = a.shape[0]
        dp = [0] * nodes
        dp[nodes - 1]= 1 #last node

        for i in range(nodes - 1, -1, -1):
            neighbour_nodes = np.where (a[i] == 1)[0]
            for j in neighbour_nodes:
                dp[i] = dp[i] + dp[j]

        return(dp[0])
    
    def create_neighbor_object(self):
        ''' create neighbor dict '''
        
        for n in self.tree.traverse():
            if n.is_leaf() == False:    
                for c in n.children:
                    self.tree_neighbor_dict[n.name] += [c.name]
        return self.tree_neighbor_dict
                
    def create_node_info_dict(self):
        ''' create edge dictionary and other node information'''
        '''Example output-- {\
                  1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
                  2:{0:[1],1:[2],2:[3],3:[4],4:[5]} \
                  } '''
        # read the json file
        with open(self.json_pogs_file, 'r') as j:
            pog_all_data = json.loads(j.read())

            # read all ancestors
            for node_type in ['Ancestors','Extants']:
                for pog_data in pog_all_data[node_type]:
                    if node_type == 'Ancestors':
                        node_name = 'N' + pog_data['Name']
                    else:
                        node_name = pog_data['Name']
                        
                    node_edges_info_dict = defaultdict(list)
                    node_edges_reverse_info_dict = defaultdict(list)
                    
                    # read that node's data
                    self.nodes = pog_data['Size'] + 2
                    mat = np.zeros(shape=(self.nodes,self.nodes))
                    
                    # Edges from special Start node to the start nodes
                    for s in pog_data['Starts']:
                        node_edges_info_dict[0] += [s+1] 
                        node_edges_reverse_info_dict[s+1] += [0]
                        mat[0,s + 1] = 1
                       
                    # Edges from last node to the special End node
                    for e in pog_data['Ends']:
                        node_edges_info_dict[e + 1] += [self.nodes-1]
                        node_edges_reverse_info_dict[self.nodes-1] += [e + 1]
                        mat[e + 1,self.nodes-1] = 1
                        
                        
                    # create the adjency matrix for all nodes except from special node start
                    for ind,node in enumerate(pog_data['Indices']):
                        row_mat = node
                        row_col = pog_data['Adjacent'][ind]

                        for rc in row_col:
                            node_edges_info_dict[row_mat + 1] += [rc + 1]
                            node_edges_reverse_info_dict[rc + 1] += [row_mat + 1]
                            mat[row_mat + 1,rc + 1] = 1
                            
                    # put all info together in the final dict
                    self.node_path_dict[node_name] = node_edges_info_dict
                    self.node_path_reverse_dict[node_name] = node_edges_reverse_info_dict
                    
                    # number of paths in pog
                    total_sequences = self.count_path(mat)
                    self.node_pogs_cnt[node_name] = total_sequences
                    self.node_pogs[node_name] = mat
                    
                    # add nodes with only 1 path in pogs in the extant list
                    if total_sequences == 1:
                        self.extant_list.append(node_name)  # all extants
                    if total_sequences == 0:
                        print("ERROR:: There is no path in the POG")
                    
        return self.node_path_dict,self.node_path_reverse_dict,self.nodes,self.extant_list,self.node_pogs_cnt,\
                                                                                        self.node_pogs
    
    def node_type_info(self):
        ''' different node type information , to make it easier for MIP coding'''
        
        for node_name,node_edge_val in self.node_path_dict.items():
            self.node_type_dict[(node_name,'start')] = [0]
            self.node_type_dict[(node_name,'end')] = [self.nodes - 1]
            positions = self.nodes
            forward_edges = self.node_path_dict[node_name]
            backward_edges = self.node_path_reverse_dict[node_name]
            
            for n in range(1,positions - 1):   # do not include start / end positions
                if forward_edges.get(n) and backward_edges.get(n):
                    self.node_type_dict[(node_name,'fc_nodes')]   += [n]
                elif (forward_edges.get(n) and not backward_edges.get(n)):
                    self.node_type_dict[(node_name,'f_sc_nodes')] += [n]
                elif (not forward_edges.get(n) and backward_edges.get(n)):
                    self.node_type_dict[(node_name,'t_sc_nodes')] += [n] 
                else:
                    self.node_type_dict[(node_name,'uc_nodes')]   += [n]
        return self.node_type_dict
            
## testing
#nwk_file_path = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/grasp_ancestors.nwk'
#pogs_file     = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/pogs.json'
#p = Pogs(pogs_file,nwk_file_path)
#p.create_node_info_dict()
#p.create_neighbor_object()
#p.node_type_info()

## MIP MODEL

In [3]:
''' MIP Model for preferred path with least parsimonous score 

CONSTRAINTS

EDGES - (ANCESTOR NODE ONLY)
1. possible_edge_constraint : edge is present and sum of all edges <= 1
2. na_edge_constraint       : edge is not present
3. edge_recon_constraint    : sum of the edges coming in is equal sum of edges going out ( not for start/end position)

POSITION - (ANCESTOR + EXTANTS)
1. extant_position_constraint     : fixed constraint if the position is present in extant
2. start_end_position_constraint  : start and end position contraint fixed and == 1
3. edge_from_constraint           : sum(of all going out edges) == position variable
4. edge_to_constraint             : sum(of all coming in edges) == position variable
5. unconnected_position_constraint: position in the node not connected to any other nodes

DIFFERENCE - (NEIGHBORS)
1. pos_diff_constraint : difference in each position betweeen node and its neighbors

PENALTY - (NEIGHBORS)
1. penalty_constraint  : penalty varaible constraint. it is 1 for every opening difference between positions.

'''

class PhyloTree:
    def __init__(self,nodes,sequence_length,neighbor_dict,node_from_edge_dict,folder_location\
                 ,extant_list,tree_name,position_type_dict,node_from_reverse_edge_dict):

        # Define the configuration  and decision variables for the tree
        self.nodes = nodes
        self.sequence_length = sequence_length
        self.neighbor_dict = neighbor_dict
        self.node_from_edge_dict = node_from_edge_dict
        self.folder_location = folder_location
        self.extant_list = extant_list
        self.tree_name = tree_name
        self.position_type_dict = position_type_dict
        self.node_from_reverse_edge_dict = node_from_reverse_edge_dict
        
        # MIP data structures
        self.edges = {}
        self.positions = {}
        self.penalty = {}
        self.diff = {}
        self.objective = []
        self.M = 999

        # 2 - create a new model
        self.m = gp.Model("PreferredPathSolve")
        
        
    def add_edges_var_constraints(self):
        ''' add edge varaibles and constraints for ancestor nodes only '''  
    
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            
            # node type ( edges for ancestor nodes)
            if node not in self.extant_list:
                
                # fully connected positions
                if self.position_type_dict.get((node,'fc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'fc_nodes')):
                        all_edges_from_pos = []
                        for pos_to in node_edge_from_list[pos_from]:
                            edge_id = (node,pos_from,pos_to)
                            e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                            self.edges[edge_id] = e
                            all_edges_from_pos.append(e)
                        
                        # only 1 edge can be used
                        pos_id = (node,pos_from)
                        self.m.addConstr(quicksum(all_edges_from_pos) <= 1,name=\
                                                  "possible_edge_constraint-%s-%s"%pos_id)
                    
                # start position
                if self.position_type_dict.get((node,'start')):
                    for pos_from in self.position_type_dict.get((node,'start')):
                        all_edges_from_pos = []
                        for pos_to in node_edge_from_list[pos_from]:
                            edge_id = (node,pos_from,pos_to)
                            e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                            self.edges[edge_id] = e
                            all_edges_from_pos.append(e)
                            
                        # only 1 edge can be used
                        pos_id = (node,pos_from)
                        self.m.addConstr(quicksum(all_edges_from_pos) <= 1,\
                                            name = "possible_edge_constraint-%s-%s"%pos_id)
                        self.m.addConstr(quicksum(all_edges_from_pos) == self.positions[pos_id],\
                                            name = "edge_from_constraint-%s-%s"%pos_id)
                        
                # end position
                if self.position_type_dict.get((node,'end')):
                    for pos_to in self.position_type_dict.get((node,'end')):
                        all_edges_to_pos = []
                        for pos_from in node_from_reverse_edge_dict[node][pos_to]:
                            edge_id = (node,pos_from,pos_to)
                            all_edges_to_pos.append(self.edges[edge_id])
                            
                        # only 1 edge can be used
                        pos_id = (node,pos_to)
                        self.m.addConstr(quicksum(all_edges_to_pos) <= 1,\
                                           name  = "possible_edge_constraint-%s-%s"%pos_id)
                        self.m.addConstr(quicksum(all_edges_to_pos) == self.positions[pos_id],\
                                            name = "edge_to_constraint-%s-%s"%pos_id)
                        
                
                # one way edges - only forward
                if self.position_type_dict.get((node,'f_sc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'f_sc_nodes')):
                        
                        all_edges_from_pos = []
                        for pos_to in node_edge_from_list[pos_from]:
                            edge_id = (node,pos_from,pos_to)
                            e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                            self.edges[edge_id] = e
                            all_edges_from_pos.append(e)
                            
                        # only 1 edge can be used
                        pos_id = (node,pos_from)
                        self.m.addConstr(quicksum(all_edges_from_pos) <= 1,name=\
                                                     "possible_edge_constraint-%s-%s"%pos_id)

                        edge_id = (node,0,pos_from)
                        e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                        self.edges[edge_id] = e
                        
                        # fake edge
                        self.m.addConstr(self.edges[edge_id] == 0,name=\
                                                 "na_edge_constraint-%s-%s-%s"%edge_id)
                # one way edges - coming in
                if self.position_type_dict.get((node,'t_sc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'t_sc_nodes')):
                        
                        edge_id = (node,pos_from, sequence_length - 1)
                        e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                        self.edges[edge_id] = e 
                        
                        # fake edge
                        self.m.addConstr(self.edges[edge_id] == 0,name=\
                                                 "na_edge_constraint-%s-%s-%s"%edge_id)
            
    def add_pos_var_constraints(self):     
        ''' add position, diff and penalty variables. add diff constraints'''
    
        # position constraints
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            for pos_from in range(0,sequence_length):
                pos_id = (node,pos_from)
                pos = self.m.addVar(vtype=GRB.BINARY, name="n-%s-%s"%pos_id)
                self.positions[pos_id] = pos
                
                # special constraint for start and end node
                if pos_from == 0 or pos_from == self.sequence_length - 1:
                    self.m.addConstr(self.positions[pos_id] == 1,\
                                     name="start_end_position_constraint-%s-%s"%pos_id)

        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                for pos in range(1,sequence_length - 1): # penalty start from 1st position only
                    
                    # penalty variables
                    pen_id = (node,node_neighbor_item,pos)
                    pen = self.m.addVar(vtype=GRB.BINARY, name='p-%s-%s-%s'%pen_id)
                    self.penalty[pen_id] = pen

                    # add position difference to the objective (not for start and end node)
                    node_pos_var = self.positions[(node,pos)]
                    node_neighbor_pos_var = self.positions[(node_neighbor_item,pos)]
                    diff_id = (node,node_neighbor_item,pos)
                    diff_pos = self.m.addVar(vtype=GRB.BINARY, name='d-%s-%s-%s'%diff_id)
                    self.diff[diff_id] = diff_pos

                    # abs difference constraint
                    self.m.addConstr( diff_pos <= node_pos_var + node_neighbor_pos_var,name=\
                                     "diff_constraint_1-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos >= node_pos_var - node_neighbor_pos_var,name=\
                                     "diff_constraint_2-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos >= node_neighbor_pos_var - node_pos_var,name=\
                                     "diff_constraint_3-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos <= 2 - node_neighbor_pos_var - node_pos_var,name=\
                                     "diff_constraint_4-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.objective.append(diff_pos)
                    
        
        
    def add_pos_constraints_extants(self):
        ''' function to add constraints for extants to fix them '''

        for node in self.extant_list:
            extant_node_edges_forward = self.node_from_edge_dict[node]
            position_present = list(extant_node_edges_forward.keys())
            position_present.append(self.sequence_length - 1) # last position
            
            for pos_from in range(0,sequence_length):
                pos_id = (node,pos_from)
                if pos_from in position_present:
                    self.m.addConstr(self.positions[pos_id] == 1,\
                                     name="extant_position_constraint-%s-%s"%pos_id)
                else:
                        self.m.addConstr(self.positions[pos_id] == 0,\
                                     name="extant_position_constraint-%s-%s"%pos_id)
                
                        
    def add_constraints_ancestors(self):
        ''' function to add constraints for ancestor node '''
        
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            
            # node type ( edges for ancestor nodes)
            if node not in self.extant_list:
                
                # unconnected positions
                if self.position_type_dict.get((node,'uc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'uc_nodes')):
                        pos_id = (node,pos_from)
                        
                        self.m.addConstr(self.positions[pos_id] == 0,\
                                     name="unconnected_position_constraint-%s-%s"%pos_id)
                        
                # fully connected positions
                if self.position_type_dict.get((node,'fc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'fc_nodes')):
                        
                        # edges coming in
                        edges_coming_in = node_from_reverse_edge_dict[node][pos_from]
                        #edges going out
                        edges_going_out = node_edge_from_list[pos_from]
                    
                        # get all edges going out and coming in
                        edges_coming_in_list = []
                        edges_going_out_list = []

                        for edges_coming_in_item in edges_coming_in:
                            edge_to_id = (node,edges_coming_in_item,pos_from)
                            edges_coming_in_list.append(self.edges[edge_to_id])

                        for edges_going_out_item in edges_going_out:
                            edge_to_id = (node,pos_from,edges_going_out_item)
                            edges_going_out_list.append(self.edges[edge_to_id])

                        pos_id = (node,pos_from)
                        self.m.addConstr(sum(edges_coming_in_list) == sum(edges_going_out_list),\
                                                 name="edge_recon_constraint-%s-%s"%pos_id)
                        
                        self.m.addConstr(sum(edges_going_out_list) == self.positions[pos_id],\
                                            name="edge_from_constraint-%s-%s"%pos_id)
                        
                        self.m.addConstr(sum(edges_coming_in_list) == self.positions[pos_id],\
                                            name="edge_to_constraint-%s-%s"%pos_id)
                        
                        
                # single connected positions
                # one way edges - only forward
                if self.position_type_dict.get((node,'f_sc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'f_sc_nodes')):
                        
                        #edges going out
                        edges_going_out = node_edge_from_list[pos_from]
                        edges_going_out_list = []
                        
                        for edges_going_out_item in edges_going_out:
                            edge_to_id = (node,pos_from,edges_going_out_item)
                            edges_going_out_list.append(self.edges[edge_to_id])
                            
                        edge_id = (node,0,pos_from) 
                        pos_id = (node,pos_from)
                        self.m.addConstr(sum(edges_going_out_list) == self.edges[edge_id],\
                                                 name="edge_recon_constraint-%s-%s"%pos_id)
                        
                        self.m.addConstr(sum(edges_going_out_list) == self.positions[pos_id],\
                                            name="edge_from_constraint-%s-%s"%pos_id)
                        
                        self.m.addConstr(self.edges[edge_id] == self.positions[pos_id],\
                                            name="edge_to_constraint-%s-%s"%pos_id)
                        
                    
                # one way edges - coming in
                if self.position_type_dict.get((node,'t_sc_nodes')):
                    for pos_from in self.position_type_dict.get((node,'t_sc_nodes')):
                        
                        # edges coming in
                        edges_coming_in = node_from_reverse_edge_dict[node][pos_from]
                        edges_coming_in_list = []
                        
                        for edges_coming_in_item in edges_coming_in:
                            edge_to_id = (node,edges_coming_in_item,pos_from)
                            edges_coming_in_list.append(self.edges[edge_to_id])
                            
                        
                        edge_id = (node,pos_from, sequence_length - 1) 
                        pos_id = (node,pos_from)
                        self.m.addConstr(sum(edges_coming_in_list) == self.edges[edge_id],\
                                                 name="edge_recon_constraint-%s-%s"%pos_id)
                        
                        self.m.addConstr(sum(edges_coming_in_list) == self.positions[pos_id],\
                                            name="edge_to_constraint-%s-%s"%pos_id)
                        
                        self.m.addConstr(self.edges[edge_id] == self.positions[pos_id],\
                                            name="edge_from_constraint-%s-%s"%pos_id)       
             
    def penalty_constraints(self):
        ''' function to add penalty constraints '''
        
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                for pos in range(1,sequence_length - 1):  # no penalty for start and end
                    diff_id  = (node,node_neighbor_item,pos)
                    pen_id   = (node,node_neighbor_item,pos)
                    diff_var = self.diff[diff_id]
                    pen_var  = self.penalty[pen_id]
                    
                    if pos == 1: # penalty for first position is simple
                        self.m.addConstr(pen_var == diff_var,"penalty_constraint-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    else:
                        pen_prev_id = (node,node_neighbor_item,pos - 1)
                        prev_pen_var =  self.penalty[pen_prev_id]
                        prev_diff_var = self.diff[pen_prev_id]
                        
                        self.m.addConstr(diff_var - prev_diff_var >= 1 - self.M * (1 - pen_var),\
                                         name="penalty_constraint_1-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                        self.m.addConstr(diff_var - prev_diff_var <= self.M * (pen_var),\
                                         name="penalty_constraint_2-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    
                    # add penalty to the objective
                    self.objective.append(2 * pen_var)
                    
        
    def train(self,n_threads,time_out):
        # Params
        self.m.Params.Threads = n_threads
        self.m.Params.TimeLimit = time_out*60
        #self.m.Params.LogFile =  folder_location + 
        self.m.Params.LogToConsole = 0
        self.m.Params.Degenmoves=0
        
        # Optimize
        self.total_objective = sum([o for o in self.objective])
        self.m.setObjective(self.total_objective, GRB.MINIMIZE)
        self.m.update()
        
        self.m.write((self.folder_location + 'pf_mip_formulation_' + self.tree_name + '.lp'))
        self.m.optimize()
        
        # Is feasible?
        return self.m.SolCount > 0
    
    
    def get_info(self):
        info_all = {}
        info_all["objective"] = self.m.ObjVal
        info_all["bound"] = self.m.ObjBound
        info_all["gap"] = self.m.MIPGap
        info_all["is_optimal"] = (self.m.status == GRB.OPTIMAL)
        info_all["num_nodes"] = self.m.NodeCount
        info_all["num_vars"] = self.m.NumIntVars + self.m.NumBinVars

        if self.m.SolCount > 0:
            print("objective: %0.2f"%info_all["objective"])
            print("bound: %0.2f"%info_all["bound"])
            print("gap: %0.2f"%info_all["gap"])

        return info_all
    
    def get_solution(self):
        # get the path
        all_node_paths = {}
        for node,node_edge_from_list in self.node_from_edge_dict.items():
            preferred_path = []
            for pos_from in range(0,sequence_length):
                pos_id = (node,pos_from)
                preferred_path.append(int(self.positions[pos_id].X))
                
            all_node_paths[node] = preferred_path
            
        # get the differnece and penalty solution
        score_dict = {}
        overall_score = 0
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                total_score = 0
                for pos in range(1,sequence_length - 1): # penalty start from 1st position only
                    pen_id = (node,node_neighbor_item,pos)
                    diff_id = (node,node_neighbor_item,pos)
                    
                    total_score = total_score + 2 * int(self.penalty[pen_id].X)
                    total_score = total_score + int(self.diff[diff_id].X)
                    
#                     if node == 'N0' and node_neighbor_item == 'N1':
#                         print("pen_id",pen_id)
#                         print(int(self.penalty[pen_id].X))
#                         print(int(self.diff[diff_id].X))
                  
                #print("total_score between node {} and node-neighbor{} is {}"\
                                  #.format(node,node_neighbor_item,total_score)) 
                score_dict[(node,node_neighbor_item)] = total_score
                overall_score = overall_score + total_score
        print("overall_score",overall_score)
        return all_node_paths,score_dict
               

## RUNNING MIP 

In [4]:
# TEST EXAMPLE - 1
# folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
# nodes = 3
# sequence_length = 6 #(start and end pos)
# neighbor_dict = dict({1:[2,3]})
# extant_list = [2,3]
# node_from_edge_dict = dict({\
#                   1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
#                   2:{0:[1],1:[2],2:[3],3:[4],4:[5]} ,\
#                   3:{0:[1],1:[2],2:[3],3:[4],4:[5]}})
# node_to_edge_dict = dict({\
#                   1:{1:[0],2:[0,1],3:[1,2],4:[1,2,3],5:[4]} ,\
#                   2:{1:[0],2:[1],3:[2],4:[3],5:[4]} ,\
#                   3:{1:[0],2:[1],3:[2],4:[3],5:[4]}\
#                     })
# tree_name = 'test_example_1'

In [5]:
#TEST EXAMPLE - 2
# folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
# nodes = 3
# sequence_length = 6 #(start and end pos)
# neighbor_dict = dict({1:[2,3]})
# extant_list = [2,3]
# node_from_edge_dict = dict({\
#                   1:{0:[1,2],1:[2,3,4],2:[3,4],3:[4],4:[5]} ,\
#                   2:{0:[1],1:[2],2:[3],3:[4],4:[5]} ,\
#                   3:{0:[1],1:[4],2:[],3:[],4:[5]}})
# node_to_edge_dict = dict({\
#                   1:{1:[0],2:[0,1],3:[1,2],4:[1,2,3],5:[4]} ,\
#                   2:{1:[0],2:[1],3:[2],4:[3],5:[4]} ,\
#                   3:{1:[0],2:[],3:[],4:[1],5:[4]}\
#                     })
# tree_name = 'test_example_2'

In [6]:
# different trees

# TEST MIP Models for different trees

## sample tree
# nwk_file_path = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/grasp_ancestors.nwk'
# pogs_file     = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/sample_tree/pogs.json'
# tree_name = 'sample_tree'

## CYP2U - 165
# nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/grasp_ancestors.nwk'
# ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/grasp_ancestors.fa"
# pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/pogs.json'
# tree_name = 'cyp2u_165'
# mip_ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_165/mip_grasp_ancestors.fa"

# ## CYP2U - 359
# nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_359/grasp_ancestors.nwk'
# ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_359/grasp_ancestors.fa"
# pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_359/pogs.json'
# tree_name = 'cyp2u_359'

# ## CYP2U - 595
nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_595/grasp_ancestors.nwk'
ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_595/grasp_ancestors.fa"
pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_595/pogs.json'
mip_ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/CYP2U_595/\
mip_grasp_ancestors.fa"
tree_name = 'cyp2u_595'

# ## DHAD - 585
# nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/DHAD_585/grasp_ancestors.nwk'
# ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/DHAD_585/grasp_ancestors.fa"
# pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/DHAD_585/pogs.json'
# tree_name = 'dhad_585'

# ## DHAD - 1612
# nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/DHAD_1612/grasp_ancestors.nwk'
# ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/DHAD_1612/grasp_ancestors.fa"
# pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/DHAD_1612/pogs.json'
# tree_name = 'dhad_1612' 


# ## KARI - 1176
# nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/KARI_1176/grasp_ancestors.nwk'
# ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/KARI_1176/grasp_ancestors.fa"
# pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/KARI_1176/pogs.json'
# tree_name = 'kari_1176' 

# ## GO - 399
# nwk_file_path       = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/GDH-GOx_399/grasp_ancestors.nwk'
# ancestor_fasta_file = "/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/GDH-GOx_399/grasp_ancestors.fa"
# pogs_file           = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/data/GDH-GOx_399/pogs.json'
# tree_name = 'go_399' 



In [7]:
##### GET THE INPUT FOR MIP READY #####

folder_location = '/Users/sanjanatule/Documents/uq/Projects/PreferredPath/scripts/mip_files/'
TreePogs = Pogs(pogs_file,nwk_file_path)
node_from_edge_dict,node_from_reverse_edge_dict,sequence_length,extant_list,node_pogs_cnt,\
                                                        node_pogs = TreePogs.create_node_info_dict()
neighbor_dict = TreePogs.create_neighbor_object()
nodes = len(node_from_edge_dict)
node_type_dict = TreePogs.node_type_info()
print("TOTAL NODES::",nodes)
print("TOTAL POSITIONS:",sequence_length)
#print("EXTANT LIST:",extant_list)
# print("NEIGHBOR DICT:",neighbor_dict)
# print('NODE TYPE DICT:',node_type_dict)

TOTAL NODES:: 1189
TOTAL POSITIONS: 664


In [8]:
#### RUN OPTIMISATION ####

start = time.time()
print("Start Time:",start)
n_threads = 1
time_out = 60
# initialise the class
PyTree = PhyloTree(nodes,sequence_length,neighbor_dict,node_from_edge_dict,folder_location\
                   ,extant_list,tree_name,node_type_dict,node_from_reverse_edge_dict)

PyTree.add_pos_var_constraints()
PyTree.add_edges_var_constraints()
PyTree.add_pos_constraints_extants()
PyTree.add_constraints_ancestors()
PyTree.penalty_constraints()

is_sat = PyTree.train(n_threads, time_out)
print("is_sat",is_sat)
total_time = ((time.time()-start))
print("-----------------------------")
print("Total time = %0.2f[m]"%total_time)
info = PyTree.get_info()
info["total_time"] = total_time
info["is_sat"] = is_sat
print("info",info)

if is_sat:
    all_node_paths,score_dict = PyTree.get_solution()
    #print("all_node_paths",all_node_paths)
else:
    print("Did not find any satisfactory solution to the model")

Start Time: 1662600192.20453
Set parameter Username
Academic license - for non-commercial use only - expires 2023-08-19
Set parameter Threads to value 1
Set parameter TimeLimit to value 3600
is_sat True
-----------------------------
Total time = 211.36[m]
objective: 5166.00
bound: 5166.00
gap: 0.00
info {'objective': 5166.0, 'bound': 5166.0, 'gap': 0.0, 'is_optimal': True, 'num_nodes': 1.0, 'num_vars': 4787626, 'total_time': 211.35943603515625, 'is_sat': True}
overall_score 5166


In [9]:
# convert output file to FASTA file
with open(mip_ancestor_fasta_file,mode='w') as fout:
    for node_name,sequence in all_node_paths.items():
        fout.write('>' + str(node_name) + '\n')
        sequence_str = ''.join([str(s) for s in sequence])
        fout.write(str(sequence_str) + '\n')

## CHECKING THE MIP SOLUTION

In [10]:
''' functions for use in checking solution '''

''' get the sequence string after removing commas and start/end node'''
def get_sequence_string(preferred_path):
    pf_str = ''
    for p in preferred_path:
        pf_str = pf_str + str(p)

    # remove first and last path
    pf_str = pf_str[1:-1]
    return pf_str

''' verify the score given by mip '''
def sequence_distance_score(str1,str2):
    dis = 0
    prev_dis = 0
    
    for i in range(0,len(str1)):
        if str1[i] != str2[i]:  # not matching
            if prev_dis == 0:   # previous unmatched
                dis += 3
                prev_dis = 1
            else:
                dis += 1
        else:
            prev_dis = 0
    return dis

''' Check if path is valid path and not broken '''
def next_one(from_pos,path,seq_length):
    if from_pos == seq_length - 1:  # end pos
        return from_pos
    for p_ind in range(from_pos + 1,seq_length):
        if path[p_ind] == 1:
            return p_ind
        
def check_path_complete(pog_mat,path,seq_length):
    valid = 0
    # first position
    from_p = 0
    to_p = next_one(from_p,path,seq_length)
    
    while(1):
        if pog_mat[from_p][to_p] == 1:
            from_p = to_p
            to_p   = next_one(from_p,path,seq_length) # next position with 1
            #print('to_p',to_p)
            
            if to_p == seq_length - 1: # end position
                valid = 1
                break
        else:
            valid = 0
            break
    return valid

# test
# pog_mat = np.array([(0,0,1),(0,0,0),(0,0,0)])
# path = [1,0,1]
# seq_length = 3
# check_path_complete(pog_mat,path,seq_length)

In [11]:
''' 1 - check whether path is complete and fixed path are not changed '''
# 1 -- check output from mip 
def check_mip_output_node_changes(grasp_fasta_file,all_node_paths,extant_list,node_pogs,node_pogs_cnt):
    
    sequences_info = FastaFile(grasp_fasta_file)
    total_node_change = 0

    for node_name,preferred_path in all_node_paths.items(): # mip path
        grasp_output_seq = sequences_info.fetch(node_name)

        # convert into 1/0
        grasp_output_seq =  re.sub('[a-zA-Z]', '1', grasp_output_seq)
        grasp_output_seq = grasp_output_seq.replace('-','0')
        #print('grasp_output_seq',grasp_output_seq)

        # compare with mip output
        pf_str = ''
        for p in preferred_path:
            pf_str = pf_str + str(p)

        # remove first and last path
        pf_str = pf_str[1:-1]

        #print("preferred_path",pf_str)
        if grasp_output_seq != pf_str:
            print("checking node::",node_name)
            total_node_change += 1

            if node_name in extant_list:
                print("ERROR: path is different for an extant.Should not change.")
            else:
                if node_pogs_cnt[node_name] == 1:
                    print("ERROR: path is different for ancestor.Should not change.")
                else:
                    print("OK: path is different for ancestor.")


        # check if the path is complete
        pog_mat = node_pogs[node_name]
        if check_path_complete(pog_mat,preferred_path,sequence_length) == 0:
            print("Path is not complete. Incorrect results.")

    print("Total nodes changed",total_node_change)

check_mip_output_node_changes(ancestor_fasta_file,all_node_paths,extant_list,node_pogs,node_pogs_cnt)

checking node:: N0
OK: path is different for ancestor.
checking node:: N3
OK: path is different for ancestor.
checking node:: N21
OK: path is different for ancestor.
checking node:: N29
OK: path is different for ancestor.
checking node:: N36
OK: path is different for ancestor.
checking node:: N39
OK: path is different for ancestor.
checking node:: N47
OK: path is different for ancestor.
checking node:: N51
OK: path is different for ancestor.
checking node:: N124
OK: path is different for ancestor.
checking node:: N125
OK: path is different for ancestor.
checking node:: N134
OK: path is different for ancestor.
checking node:: N135
OK: path is different for ancestor.
checking node:: N137
OK: path is different for ancestor.
checking node:: N139
OK: path is different for ancestor.
checking node:: N145
OK: path is different for ancestor.
checking node:: N161
OK: path is different for ancestor.
checking node:: N163
OK: path is different for ancestor.
checking node:: N352
OK: path is differen

In [12]:
''' 2 - compare/check the score function between heuristic and mip method '''
### 2 - check score
def check_score(score_dict,neighbor_dict):
    # compare score with hueristics solution
    for node,node_neighbor_list in neighbor_dict.items():
        node_sequence = get_sequence_string(all_node_paths[node])
        for nn in node_neighbor_list:
            nn_sequence = get_sequence_string(all_node_paths[nn])
            # calculate score between the nodes
            diff_score = sequence_distance_score(node_sequence,nn_sequence)
            # compare with mip score
            mip_score = score_dict[(node,nn)]
            if diff_score != mip_score:
                print('node',node)
                print("node neighbor",nn)
                print("diff_score",diff_score)
                print("mip_score",mip_score)
                print("ERROR: Score is different")
                print("node sequence",node_sequence)
                print("nn sequence",nn_sequence)
check_score(score_dict,neighbor_dict)