In [1]:
import numpy as geek
import numpy.ma as ma
  
# creating input array 
in_arr = geek.array([1, 2, 3, -1, 5])
print ("Input array : ", in_arr)
  
# Now we are creating a masked array 
# by making third entry as invalid. 
mask_arr = ma.masked_array(in_arr, mask =[0, 0, 1, 0, 0])
print ("Masked array : ", mask_arr)
  
# applying MaskedArray.argmin methods to mask array
out_arr = mask_arr.argmin()
print ("Index of min element in masked array : ", out_arr)

Input array :  [ 1  2  3 -1  5]
Masked array :  [1 2 -- -1 5]
Index of min element in masked array :  3


In [1]:
import torch
import torch.nn as nn
import numpy as np
import time 
from models.gnn_policy import GNNPolicy, GNNNodeSelectionPolicy
from models.fcn_policy import FCNNodeSelectionLinearPolicy
from models.gnn_dataset import get_graph_from_obs
from models.setting import MODEL_PATH
from single_group_as_bm.observation import *
from single_group_as_bm.solve_relaxation import solve_relaxed, solve_relaxed_with_selected_antennas, cvxpy_relaxed
import numpy.ma as ma

min_bound_gap = 0.01

class Node(object):
    def __init__(self,
                z_mask=None, 
                z_sol=None, 
                z_feas=None, 
                w_sol=None, 
                w_feas=None,
                l_angle=None,
                u_angle=None,
                U=False, 
                L=False, 
                depth=0, 
                parent_node=None, 
                node_index = 0):
        """
        @params: 
            z_mask: vector of boolean, True means that the current variable is boolean
            z_sol: value of z at the solution of the cvx relaxation
            z_feas: value of z after making z_sol feasible (i.e. boolean)
            U: True if it is the current global upper bound
            L: True if it is the current global lower bound
            depth: depth of the node from the root of the BB tree
            node_index: unique index assigned to the node in the BB tree
        """
        # related to the discrete variables
        self.z_mask = z_mask.copy()
        self.z_sol = z_sol.copy()
        self.z_feas = z_feas.copy()

        # related to the continuous variables
        self.w_sol = w_sol.copy()
        self.w_feas = w_feas.copy()
        self.l_angle = l_angle.copy()
        self.u_angle = u_angle.copy()

        # BB statistics
        self.U = U
        self.L = L
        self.depth = depth
        self.parent_node = parent_node
        self.node_index = node_index

class DefaultBranchingPolicy(object):
    def __init__(self):
        pass

    def select_variable(self, observation, candidates, env=None):
        """
        Returns two variables to split on resulting in 4 children at one node
        """
        # z branching variable
        z_mask = observation.antenna_features[:,2]
        z_sol = observation.antenna_features[:,0]

        z_sol_rel = (1-z_mask)*(np.abs(z_sol - 0.5))
        if sum(1-z_mask)<1:
            z_branching_var = None
        else:
            z_branching_var = np.argmax(z_sol_rel)

        # maximum channel engergy
        if env is not None:
            z_compare = (1-z_mask)*env.H_energy_per_antenna
            if sum(1-z_mask)<1:
                z_branching_var = None
            else:
                z_branching_var = np.argmax(z_compare)

        # maximum channel sum
        # if env is not None:
        #     z_compare = (1-z_mask)*env.H_sum_per_antenna
        #     if sum(1-z_mask)<1:
        #         z_branching_var = None
        #     else:
        #         z_branching_var = np.argmax(z_compare)


        # w branching variable
        mask = np.zeros(observation.variable_features.shape[0])
        mask[0] = 1 # To make the variable not= 1
        l_angle = observation.variable_features[:,3]
        u_angle = observation.variable_features[:,4]
        
        for i in range(1,len(mask)):
            if u_angle[i] - l_angle[i] < min_bound_gap:
                mask[i] = 1
        
        if sum(mask) < len(mask):
            w_branching_var = np.argmin(observation.variable_features[:, 7]+ mask*999999)
            # print('branching var value {}'.format(u_angle[w_branching_var]-l_angle[w_branching_var]))
        else:
            w_branching_var = None

        return z_branching_var, w_branching_var 


class BBenv(object):
    def __init__(self, observation_function=Observation, node_select_policy_path='default', policy_type='gnn', epsilon=0.001, init_U=None):
        """
        @params: 
            node_select_policy_path: one of {'default', 'oracle', policy_params}
                                     if the value is 'oracle', optimal solution should be provided in the reset function
                                     policy_params refers to the actual state_dict of the policy network
                                     appropriate policy_type should be given according the policy parameters provided in this argument
            policy_type: One of 'gnn' or 'linear'
        """
        self._is_reset = None
        self.epsilon = epsilon # stopping criterion 
        self.H = None
        
        self.nodes = []     # list of problems (nodes)
        self.num_nodes = 0
        self.num_active_nodes = 0
        self.all_nodes = [] # list of all nodes to serve as training data for node selection policy
        self.optimal_nodes = []
        self.node_index_count = 0

        self.L_list = []    # list of lower bounds on the problem
        self.U_list = []    # list of upper bounds on the problem
        
        self.global_L = 0 # global lower bound
        self.global_U = np.inf  # global upper bound        

        self.action_set_indices = None
        # current active node
        self.active_node = None

        self.global_U_ind = None
        self.failed_reward = -2000

        self.node_select_model = None

        self.init_U = 999999
        self.node_select_policy = self.default_node_select        
        
        self.z_incumbent = None
        self.w_incumbent = None
        
        self.current_opt_node = None
        self.min_bound_gap = None

        if node_select_policy_path == 'default':
            self.node_select_policy = self.default_node_select
        elif node_select_policy_path == 'oracle':
            self.node_select_policy = self.oracle_node_select
        else:
            if policy_type=='gnn':
                self.node_select_model = GNNPolicy()
                self.node_select_model.load_state_dict(torch.load(node_select_policy_path))
                self.node_select_policy = self.learnt_node_select
            if policy_type=='linear':
                self.node_select_model = FCNNodeSelectionLinearPolicy()
                self.node_select_model.load_state_dict(torch.load(node_select_policy_path))
                self.node_select_policy = self.learnt_node_select
                
        self.observation_function = observation_function
        self.include_heuristic_solutions = False
        self.heuristic_solutions = []
        

    def set_heuristic_solutions(self, solution):
        """
        Provide antenna selections provided by heuristic methods in order to incorporate them into the BB
        """
        self.include_heuristic_solutions = True
        self.heuristic_solutions.append(solution)


    def reset(self, instance, max_ant,  oracle_opt=None):

        # clear all variables
        self.H = None
        self.nodes = []  # list of problems (nodes)
        self.all_nodes = []
        self.optimal_nodes = []
        self.node_index_count = 0

        self.L_list = []    # list of lower bounds on the problem
        self.U_list = []    # list of upper bounds on the problem
        self.global_L = 0 # global lower bound
        self.global_U = np.inf  # global upper bound        
        self.action_set_indices = None 
        self.active_node = None
        self.global_U_ind = None
        self.num_nodes = 1

        self.H = instance
        self.H_complex = self.H[0,:,:] + self.H[1,:,:]*1j
        
        self.min_bound_gap = np.ones(self.H.shape[-1])*0.01 # smallest size of continuous set to be branched on
        
        self.max_ant = max_ant

        # number of transmitters and users
        _, self.N, self.M = self.H.shape 
        self._is_reset = True
        self.action_set_indices_z = np.arange(1,self.N)
        self.action_set_indices_w = np.arange(1,self.M)

        # boolean vars corresponding to each antenna denoting its selection if True
        z_mask = np.zeros(self.N)
        # values of z (selection var) at the z_mask locations
        # for the root node it does not matter
        z_sol = np.zeros(self.N)

        done = False

        l = np.zeros(self.M)
        u = np.ones(self.M)*2*np.pi

        # initialize the root node 
        # try:
        [z, w, lower_bound, optimal] = cvxpy_relaxed(self.H,
                                            l=l, 
                                            u=u, 
                                            z_mask=z_mask, 
                                            z_sol=z_sol, 
                                            max_ant=self.max_ant)

        self.global_L = lower_bound
        
        # Upper bound method
        self.z_incumbent = self.get_feasible_z(z)
        w_selected, obj, optimal = solve_relaxed_with_selected_antennas(self.H, l=l, u=u, z_sol=self.z_incumbent)
        w_feas = self.get_feasible_w(w_selected, self.z_incumbent)
        self.global_U = np.linalg.norm(w_feas, 2)

        if not self.global_U == np.inf:
            self.w_incumbent = w_feas.copy()
        else:
            self.w_incumbent = np.zeros(self.H_complex.shape[0])

        self.active_node = Node(z_mask=z_mask, 
                                z_sol=z, 
                                z_feas=self.z_incumbent, 
                                w_sol = w, 
                                w_feas = w_feas,
                                l_angle=l,
                                u_angle=u,
                                U=self.global_U, 
                                L=lower_bound, 
                                depth=1, 
                                node_index=self.node_index_count)

        self.current_opt_node = self.active_node
        
        self.active_node_index = 0
        self.nodes.append(self.active_node)
        self.L_list.append(lower_bound)
        self.U_list.append(self.global_U)
        self.all_nodes.append(self.active_node)

        done = False
        if self.is_terminal():
            done = True
        reward = 0

        # TODO: build observation function
        observation = self.observation_function().extract(self)
        # observation = None

        #TODO: re-write this to include both z and w oracle solutions
        self.optimal_angle = None
        if oracle_opt is not None:
            (self.oracle_z, oracle_w) = oracle_opt
            self.optimal_angle = np.angle(np.matmul(self.H_complex.conj().T, oracle_w))
            self.optimal_angle[self.optimal_angle<0] += 2*np.pi

        else:
            self.oracle_z = np.zeros(self.N)
            self.optimal_angle = np.random.randn(self.M, 1)

        self.H_energy_per_antenna = np.linalg.norm(self.H_complex, 2, axis=1)
        self.H_sum_per_antenna = np.abs(np.sum(self.H_complex, axis=1))

        return 

    def push_children(self, action_id, node_id):
        """
        action_id branching variable contains two indices, one for z and another for w
        use action_id branching variables to split the node into (possibly) four children
        if the action_id branching variables contain only one index and the other is None, then branch only on that variable
        e.g., (2,5) refers to second element of z and 5th element fo the continuous variable, c(5) = |h_5^H w|

        the sequence of children nodes are: left, midleft, midright, right
        """
        
        self.delete_node(node_id)

        z_action_id, c_action_id = action_id

        if z_action_id is None and c_action_id is None:
            # print('Both actions None')
            return
        # assert z_action_id is not None or c_action_id is not None, "Both branching variables set to None. Nothing to branch on"

        if z_action_id is not None:
            max_possible_ant = sum(self.active_node.z_mask*self.active_node.z_sol) + sum(1-self.active_node.z_mask)
            if max_possible_ant < self.max_ant:
                # print('less than max ant antennas')
                return 
            elif max_possible_ant == self.max_ant:
                self.active_node.z_mask = np.ones(self.N)
                self.active_node.z_sol = self.active_node.z_mask*self.active_node.z_sol + (1-self.active_node.z_mask)*np.ones(self.N)
                z_action_id = None
            else:
                z_mask_left = self.active_node.z_mask.copy()
                z_mask_left[z_action_id] = 1

                z_mask_right = self.active_node.z_mask.copy()
                z_mask_right[z_action_id] = 1

                z_sol_left = self.active_node.z_sol.copy()
                z_sol_left[z_action_id] = 0

                z_sol_right = self.active_node.z_sol.copy()
                z_sol_right[z_action_id] = 1

        if c_action_id is not None:
            # if np.all(abs(self.active_node.u_angle - self.active_node.l_angle)> self.min_bound_gap):
            if self.active_node.u_angle[c_action_id] - self.active_node.l_angle[c_action_id] > min_bound_gap:
                mid_u_angle = self.active_node.u_angle.copy()
                mid_u_angle[c_action_id] = (mid_u_angle[c_action_id] + self.active_node.l_angle[c_action_id])/2
                mid_l_angle = self.active_node.l_angle.copy()
                mid_l_angle[c_action_id] = mid_u_angle[c_action_id]
            else:
                c_action_id = None
                if z_action_id is None:
                    # print('No children at this node')
                    return
        
        children_sets = []
        if c_action_id is not None and z_action_id is not None:
            children_sets.append(((z_mask_left, z_sol_left), (self.active_node.l_angle, mid_u_angle)))
            children_sets.append(((z_mask_left, z_sol_left), (mid_l_angle, self.active_node.u_angle)))
            children_sets.append(((z_mask_right, z_sol_right), (self.active_node.l_angle, mid_u_angle)))
            children_sets.append(((z_mask_right, z_sol_right), (mid_l_angle, self.active_node.u_angle)))

        elif c_action_id is not None and z_action_id is None:
            children_sets.append(((self.active_node.z_mask.copy(), self.active_node.z_sol.copy()), (self.active_node.l_angle, mid_u_angle)))
            children_sets.append(((self.active_node.z_mask.copy(), self.active_node.z_sol.copy()), (mid_l_angle, self.active_node.u_angle)))
        
        elif c_action_id is None and z_action_id is not None:
            children_sets.append(((z_mask_left, z_sol_left), (self.active_node.l_angle, self.active_node.u_angle)))
            children_sets.append(((z_mask_right, z_sol_right), (self.active_node.l_angle, self.active_node.u_angle)))
            
        children_stats = []
        for subset in children_sets:
            children_stats.append(self.create_children(subset))
        
        if len(self.nodes) == 0:
            return

        # Update the global upper and lower bound 
        # update the incumbent solutions
        min_L_child = min([children_stats[i][1] for i in range(len(children_stats))])
        self.global_L = min(min(self.L_list), min_L_child)

        min_U_index = np.argmin([children_stats[i][0] for i in range(len(children_stats))])
        if self.global_U > children_stats[min_U_index][0]:
            self.global_U = children_stats[min_U_index][0] 
            self.z_incumbent = children_stats[min_U_index][2]
            self.w_incumbent = children_stats[min_U_index][3]
            

    def create_children(self, constraint_set):
        """
        Create the Node with the constraint set
        Compute the local lower and upper bounds 
        return the computed bounds for the calling function to update
        """
        (z_mask, z_sol), (l_angle, u_angle) = constraint_set 

        # check if the maximum number of antennas are already selected or all antennas are already assigned (z is fully assigned)
        if np.sum(z_mask*np.round(z_sol))==self.max_ant:
            z_sol = np.round(z_sol)*z_mask
            print('calling selected antennas method')
            [w, L, optimal] = solve_relaxed_with_selected_antennas(self.H,
                                                                    l=l_angle,
                                                                    u=u_angle,
                                                                    z_sol=z_sol)
            # check this constraint                                                                    
            if not optimal:
                print('antennas: not optimal, may be infeasible')
                print('constraint self', constraint_set)   
                print('parent', self.active_node.z_mask, self.active_node.z_sol, self.active_node.l_angle, self.active_node.u_angle)     

                                     
                return np.inf, np.inf, np.zeros(self.N), np.zeros(self.N)


            # if L < self.active_node.L - self.epsilon or not optimal:
            #     print('constraint self', constraint_set)   
            #     print('parent', self.active_node.z_mask, self.active_node.z_sol, self.active_node.l_angle, self.active_node.u_angle)                                             
            assert L >= self.active_node.L - self.epsilon, 'selected antennas: lower bound of child node less than that of parent'

            z_feas = z_sol.copy()
            w_feas = self.get_feasible_w(w,z_feas)
            U = self.get_objective(w_feas, z_feas)
            # create and append node
            self.node_index_count += 1
            new_node = Node(z_mask=z_mask,
                            z_sol=z_sol,
                            z_feas=z_feas,
                            w_sol=w,
                            w_feas=w_feas,
                            l_angle=l_angle,
                            u_angle=u_angle,
                            U=U,
                            L=L,
                            depth=self.active_node.depth+1,
                            node_index=self.node_index_count
                            )
            self.L_list.append(L)
            self.U_list.append(U)
            self.nodes.append(new_node)
            self.all_nodes.append(new_node)
            return U, L, z_feas, w_feas
        elif np.sum(z_mask*np.round(z_sol))>self.max_ant:
            return np.inf, np.inf, np.zeros(self.N), np.zeros(self.N)
        else:
            # print('solving relaxed with z_mask {}, z_sol {}'.format(z_mask, z_sol))
            # print('now solving relaxed problem')
            [z,w,L, optimal] = cvxpy_relaxed(self.H,
                                    l=l_angle,
                                    u=u_angle,
                                    z_sol=z_sol,
                                    z_mask=z_mask,
                                    max_ant=self.max_ant,
                                    T=min(np.sqrt(self.global_U), 1000))
            
            # check this constraint                                                                    
            if not optimal:
                print('relaxed: not optimal, may be infeasible')
                # print('constraint self', constraint_set, L)   
                # print('parent', self.active_node.z_mask, self.active_node.z_sol, self.active_node.l_angle, self.active_node.u_angle, self.active_node.L)                                  
                return np.inf, np.inf, np.zeros(self.N), np.zeros(self.N)

            # if L < self.active_node.L - self.epsilon:
            #     print('child node', constraint_set, L)   
            #     print('parent node', self.active_node.z_mask, self.active_node.z_sol, self.active_node.l_angle, self.active_node.u_angle, self.active_node.L)                                             
            #     print(self.H)

            assert L >= self.active_node.L - self.epsilon, 'relaxed: lower bound of child node less than that of parent'

            if not L == np.inf:
                # if the z is nearly determined round it
                temp = (1-z_mask)*(np.abs(z - 0.5))
                z_mask[temp>0.499] = 1

                z = np.round(z_mask*z) + (1-z_mask)*z

                z_feas = self.get_feasible_z(z)
                [w_feas_relaxed, L_feas_relaxed, optimal] =  solve_relaxed_with_selected_antennas(self.H,
                                                                    l=l_angle,
                                                                    u=u_angle,
                                                                    z_sol=z_feas)
                if optimal:
                    w_feas = self.get_feasible_w(w_feas_relaxed,z_feas)
                    U = self.get_objective(w_feas, z_feas)
                else:
                    w_feas = np.zeros(self.N)
                    U = np.inf
                # create and append node
                self.node_index_count += 1
                new_node = Node(z_mask=z_mask,
                                z_sol=z_sol,
                                z_feas=z_feas,
                                w_sol=w,
                                w_feas=w_feas,
                                l_angle=l_angle,
                                u_angle=u_angle,
                                U=U,
                                L=L,
                                depth=self.active_node.depth+1,
                                node_index=self.node_index_count
                                )
                self.L_list.append(L)
                self.U_list.append(U)
                self.nodes.append(new_node)
                self.all_nodes.append(new_node)
                                                                                    
                return U, L, z_feas, w_feas
            
            else:
                return np.inf, np.inf, np.zeros(self.N), np.zeros(self.N)


    def get_feasible_w(self, w_selected, z_feas):
        # masked_w = ma.masked_array(abs(np.matmul(self.H_complex.conj().T, w_selected)), mask=z_feas)
        return w_selected/min(abs(np.matmul(self.H_complex.conj().T, w_selected*z_feas)))

    def get_feasible_z(self, z):
        # z_round = np.round(z)
        # if np.sum(z_round) <= self.max_ant:
        #     return z_round
        # else:
        mask = np.zeros(len(z))
        mask[np.argsort(z)[len(z)-self.max_ant:]] = 1
        return mask

    def get_objective(self, w, z_feas):
        return np.linalg.norm(w*z_feas, 2)**2

    def set_node_select_policy(self, node_select_policy_path='default', policy_type='gnn'):
        if node_select_policy_path=='default':
            self.node_select_policy = 'default'
        elif node_select_policy_path == 'oracle':
            self.node_select_policy = 'oracle'
        else:
            if policy_type == 'gnn':
                self.node_select_model = GNNNodeSelectionPolicy()
                # self.node_select_model.load_state_dict(node_select_policy_path.state_dict())
                print('policy path', node_select_policy_path)
                model_state_dict = torch.load(node_select_policy_path)
                self.node_select_model.load_state_dict(model_state_dict)
                self.node_select_policy = 'ml_model'

            elif policy_type == 'linear':
                self.node_select_model = FCNNodeSelectionLinearPolicy()
                # self.node_select_model.load_state_dict(node_select_policy_path.state_dict())
                self.node_select_model.load_state_dict(node_select_policy_path)
                self.node_select_policy = 'ml_model'

    def select_variable_default(self):
        z_sol_rel = (1-self.active_node.z_mask)*(np.abs(self.active_node.z_sol - 0.5))
        return np.argmax(z_sol_rel)

    def select_node(self):
        node_id = 0
        while len(self.nodes)>0:
            node_id = self.rank_nodes()

            self.active_node = self.nodes[node_id]
            break
        return node_id, self.observation_function().extract(self), self.is_optimal(self.active_node)


    def prune(self, observation):
        if isinstance(observation, Observation):
            observation = get_graph_from_obs(observation, self.action_set_indices_w)
        elif isinstance(observation, LinearObservation):
            observation = torch.tensor(observation.observation, dtype=torch.float32).unsqueeze(0)
        if self.node_select_policy == 'oracle':
            return not self.is_optimal(self.active_node)
        elif self.node_select_policy == 'default':
            return False
        else:
            # out = self.node_select_model(observation.antenna_features, observation.edge_index, observation.edge_attr, observation.variable_features) 
            # out = out.sum()
            # out = self.sigmoid(out) 
            if self.include_heuristic_solutions:
                heuristic_match = self.contains_heuristic(self.active_node)
                if heuristic_match:
                    return False

            with torch.no_grad():
                out = self.node_select_model(observation, 1)

            if out < 0.5:
                return True
            else:
                return False

    def rank_nodes(self):
        return np.argmin(self.L_list)

    def fathom_nodes(self):
        del_ind = np.argwhere(np.array(self.L_list) > self.global_U)
        if len(del_ind)>0:
            del_ind = sorted(list(del_ind.squeeze(axis=1)))
            for i in reversed(del_ind):
                self.delete_node(i)
        
    def fathom(self, node_id):
        if self.nodes[node_id].L > self.global_U:
            self.delete_node(node_id)
            return True
        return False

    def delete_node(self, node_id):
        del self.nodes[node_id]
        del self.L_list[node_id]
        del self.U_list[node_id]

    # This needs to be re-written for the current task
    def is_optimal(self, node):
        if np.linalg.norm(node.z_mask*(node.z_sol - self.oracle_z)) < 0.0001 and (self.optimal_angle.squeeze()<=node.u_angle).all() and (self.optimal_angle.squeeze()>=node.l_angle).all():
            return True
        else:
            return False


    # This needs to be re-written for the current task
    def contains_heuristic(self, node):
        contains = False
        for heuristic_sol in self.heuristic_solutions:
            if np.linalg.norm(node.z_mask*(node.z_sol - heuristic_sol)) < 0.0001:
                contains = True
                break
        return contains

    def is_terminal(self):
        if (self.global_U - self.global_L)/abs(self.global_U) < self.epsilon:
            return True
        else:
            return False

    def default_node_select(self):
        """
        Use the node with the lowest lower bound
        """
        return np.argmin(self.L_list)




def solve_bb(instance, max_ant=5, max_iter=10000, policy='default', policy_type='gnn', oracle_opt=None):
    t1 = time.time()
    if policy_type == 'default':
        env = BBenv(observation_function=Observation, epsilon=0.001)
    elif policy_type == 'gnn':
        env = BBenv(observation_function=Observation, epsilon=0.001)
    elif policy_type == 'linear':
        env = BBenv(observation_function=LinearObservation, epsilon=0.001)
    elif policy_type == 'oracle':
        env = BBenv(observation_function=Observation, epsilon=0.001)
        pass

    branching_policy = DefaultBranchingPolicy()

    t1 = time.time()

    env.reset(instance, max_ant=max_ant)
    timestep = 0
    done = False
    ub_list = []
    lb_list = []
    while timestep < max_iter and len(env.nodes)>0 and not done:
        print('timestep', timestep, env.global_U, env.global_L, len(env.nodes))
        env.fathom_nodes()
        if len(env.nodes) == 0:
            break
        node_id, node_feats, label = env.select_node()
        
        if len(env.nodes) == 0:
            break
        # prune_node = env.prune(node_feats)
        # if prune_node:
        #     env.delete_node(node_id)
        #     continue
        # else:
        branching_var = branching_policy.select_variable(node_feats, env.action_set_indices, env=env)
        # print(branching_var)
        # print('selected node z_sol {}, z_mask {}, z_feas {}'.format(env.nodes[node_id].z_sol, env.nodes[node_id].z_mask, env.nodes[node_id].z_feas))

        done = env.push_children(branching_var, node_id)
        timestep = timestep+1

        if env.is_terminal():
            break
        ub_list.append(env.global_U)
        lb_list.append(env.global_L)


    print('ended')
    print('result', env.z_incumbent.copy(), np.linalg.norm(env.w_incumbent,2)**2)
    # returns the solution, objective value, timestep and the time taken
    return (env.z_incumbent.copy(), env.w_incumbent.copy()), env.global_U, timestep , time.time()-t1, ub_list, lb_list

if __name__ == '__main__':
    np.random.seed(seed = 150)
    N = 4
    M = 3
    max_ant = 3
    
    u_avg = 0
    t_avg = 0
    tstep_avg = 0
    for i in range(1):
        H = np.random.randn(N, M) + 1j*np.random.randn(N,M)    
        instance = np.stack((np.real(H), np.imag(H)), axis=0)
        _, global_U, timesteps, t, u_list, l_list = solve_bb(instance, max_ant=max_ant, max_iter = 7000)
        u_avg += global_U
        t_avg += t
        tstep_avg += timesteps

    print(u_avg, t_avg, tstep_avg, u_avg)



ImportError: cannot import name 'solve_relaxed'

In [43]:

def check_integrality(z_sol, z_mask):
    sum_z= np.sum(np.abs(z_mask*(z_sol - np.round(z_sol))))
    return sum_z < 0.001

In [52]:
N = 5

z_sol = np.random.binomial(size=N, n=1, p=0.5)
# z_mask = np.random.binomial(size=N, n=1, p=0.5)
z_sol = z_sol.astype('float64')

check_integrality(z_sol, z_mask)

False

In [38]:
z_mask

array([1, 0, 1, 1, 1])

In [31]:
z_sol

array([0, 0, 1, 0, 1])

In [4]:
z_mask

array([0, 1, 1, 0, 0])

In [5]:
z_sol

array([0, 0, 0, 1, 1])

In [12]:
z_sol[1] = 0.5

In [9]:
check_integrality(z_sol, z_mask)

True

In [13]:
z_sol - np.round(z_sol)

array([0, 0, 0, 0, 0])

In [27]:
z_sol.dtype

dtype('float64')

In [28]:
z_sol.astype('float64')

array([3. , 3. , 4. , 0.1])

In [14]:
z_sol[3] = 0.1

In [16]:
z_sol = np.array(z_sol)

In [25]:
z_sol[3] = 0.1

In [23]:
z_sol =  np.array([3,3,4,5.0])

In [26]:
z_sol

array([3. , 3. , 4. , 0.1])

In [98]:
import cvxpy as cp
import numpy as np


def as_omar(H, max_ant=5):
    lmbda_lb = 0
    lmbda_ub = 1e6
    # global lmbda_lb, lmbda_ub
    N,K = H.shape

    
    U = np.zeros((N,N))
    U_new = np.ones((N,N))
    r = 0
    max_iter = 30
    while r < max_iter:
        print('sparse iteration  {}'.format(r))
        r += 1
        U = U_new.copy()
        W, _ = sparse_iteration(H, U)
        a = np.diag(W)
        mask = (a>0.01)*1
        if mask.sum()<= max_ant:
            print('exiting here')
            break
        U_new = 1/(W + 1e-5)
    prelim_mask = mask.copy()

    # if mask.sum() > max_ant:
    #     return
    before_iter_ant_count = mask.sum()
    if mask.sum() > max_ant:
        return None, mask.copy()
    # step 2
    r = 0
    max_iter = 50
    while mask.sum() != max_ant and r < max_iter:
        r += 1
        # if mask.sum() < max_ant:
        lmbda = lmbda_lb + (lmbda_ub - lmbda_lb)/2
        W, _ = sdp_omar(H, lmbda, U_new)
        a = np.diag(W)
        mask = (a>0.01)*1
        print('iteration {}'.format(r), lmbda, mask.sum(), lmbda_lb, lmbda_ub)
        if mask.sum() == max_ant:
            break
        elif mask.sum() > max_ant:
            lmbda_lb = lmbda
        elif mask.sum() < max_ant:
            lmbda_ub = lmbda
    if mask.sum()>max_ant:
        mask = prelim_mask.copy()    
    print('num selected antennas', mask.sum())

    after_iter_ant_count = mask.sum()

    # step 3
    # solve using SDR and randomization with the selected antennas
    
    # Only use the columns of H that are allowed by mask
    mask_ind = []
    for i in range(4):
        if mask[i]:
            mask_ind.append(i)
    print('before', H.shape)
    print(mask, mask_ind)
    H_reduced = H[mask_ind]

    obj = sdr(H_reduced, num_samples=100)
    
    if mask.sum() > max_ant:
        return None, mask.copy()
    return obj.copy(), mask.copy()

def sparse_iteration(H, U, noise_var=1, min_snr=1):
    """
    Solves the relaxed formulation of Omar et al 2013
    """
    # print('z mask: {},\n z value: {}'.format(z_mask, z_sol))
    N, M = H.shape
    W = cp.Variable((N,N), hermitian=True)

    objective = cp.Minimize(cp.real(cp.trace(cp.multiply(U, W))))

    constraints = [W >> 0]
    for i in range(M):
        HH = np.matmul(H[:,i:i+1], H[:,i:i+1].conj().T)
        constraints += [cp.real(cp.trace(HH @ W)) >= 1]
    prob = cp.Problem(objective, constraints)
    prob.solve()
    
    if prob.status in ['infeasible', 'unbounded']:
        print('sparse iteration infeasible solution')
        return None, np.inf

    return W.value, 0

def sdp_omar(H, lmbda, U, noise_var=1, min_snr=1):
    """
    Solves the relaxed formulation of Omar et al 2013
    """
    # print('z mask: {},\n z value: {}'.format(z_mask, z_sol))
    N, M = H.shape
    W = cp.Variable((N,N), hermitian=True)

    objective = cp.Minimize(cp.real(cp.trace(W)) + lmbda*(cp.real(cp.trace(cp.multiply(U, W)))))

    constraints = [W >> 0]
    for i in range(M):
        HH = np.matmul(H[:,i:i+1], H[:,i:i+1].conj().T)
        constraints += [cp.real(cp.trace(HH @ W)) >= 1]
    prob = cp.Problem(objective, constraints)
    prob.solve()
    
    if prob.status in ['infeasible', 'unbounded']:
        print('infeasible solution')
        return None, np.inf

    return W.value, None

def sdr(H, num_samples=50):
    print(H.shape)
    print(H)
    N, M = H.shape
    
    W = cp.Variable((N,N), hermitian=True)
    constraints = [W >> 0]
    for i in range(M):
        HH = np.matmul(H[:,i:i+1], H[:,i:i+1].conj().T)
        constraints += [cp.real(cp.trace(HH @ W)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.real(cp.trace(W))), constraints)
    prob.solve()

    # Randomization
    W_real = np.real(W.value)
    W_imag = np.imag(W.value)

    # randA
    lmbda, U = np.linalg.eig(W.value)
    lmbda = np.abs(np.real(lmbda))
    randvecs = [np.matmul(U, np.matmul(np.diag(np.sqrt(lmbda)), np.exp(1j*np.random.rand(N,1)*2*np.pi))) for _ in range(num_samples)]
    outvecs = [ vec/min(abs(np.matmul(H.conj().T, vec))) for vec in randvecs]
    # sol_id = np.argmin(norms)


    # randB
    randvecs = [ np.sqrt(np.real(np.expand_dims(np.diag(W.value),axis=1)))*np.exp(1j*np.random.rand(N,1)*2*np.pi) for _ in range(num_samples)]
    outvecs += [ vec/min(abs(np.matmul(H.conj().T, vec))) for vec in randvecs]

    # randC
    randvecs = [np.random.multivariate_normal(np.zeros(N), W_real) + 1j* np.random.multivariate_normal(np.zeros(N), W_imag) for i in range(num_samples)]
    outvecs += [ vec/min(abs(np.matmul(H.conj().T, vec))) for vec in randvecs]
    norms = [np.linalg.norm(vec) for vec in outvecs]
    sol_id = np.argmin(norms)

    return np.linalg.norm(outvecs[sol_id],2)**2

In [99]:
from bb import solve_bb

N,M,L = 4,32,2

# H = np.random.randn(2,N,M)
Hc = H[0,::] + 1j*H[1,::]

obj = as_omar(Hc, max_ant=L)

_, global_U, timesteps, t = solve_bb(H, max_ant=L, max_iter = 7000)

print(obj[0], global_U)


sparse iteration  0
sparse iteration  1
sparse iteration  2
exiting here
num selected antennas 2
before (4, 32)
[1 0 1 0] [0, 2]
(2, 32)
[[ 0.07250534+1.31504205j  1.04478522+1.72936558j -0.22024619-0.6810912j
  -1.1595688 +1.15924914j  1.05512518-0.80747634j -0.08362269+0.81326268j
  -0.73807717-0.09229792j -0.67957337-0.65777943j -1.00253404+0.87670245j
   2.6459702 +1.88203649j  1.30040544+0.19525495j -0.78727805+1.57397358j
  -0.77756864-0.54609683j -0.92659659+0.02130906j -1.03030989-0.20638983j
   0.93377596+1.02825457j -1.65441663+0.52814673j  1.28878258+0.98126907j
  -1.08595216-0.17679441j  0.54962989-0.36448125j  0.30921099-0.52383612j
   1.63909637+1.43816975j -0.54372328-0.81378592j  1.25519859-1.0144964j
  -0.12452708+1.09412503j -0.96323811+1.80165065j -0.36704289+0.79961837j
   1.88121805+1.06968002j -0.69780571+1.55285628j  0.86289018+0.09770638j
   0.56130859-1.59872055j -0.41921535+1.01380202j]
 [-0.53362963+0.32764185j -1.50687444+0.21563993j -0.28175999-2.09865185j




timestep 0, U 2.8834726176390166, L 0.31638857795411474, len_nodes 1, depth_tree 1
timestep 1, U 2.8834726176390166, L 0.3163885758017645, len_nodes 4, depth_tree 1
timestep 2, U 2.8834726176390166, L 0.31811321581320445, len_nodes 7, depth_tree 2
timestep 3, U 2.8834726176390166, L 0.40620222524748156, len_nodes 10, depth_tree 2
timestep 4, U 2.8834726176390166, L 0.40620216082086985, len_nodes 11, depth_tree 3
timestep 5, U 2.5573105966868384, L 0.41747513880239445, len_nodes 12, depth_tree 4
antennas: not optimal, may be infeasible
timestep 6, U 2.5573105966868384, L 0.41747507675536183, len_nodes 11, depth_tree 3
timestep 7, U 2.5573105966868384, L 0.4625171725104916, len_nodes 12, depth_tree 4
timestep 8, U 2.5573105966868384, L 0.4703861613576063, len_nodes 14, depth_tree 3
timestep 9, U 2.5573105966868384, L 0.5222239876577189, len_nodes 15, depth_tree 4
timestep 10, U 2.5573105966868384, L 0.5222239842339638, len_nodes 16, depth_tree 4
timestep 11, U 2.5573105966868384, L 0.522

timestep 87, U 2.04428030208864, L 1.152374101153971, len_nodes 43, depth_tree 7
antennas: not optimal, may be infeasible
timestep 88, U 2.04428030208864, L 1.1523740988130653, len_nodes 42, depth_tree 4
timestep 89, U 2.04428030208864, L 1.1590274489455843, len_nodes 43, depth_tree 5
timestep 90, U 2.04428030208864, L 1.1590312073307123, len_nodes 42, depth_tree 6
timestep 91, U 2.04428030208864, L 1.1590311705276675, len_nodes 41, depth_tree 6
timestep 92, U 2.04428030208864, L 1.1607434867459332, len_nodes 42, depth_tree 7
timestep 93, U 2.04428030208864, L 1.161013635377477, len_nodes 41, depth_tree 7
antennas: not optimal, may be infeasible
timestep 94, U 2.04428030208864, L 1.1681129568565487, len_nodes 40, depth_tree 4
timestep 95, U 2.04428030208864, L 1.1681129494738334, len_nodes 40, depth_tree 6
timestep 96, U 2.04428030208864, L 1.1741196229607733, len_nodes 41, depth_tree 7
antennas: not optimal, may be infeasible
antennas: not optimal, may be infeasible
timestep 97, U 2.0

timestep 170, U 1.9913310954113848, L 1.6009775715283714, len_nodes 38, depth_tree 8
antennas: not optimal, may be infeasible
timestep 171, U 1.9913310954113848, L 1.6393314806172108, len_nodes 37, depth_tree 13
timestep 172, U 1.9913310954113848, L 1.6393313316596634, len_nodes 37, depth_tree 9
antennas: not optimal, may be infeasible
timestep 173, U 1.9913310954113848, L 1.6471669273621967, len_nodes 37, depth_tree 10
antennas: not optimal, may be infeasible
timestep 174, U 1.9913310954113848, L 1.647166927064187, len_nodes 36, depth_tree 9
antennas: not optimal, may be infeasible
timestep 175, U 1.9913310954113848, L 1.650750467268453, len_nodes 36, depth_tree 10
antennas: not optimal, may be infeasible
timestep 176, U 1.9913310954113848, L 1.6507502202764366, len_nodes 36, depth_tree 9
timestep 177, U 1.9913310954113848, L 1.675406656953768, len_nodes 37, depth_tree 10
antennas: not optimal, may be infeasible
timestep 178, U 1.9913310954113848, L 1.6754067283249383, len_nodes 35, d

timestep 253, U 1.9825664114957056, L 1.9302756689801392, len_nodes 11, depth_tree 8
timestep 254, U 1.9825664114957056, L 1.9314898599498762, len_nodes 11, depth_tree 9
timestep 255, U 1.9825664114957056, L 1.9341675421317457, len_nodes 10, depth_tree 16
timestep 256, U 1.9825664114957056, L 1.9390187297169632, len_nodes 10, depth_tree 17
antennas: not optimal, may be infeasible
timestep 257, U 1.9825664114957056, L 1.9393022372595683, len_nodes 9, depth_tree 17
antennas: not optimal, may be infeasible
timestep 258, U 1.9825664114957056, L 1.9404093564746738, len_nodes 9, depth_tree 18
timestep 259, U 1.9825664114957056, L 1.9466873755600769, len_nodes 10, depth_tree 18
antennas: not optimal, may be infeasible
timestep 260, U 1.9825664114957056, L 1.9494881805071103, len_nodes 10, depth_tree 16
antennas: not optimal, may be infeasible
timestep 261, U 1.9825664114957056, L 1.9528606651784344, len_nodes 9, depth_tree 19
antennas: not optimal, may be infeasible
timestep 262, U 1.98256641

In [2]:
import numpy as np
N,M,L = 4,4,2
num_instances = 1
H = np.random.randn(num_instances, N, M) + 1j*np.random.randn(num_instances, N,M)   
print(H) 

[[[ 0.32186097+0.65954155j  0.34281778+0.02350895j
    1.23766527-1.81582227j  1.26169182+0.34187653j]
  [-0.74908359+0.87048862j -0.52191639-0.55621534j
    0.63826856+0.60562038j -0.62769165+0.74986245j]
  [-0.08637567+0.94162341j  1.39708182-0.36261405j
   -0.4287559 +0.44502895j  0.42463397+0.45682093j]
  [ 1.07822437+0.56370311j -1.07990126-0.06373646j
   -1.34596239+0.37911793j  0.39980161-0.31254145j]]]


In [1]:
from numpy import array
import numpy as np

N,M,L = 4,4,2

Hc = array([[-0.25888367-0.28944009j,  3.06937408+2.35793347j,
       -0.470432  +1.88379463j,  1.61017936-1.55863225j],
      [ 0.47142104+0.96118079j,  0.6060375 -0.77919633j,
       -0.39720977-0.10625149j, -0.21177321-0.77096039j],
      [-1.12482453-1.22670005j,  0.54676194-1.85246426j,
        1.27414878-1.30884716j,  0.47142361-0.52206189j],
      [-0.28722866+1.08217707j, -0.32034735-0.97146323j,
        0.11415622-1.1107275j ,  1.12879974-0.05358403j]])
H = np.stack((np.real(Hc), np.imag(Hc)), axis=0)

z_opt, w_opt = (array([1., 0., 1., 0.]), array([-0.02142339-0.29873723j,  0.        +0.j        ,
       -0.40033129-0.37334624j,  0.        +0.j        ]))

(z_mask, z_sol), (l_angle, u_angle) = ((array([1., 1., 1., 1.]), array([0.49909233, 0.        , 1.        , 0.        ])), (array([0.        , 0.        , 3.14159265, 4.71238898]), array([6.28318531, 6.28318531, 6.28318531, 6.28318531])))

In [2]:
from solve_relaxation import *

obj = qp_relaxed(H=H, l=l_angle, u=u_angle, z_mask=z_mask, z_sol=z_sol, max_ant=L)
obj

# check_feasibility(H=H, l=l_angle, u=u_angle, w=w_opt, z=z_opt, T=1000)

ValueError: Rank(A) < p or Rank([P; A; G]) < n

In [8]:
H.shape

(1, 4, 4)

In [51]:
obj

(array([[ 0.10466105-0.17423484j],
        [-0.43578954-0.23513601j],
        [ 0.15122024+0.2044169j ]]),
 array([1, 0, 1, 1]))

In [32]:
H[np.array([1,2,3])]

array([[ 0.12122773-1.07811342j,  0.37751432-0.75202578j,
        -0.09611904+1.73250606j],
       [ 0.10758044-0.56174483j,  0.10528554-0.96291795j,
         0.43782672-1.93754168j],
       [-1.16184502-1.19967152j,  0.56596209-0.01535374j,
        -0.51183316-0.82128922j]])

In [55]:
mask = np.array([1,0,0,1])

In [59]:
mask_ind = []
for i in range(4):
    if mask[i]:
        mask_ind.append(i)
Hr = H[mask_ind]

In [62]:
mask_ind

[0, 3]

In [63]:
Hr

array([[ 0.31217868, -0.72525171,  0.22961232, -0.70575041],
       [ 1.37353867,  0.91961641, -1.28664353, -0.33615614]])

In [57]:
H

array([[ 0.31217868, -0.72525171,  0.22961232, -0.70575041],
       [-1.4111894 ,  0.32737857, -0.29223297,  0.93871917],
       [ 0.46425448,  1.71917706, -0.94451586, -1.23528613],
       [ 1.37353867,  0.91961641, -1.28664353, -0.33615614]])