In [9]:
''' Run MIP for indel inference '''

# libraries
import gurobipy as gp
from gurobipy import abs_,quicksum
from gurobipy import GRB
import time
import json
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import Dataset
import pickle

In [10]:
class PhyloTreeMIP:
    def __init__(self,extant_data,ancestor_data,tree_name,neighbor_dict,mip_ancestor_fasta_file,mip_model_file):

        # Define the configuration  and decision variables for the tree
        self.extant_data = extant_data
        self.ancestor_data = ancestor_data
        self.tree_name = tree_name
        self.sequence_length = len(list(self.extant_data.values())[0])
        self.neighbor_dict = neighbor_dict
        self.objective = []
        self.mip_ancestor_fasta_file = mip_ancestor_fasta_file
        self.all_node_paths = {}
        self.mip_model_file = mip_model_file
        self.sequence_range = range(1,self.sequence_length-1)
        self.extant_list = self.extant_data.keys()
        self.ancestor_list = self.ancestor_data[0]

        # MIP data structures
        self.ancestorsequence = {}
        self.edges = {}
        self.penalty = {}
        self.diff = {}
        self.objective = []

        # 2 - create a new model
        self.m = gp.Model("PreferredPathSolve")

    def add_pos_constraints_ancestors(self):
        ''' function to add constraints for positions in ancestors  '''
        
        # V - create variable for each ancestor sequence
        for ancestor in self.ancestor_list:
            an = self.m.addVars(self.sequence_range,vtype=GRB.BINARY, name="an-%s"%ancestor)
            self.ancestorsequence[ancestor] = an
            

    def add_edge_constraints_ancestors(self):
        ''' function to add constraints for edges in ancestors  '''

        ancestor_fwd_edges  = self.ancestor_data[1]
        ancestor_bkwd_edges = self.ancestor_data[2]
        ancestor_node_type  = self.ancestor_data[3]

        # for each ancestor
        for ancestor in self.ancestor_list:

            # START NODES
            for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','start')):

                all_edges_from_pos = []
                for pos_to in ancestor_fwd_edges[fwd_back_pos]:
                    edge_id = (ancestor,fwd_back_pos,pos_to)
                    # V - var for each edge from start node
                    e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                    self.edges[edge_id] = e
                    all_edges_from_pos.append(e)

                # C - 1 edge has to be can be used
                self.m.addConstr(quicksum(all_edges_from_pos) == 1,\
                                            name = "ancestor_start_node_recon_constraint-%s"%ancestor)

            # DEAD NODES
            if ancestor_node_type.get(('ANCESTOR','dead_pos')) :
                for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','dead_pos')):
                    pos_id = (ancestor,fwd_back_pos)
                    # C - constraint for dead pos to 0 as they cannot find complete path
                    self.m.addConstr(self.ancestorsequence[ancestor][fwd_back_pos] == 0,\
                                                name = "ancestor_dead_pos-%s-%s"%pos_id)

            # FULLY CONNECTED NODES
            for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','fwd_back_pos')):
                
                pos_id = (ancestor,fwd_back_pos)
                # C - edges going out of the node
                all_edges_from_pos = []
                for pos_to in ancestor_fwd_edges[fwd_back_pos]:
                    edge_id = (ancestor,fwd_back_pos,pos_to)
                    # V - var for each edge going from the node
                    e = self.m.addVar(vtype=GRB.BINARY, name='e-%s-%s-%s'%edge_id)
                    self.edges[edge_id] = e
                    all_edges_from_pos.append(e)

                # C - edges coming in into the node
                edges_coming_in_list = []
                for edges_coming_in_item in ancestor_bkwd_edges[fwd_back_pos]:
                    edge_to_id = (ancestor,edges_coming_in_item,fwd_back_pos)
                    edges_coming_in_list.append(self.edges[edge_to_id])
                
                # C - sum(edges) = position
                self.m.addConstr(quicksum(all_edges_from_pos) == self.ancestorsequence[ancestor][fwd_back_pos],\
                                            name="ancestor_edge_node_recon_constraint1-%s-%s"%pos_id)
                # C - sum(edges) = position
                self.m.addConstr(quicksum(edges_coming_in_list) == self.ancestorsequence[ancestor][fwd_back_pos],\
                                            name="ancestor_edge_node_recon_constraint2-%s-%s"%pos_id)

            # END NODES
            for fwd_back_pos in ancestor_node_type.get(('ANCESTOR','end')):

                all_edges_to_pos = []
                for pos_from in ancestor_bkwd_edges[fwd_back_pos]:
                    edge_id = (ancestor,pos_from,fwd_back_pos)
                    all_edges_to_pos.append(self.edges[edge_id])
                        

                # C - sum(edges) = position
                self.m.addConstr(quicksum(all_edges_to_pos) == 1,\
                                             name = "ancestor_end_node_recon_constraint-%s"%ancestor)
                    

    # penalty constraint
    def add_penalty_constraint(self):
        ''' function to add penalty constraints for whole tree  '''
        
        # difference constraints
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:

                # V - add position difference variable
                node_pos_var = self.ancestorsequence[node]

                # check if it is a extant node or not
                if node_neighbor_item in self.extant_list:
                    node_neighbor_pos_var = [int(i) for i in self.extant_data[node_neighbor_item]]
                else:
                    node_neighbor_pos_var = self.ancestorsequence[node_neighbor_item]

                # V - create variables
                nn_pair               = (node,node_neighbor_item)
                diff_pos              = self.m.addVars(self.sequence_range,vtype=GRB.BINARY, \
                                                name='d-%s-%s'%nn_pair)
                pen                   = self.m.addVars(self.sequence_range,vtype=GRB.BINARY, \
                                                name='pe-%s-%s'%nn_pair)
                self.diff[nn_pair]    = diff_pos
                self.penalty[nn_pair] = pen
                
                
                for pos in range(1,self.sequence_length - 1):

                    # C - Abs constraints
                    self.m.addConstr( diff_pos[pos] <= node_pos_var[pos] + node_neighbor_pos_var[pos],name=\
                                     "diff_constraint_1-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos[pos] >= node_pos_var[pos] - node_neighbor_pos_var[pos],name=\
                                     "diff_constraint_2-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos[pos] >= node_neighbor_pos_var[pos] - node_pos_var[pos],name=\
                                     "diff_constraint_3-%s-%s-%s"%(node,node_neighbor_item,pos))
                    self.m.addConstr( diff_pos[pos] <= 2 - node_neighbor_pos_var[pos] - node_pos_var[pos],name=\
                                     "diff_constraint_4-%s-%s-%s"%(node,node_neighbor_item,pos))
                    
                    # C - penalty constraits
                    if pos == 1: # penalty for first position is simple
                        self.m.addConstr(pen[pos] == diff_pos[pos],"penalty_constraint-%s-%s-%s"%\
                                         (node,node_neighbor_item,pos))
                    else:
                        self.m.addConstr(pen[pos] >= diff_pos[pos] - diff_pos[pos-1])
                        
                    # O - add difference to the objective
                    self.objective.append(diff_pos[pos])

                    # O - add penalty to the objective
                    self.objective.append(2 * pen[pos])                    
                    

    def train(self,n_threads,time_out):
        # Params
        self.m.Params.Threads = n_threads
        self.m.Params.TimeLimit = time_out*60
        self.m.Params.LogToConsole = 0
        self.m.Params.Degenmoves=0
        self.m.Params.Method = 1
        self.m.Params.LogFile = "mip_model.log"

        # Optimize
        self.total_objective = quicksum(self.objective)
        self.m.setObjective(self.total_objective, GRB.MINIMIZE)
        self.m.update()

        self.m.write(('mip_model_' + self.tree_name + '.lp'))
        self.m.optimize()

        #Is feasible?
        return self.m.SolCount > 0

    def get_info(self):
        info_all = {}
        info_all["objective"] = self.m.ObjVal
        info_all["bound"] = self.m.ObjBound
        info_all["gap"] = self.m.MIPGap
        info_all["is_optimal"] = (self.m.status == GRB.OPTIMAL)
        info_all["num_nodes"] = self.m.NodeCount
        info_all["num_vars"] = self.m.numVars
        if self.m.SolCount > 0:
            print("objective: %0.2f"%info_all["objective"])
            print("bound: %0.2f"%info_all["bound"])
            print("gap: %0.2f"%info_all["gap"])

        return info_all
    
    def output_fasta(self,all_node_paths):
        # convert output file to FASTA file
        with open(self.mip_ancestor_fasta_file,mode='w') as fout:
            for node_name,sequence in all_node_paths.items():
                fout.write('>' + str(node_name) + '\n')
                sequence_str = ''.join([str(s) for s in sequence])
                fout.write(str(sequence_str) + '\n')
    
    def get_solution(self):
        # get the path for extants - should be same as the input
        all_node_paths = {}
        
        for sequence_name,preferred_path in self.extant_data.items():
            all_node_paths[sequence_name] = preferred_path
        
        # get the path for ancestor
        for ancestor in self.ancestor_data[0]:
            preferred_path = []
            preferred_path.append('1') # start position
            for pos in range(1,self.sequence_length-1):
                preferred_path.append(str(int(self.ancestorsequence[ancestor][pos].X)))
            preferred_path.append('1') # end position
            
            all_node_paths[ancestor] = "".join(preferred_path)
            
        # get the differnece and penalty solution
        score_dict = {}
        overall_score = 0
        for node,node_neighbor in self.neighbor_dict.items():
            for node_neighbor_item in node_neighbor:
                total_score_1 = 0
                total_score_2 = 0
                pen_id  = (node,node_neighbor_item)
                diff_id = (node,node_neighbor_item)
                
                for pos in range(1,self.sequence_length - 1): # penalty start from 1st position only
                    
                    total_score_1 = total_score_1 + 2 * int(self.penalty[pen_id][pos].X)
                    total_score_2 = total_score_2 + int(self.diff[diff_id][pos].X)
#                     print(f"pos-{pos}")
#                     print(2 * int(self.penalty[pen_id][pos].X))
#                     print(int(self.diff[diff_id][pos].X))
                    
#                     try:
#                         print(int(self.ancestorsequence[node][pos].X))
#                     except:
#                         print(int(self.extant_data[node][pos]))
#                     try:
#                         print(int(self.ancestorsequence[node_neighbor_item][pos].X)) 
#                     except:
#                         print(int(self.extant_data[node_neighbor_item][pos]))
                
                print(f"Sequence for node {node} is {all_node_paths[node]}")
                print(f"Sequence for node {node_neighbor_item} is {all_node_paths[node_neighbor_item]}")
                print(f"Difference Score between {node} and {node_neighbor_item} is {total_score_2}")
                print(f"Penalty Score between {node} and {node_neighbor_item} is {total_score_1}")
                print(f"Total score is {total_score_1 + total_score_2}")
                
                score_dict[(node,node_neighbor_item)] = total_score_1 + total_score_2
                overall_score = overall_score + total_score_1 + total_score_2
                
        return all_node_paths ,score_dict, overall_score
        

def main():
    folder_location         = '/Users/sanjanatule/Documents/uq/Projects/MIPIndel/data/'
    #folder_location         = '/media/WorkingSpace/Share/mipindel/data/'

    ## Sample tree 1
    tree_name               =   'CYP2U_165' #'st1' #'DHAD_1612' #'Anthony' #'DHAD_1612' #  #'CYP2U_359'#'CYP2U_165'
    mip_ancestor_fasta_file = folder_location + tree_name + '/mip_ancestor_indel.fasta'
    mip_model_file = folder_location + tree_name + '/mip_model.lp'

    # read input Dataset
    with open(folder_location + tree_name + '/ancestor_info.pkl', 'rb') as f:
        ancestor_data = pickle.load(f)

    with open(folder_location + tree_name + '/neighbor_dict.pkl', 'rb') as f:
        neighbor_dict = pickle.load(f)

    with open(folder_location + tree_name + '/extant_data.pkl', 'rb') as f:
        extant_data = pickle.load(f)


    # Run MIP
    start = time.time()
    print("Start Time:",start)

    # initialise the class
    PyTree = PhyloTreeMIP(extant_data,ancestor_data,tree_name,neighbor_dict,mip_ancestor_fasta_file,mip_model_file)

    # variable position constraints for ancestors
    print("Adding Ancestor Constraints")
    PyTree.add_pos_constraints_ancestors()
    # edge constraints for ancestors
    print("Adding Edges Constraints")
    #PyTree.add_edge_constraints_ancestors()
    # position difference constraints
    print("Adding Penalty Constraints")
    
    PyTree.add_penalty_constraint()

    print('Training MIP Model')
    
    total_time = ((time.time()-start)/60)
    print("-----------------------------")
    print("Total time to create model = %0.2f[mins]"%total_time)
    
    start = time.time()
    print("Start Time:",start)
    n_threads = 1
    time_out = 60 # 1 hour
    
    is_sat = PyTree.train(n_threads, time_out)
    print("is_sat",is_sat)
    total_time = ((time.time()-start)/60)
    print("-----------------------------")
    print("Total time to solve model= %0.2f[mins]"%total_time)
    info = PyTree.get_info()
    info["total_time"] = total_time
    info["is_sat"] = is_sat
    
    print("info",info)

    if is_sat:
        all_node_paths,score_dict, overall_score = PyTree.get_solution()
#         print(all_node_paths)
#         print(score_dict)
#         print(overall_score)
        PyTree.output_fasta(all_node_paths)
    else:
        print("Did not find any satisfactory solution to the model")

In [11]:
main()

Start Time: 1684506801.140521
Adding Ancestor Constraints
Adding Edges Constraints
Adding Penalty Constraints
Training MIP Model
-----------------------------
Total time to create model = 0.18[mins]
Start Time: 1684506811.647425
Set parameter Threads to value 1
Set parameter TimeLimit to value 3600
is_sat True
-----------------------------
Total time to solve model= 0.27[mins]
objective: 3134.00
bound: 3134.00
gap: 0.00
info {'objective': 3134.0, 'bound': 3134.0, 'gap': 0.0, 'is_optimal': True, 'num_nodes': 1.0, 'num_vars': 512500, 'total_time': 0.27403958241144816, 'is_sat': True}
Sequence for node NODE_0 is 11000000000011110000000000000111100001111111111111101100001110001111111111100001110111001111111111111111111111111111110011100010000000011111110010111001111111111111111111111111111111111111111111111111111111111111101111111111111111111111111111111111111111111111111000001110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110

Sequence for node NODE_14 is 110000000000111100000000111111100000000011111111111000000011100011111111111000011100110011111111111111111111111111111100111000000000000111111100101110011111111111111111111111111111111111111111111111111111111111111011111111111111111111111111111111111111111111111110000011101111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111100011111110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110000000000000000000001
Sequence for node XP_015281737.1 is 1111111111111111001111011111111000000000111111111110000000111000111111111110000111001100111111111111111111111111111111001110001111111111111111001011100111111111111111111111111111111111111111111111111111111111111110111111111111111111111111111111111111111111111111100000111011111111111111111111111111111111111

Sequence for node NODE_67 is 110000000000111100000000000001111000011011111111011001000011100011111111110000011100110011111111111111111111111111111100111000000000000000001100101111111111111111111111111111111111111111111111111111111111111111111011111111111111111111111111111111111111111111111111111011101111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111100111111110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111000000000000000000001
Sequence for node XP_022621190.1 is 1100000000001111000000000000011110000110111111110110010000111000111111111100000111001100111111111111111111111111111111001110000000000000000011001011111111111111111111111111111111111111111111111111111111111111111110111111111111111111111111111111111111111111111111111110111011111111111111111111111111111111111

Sequence for node NODE_117 is 110000000000111100000000111111000000000000111111111000000011100011111111111000011111110011111111111111111111111111111100111000000000000110111100101110011111111111111111111111111111111111111111111111111111111111111011111111111111111111111111111111111111111111111110000011101111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111100011101110111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110000000000000000000001
Sequence for node XP_002194737.2 is 110000000000111100000000111111000000000000111111111000000011100011111111111000011111110011111111111111111111111111111100111000000000000110111100101110011111111111111111111111111111111111111111111111111111111111111011111111111111111111111111111111111111111111111110000011101111111111111111111111111111111111