In [224]:
''' Run MIP for indel inference '''

# libraries
import gurobipy as gp
from gurobipy import abs_,quicksum
from gurobipy import GRB
import time
import json
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import Dataset
import pickle

In [225]:
class PhyloTreeMIP:
    def __init__(self,extant_data,ancestor_data,tree_name,neighbor_dict,mip_ancestor_fasta_file,mip_model_file):

        # Define the configuration  and decision variables for the tree
        self.extant_data = extant_data
        self.ancestor_data = ancestor_data
        self.tree_name = tree_name
        self.sequence_length = len(list(self.extant_data.values())[0])
        self.neighbor_dict = neighbor_dict
        self.objective = []
        self.mip_ancestor_fasta_file = mip_ancestor_fasta_file
        self.all_node_paths = {}
        self.mip_model_file = mip_model_file

        # MIP data structures
        self.edges = {}
        self.ancestorsequence = {}
        self.penalty = {}
        self.diff = {}
        self.objective = []
        self.M = 999

        # 2 - create a new model
        self.m = gp.Model("PreferredPathSolve")

    def add_pos_constraints_ancestors(self):
        ''' function to add constraints for positions in ancestors  '''

        ancestor_list = self.ancestor_data[0] # all ancestors in a tree
        ancestor_range = range(1,self.sequence_length-1)
        # V - create variable for each position for each ancestor sequence
        for ancestor in ancestor_list:
            an = self.m.addVars(ancestor_range,vtype=GRB.BINARY, name="p-%s"%ancestor)
            self.ancestorsequence[ancestor] = an
#             self.m.addConstr(self.ancestorsequence[ancestor][0] == 1,\
#                                              name = "ancestor_start-%s"%ancestor)
#             self.m.addConstr(self.ancestorsequence[ancestor][self.sequence_length - 1] == 1,\
#                                              name = "ancestor_end-%s"%ancestor)
            

    def add_edge_constraints_ancestors(self):
        ''' function to add constraints for edges in ancestors  '''

        ancestor_list       = self.ancestor_data[0]
        ancestor_fwd_edges  = self.ancestor_data[1]
        ancestor_bkwd_edges = self.ancestor_data[2]
        ancestor_node_type  = self.ancestor_data[3]

        # constraint for each ancestor
        for ancestor in ancestor_list:

            # START NODES
            for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','start')):

                all_edges_from_pos = []
                for pos_to in ancestor_fwd_edges[fwd_back_pos]:
                    edge_id = (ancestor,fwd_back_pos,pos_to)
                    # V - var for each edge from start node
                    e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                    self.edges[edge_id] = e
                    all_edges_from_pos.append(e)
                    
                # C - only 1 edge can be used
                pos_id = (ancestor,fwd_back_pos)
                self.m.addConstr(quicksum(all_edges_from_pos) <= 1,name=\
                                          "ancestor_start_edge_constraint-%s-%s"%pos_id)

                # C - sum(edges) = position & only 1 edge can be used
                self.m.addConstr(quicksum(all_edges_from_pos) == 1,\
                                            name = "ancestor_edge_node_recon_constraint-%s"%ancestor)

            # DEAD NODES
            if ancestor_node_type.get(('ANCESTOR','dead_pos')) :
                for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','dead_pos')):
                    pos_id = (ancestor,fwd_back_pos)
                    # C - constraint for dead pos to 0 as they cannot find complete path
                    self.m.addConstr(self.ancestorsequence[ancestor][fwd_back_pos] == 0,\
                                                name = "ancestor_dead_pos-%s-%s"%pos_id)

            # FULLY CONNECTED NODES
            for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','fwd_back_pos')):
                
                pos_id = (ancestor,fwd_back_pos)
                # C - edges going out of the node
                all_edges_from_pos = []
                for pos_to in ancestor_fwd_edges[fwd_back_pos]:
                    edge_id = (ancestor,fwd_back_pos,pos_to)
                    # V - var for each edge going from the node
                    e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                    self.edges[edge_id] = e
                    all_edges_from_pos.append(e)

                # C - only 1 edge can be used
                #if len(all_edges_from_pos) > 1: # only if we have more than 1 edge
                self.m.addConstr(quicksum(all_edges_from_pos) <= 1,name=\
                                              "ancestor_edge_constraint1-%s-%s"%pos_id)

                # C - edges coming in into the node
                edges_coming_in_list = []
                for edges_coming_in_item in ancestor_bkwd_edges[fwd_back_pos]:
                    edge_to_id = (ancestor,edges_coming_in_item,fwd_back_pos)
                    edges_coming_in_list.append(self.edges[edge_to_id])
                
                # C - only 1 edge can be used
#                 if len(edges_coming_in_list) > 1: # only if we have more than 1 edge
#                     self.m.addConstr(quicksum(edges_coming_in_list) <= 1,name=\
#                                               "ancestor_edge_constraint2-%s-%s"%pos_id)
                
                # C - sum(edges going in) = sum(edges going out) for each node
                
                self.m.addConstr(quicksum(edges_coming_in_list) == quicksum(all_edges_from_pos),\
                                                 name="ancestor_edge_recon_constraint-%s-%s"%pos_id)
                # C - sum(edges) = position
                self.m.addConstr(quicksum(all_edges_from_pos) == self.ancestorsequence[ancestor][fwd_back_pos],\
                                            name="ancestor_edge_node_recon_constraint1-%s-%s"%pos_id)
                 # C - sum(edges) = position
#                 self.m.addConstr(quicksum(edges_coming_in_list) == self.ancestorsequence[ancestor][fwd_back_pos],\
#                                             name="ancestor_edge_node_recon_constraint2-%s-%s"%pos_id)

            # END NODES
            for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','end')):

                 all_edges_to_pos = []
                 for pos_from in ancestor_bkwd_edges[fwd_back_pos]:
                     edge_id = (ancestor,pos_from,fwd_back_pos)
                     all_edges_to_pos.append(self.edges[edge_id])
                    
                # C - only 1 edge can be used
                 pos_id = (ancestor,fwd_back_pos)
                 self.m.addConstr(quicksum(all_edges_to_pos) <= 1,name=\
                                           "ancestor_end_edge_constraint-%s-%s"%pos_id)

                 # C - sum(edges) = position
                 self.m.addConstr(quicksum(all_edges_to_pos) == 1,\
                                             name = "ancestor_edge_node_recon_constraint-%s"%ancestor)
                    

    # penalty constraint
    def add_penalty_constraint(self):
        ''' function to add penalty constraints for whole tree  '''

        # difference constraints
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:

                # V - add position difference variable
                node_pos_var = self.ancestorsequence[node]

                # check if it is a extant node or not
                try:
                    #node_neighbor_pos_var = self.ancestorsequence[node_neighbor_item]
                    node_neighbor_pos_var = [int(i) for i in self.extant_data[node_neighbor_item]]
                except:
                    #continue
                    node_neighbor_pos_var = self.ancestorsequence[node_neighbor_item]

                nn_pair = (node,node_neighbor_item)
                diff_pos_temp = self.m.addVars(self.sequence_length,vtype=GRB.INTEGER,lb=-1,ub=1, name='dt_temp-%s-%s'%nn_pair)
                diff_pos = self.m.addVars(self.sequence_length,vtype=GRB.BINARY, name='d-%s-%s'%nn_pair)
                pen = self.m.addVars(self.sequence_length,vtype=GRB.BINARY, name='pe-%s-%s'%nn_pair)
                self.diff[nn_pair]    = diff_pos
                self.penalty[nn_pair] = pen
                
                self.m.addConstr(diff_pos_temp[0] == 0,\
                                             name = "DUMMY")
                self.m.addConstr(diff_pos_temp[self.sequence_length - 1] == 0,\
                                             name = "DUMMY1")
            
                self.m.addConstr(diff_pos[0] == 0,\
                                             name = "DUMMY3")
                self.m.addConstr(diff_pos[self.sequence_length - 1] == 0,\
                                             name = "DUMMY4")
                
    
                for pos in range(1,self.sequence_length - 1):
                    # C  - abs difference constraints
                    self.m.addConstr(diff_pos_temp[pos] == node_pos_var[pos] - node_neighbor_pos_var[pos],name=\
                                  "diff_constraint-%s-%s-%s"%(node,node_neighbor_item,pos))
                
                    self.m.addGenConstrAbs(diff_pos[pos],diff_pos_temp[pos], \
                                name="abs_diff_constraint-%s-%s-%s"%(node,node_neighbor_item,pos))
                    
                    
                    # C - penalty constraits
                    if pos == 1: # penalty for first position is simple
                        self.m.addConstr(pen[pos] == diff_pos[pos],"penalty_constraint-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    else:
                        self.m.addConstr((pen[pos] == 1) >> (diff_pos[pos] - diff_pos[pos-1] == 1))
                        
                    # O - add difference to the objective
                    self.objective.append(diff_pos[pos])

                    # O - add penalty to the objective
                    self.objective.append(2 * pen[pos])                    
                    
                    

    def train(self,n_threads,time_out):
        # Params
        self.m.Params.Threads = n_threads
        self.m.Params.TimeLimit = time_out*60
        self.m.Params.LogToConsole = 0
        self.m.Params.Degenmoves=0

        # Optimize
        self.total_objective = sum([o for o in self.objective])
        self.m.setObjective(self.total_objective, GRB.MINIMIZE)
        self.m.update()

        self.m.write(('mip_model_' + self.tree_name + '.lp'))
        self.m.optimize()

        #Is feasible?
        return self.m.SolCount > 0

    def get_info(self):
        info_all = {}
        info_all["objective"] = self.m.ObjVal
        info_all["bound"] = self.m.ObjBound
        info_all["gap"] = self.m.MIPGap
        info_all["is_optimal"] = (self.m.status == GRB.OPTIMAL)
        info_all["num_nodes"] = self.m.NodeCount
        info_all["num_vars"] = self.m.NumIntVars + self.m.NumBinVars

        if self.m.SolCount > 0:
            print("objective: %0.2f"%info_all["objective"])
            print("bound: %0.2f"%info_all["bound"])
            print("gap: %0.2f"%info_all["gap"])

        return info_all
        

def main():
    folder_location         = '/Users/sanjanatule/Documents/uq/Projects/MIPIndel/data/'
    #folder_location         = '/media/WorkingSpace/Share/mipindel/data/'

    ## Sample tree 1
    tree_name               =   'st1' #'st1' #'DHAD_1612' #'Anthony' #'DHAD_1612' #  #'CYP2U_359'#'CYP2U_165'
    mip_ancestor_fasta_file = folder_location + tree_name + '/mip_ancestor_indel.fasta'
    mip_model_file = folder_location + tree_name + '/mip_model.lp'

    # read input Dataset
    with open(folder_location + tree_name + '/ancestor_info.pkl', 'rb') as f:
        ancestor_data = pickle.load(f)

    with open(folder_location + tree_name + '/neighbor_dict.pkl', 'rb') as f:
        neighbor_dict = pickle.load(f)

    with open(folder_location + tree_name + '/extant_data.pkl', 'rb') as f:
        extant_data = pickle.load(f)


    # Run MIP
    start = time.time()
    print("Start Time:",start)

    # initialise the class
    PyTree = PhyloTreeMIP(extant_data,ancestor_data,tree_name,neighbor_dict,mip_ancestor_fasta_file,mip_model_file)

    # variable position constraints for ancestors
    print("Adding Ancestor Constraints")
    PyTree.add_pos_constraints_ancestors()
    # edge constraints for ancestors
    print("Adding Edges Constraints")
    PyTree.add_edge_constraints_ancestors()
    # position difference constraints
    print("Adding Penalty Constraints")
    
    PyTree.add_penalty_constraint()

    print('Training MIP Model')
    
    total_time = ((time.time()-start)/60)
    print("-----------------------------")
    print("Total time to create model = %0.2f[mins]"%total_time)
    
    start = time.time()
    print("Start Time:",start)
    n_threads = 1
    time_out = 60 # 1 hour
    
    is_sat = PyTree.train(n_threads, time_out)
    print("is_sat",is_sat)
    total_time = ((time.time()-start)/60)
    print("-----------------------------")
    print("Total time to solve model= %0.2f[mins]"%total_time)
    info = PyTree.get_info()
    info["total_time"] = total_time
    info["is_sat"] = is_sat
    print("info",info)

#     if is_sat:
#         all_node_paths,score_dict = PyTree.get_solution()
#         PyTree.output_fasta()
#     else:
#         print("Did not find any satisfactory solution to the model")

In [226]:
main()

Start Time: 1682848851.1086452
Adding Ancestor Constraints
Adding Edges Constraints
Adding Penalty Constraints
Training MIP Model
-----------------------------
Total time to create model = 0.00[mins]
Start Time: 1682848851.138889
Set parameter Threads to value 1
Set parameter TimeLimit to value 3600
is_sat True
-----------------------------
Total time to solve model= 0.00[mins]
objective: 14.00
bound: 14.00
gap: 0.00
info {'objective': 14.0, 'bound': 14.0, 'gap': 0.0, 'is_optimal': True, 'num_nodes': 1.0, 'num_vars': 4484, 'total_time': 0.0004253983497619629, 'is_sat': True}
