In [None]:
from numpy.linalg.linalg import solve
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
import torch_geometric 
from gnn_dataset import GraphNodeDataset, instance_generator
from gnn_dataset import get_graph_from_obs
from gnn_policy import GNNPolicy, GNNNodeSelectionPolicy
from acr_bb import ACRBBenv, DefaultBranchingPolicy, solve_bb, solve_bb_policy
import gzip
import pickle
from fcn_policy import FCNNodeSelectionLinearPolicy, FCNNodeDataset
from tqdm import tqdm
from observation import LinearObservation, Observation
import time


def test_bb(M,N, num_egs=10, policy=None, policy_type='gnn'):
    instances = instance_generator(M,N)
    ogap_avg = 0
    speedup = 0
    for i in range(num_egs):
        instance = next(instances)
#         w_opt_old = solve_bb(instance)
        w_opt, f_opt, iters_opt, time_taken_opt = solve_bb_policy(instance, max_iter=1000)
        w, f, iters, time_taken = solve_bb_policy(instance, max_iter=1000, policy=policy, policy_type=policy_type)
        ogap_avg += (abs((f_opt-f)/(f_opt*num_egs))*100)
        speedup += time_taken_opt/(time_taken*num_egs)
        print('opt_f: {}, policy_f: {}, iters_opt: {}, iters: {}'.format(f_opt, f, iters_opt, iters))
    return ogap_avg, speedup

train_filepath = '../data/dagger_train_gnn1'
valid_filepath = '../data/dagger_valid_gnn1'

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

TRAIN_EPOCHS = 100
LEARNING_RATE = 0.001
BATCH_SIZE = 128
POLICY_TYPE = 'gnn'
MODEL_PATH = '../trained_params/gnn_node_prune_without_imitation.model'
train_files = [str(path) for path in Path(train_filepath).glob('sample_*.pkl')]            
valid_files = [str(path) for path in Path(valid_filepath).glob('sample_*.pkl')]

# train_files = train_files[:2]


print(len(train_files), len(valid_files))
sample_obs = pickle.load(gzip.open(train_files[0], 'rb'))[0]
assert isinstance(sample_obs, Observation)
M, N = sample_obs.variable_features.shape[0], sample_obs.antenna_features.shape[0]
print('M,N', M,N)

train_data = GraphNodeDataset(train_files)
train_loader = torch_geometric.loader.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_data = GraphNodeDataset(valid_files)
valid_loader = torch_geometric.loader.DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False) 

policy = GNNNodeSelectionPolicy()
policy = policy.train().to(DEVICE)
        
optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)

# training stage
total_data = 0
for _ in (range(TRAIN_EPOCHS)):
    mean_loss = 0
    mean_acc = 0
    n_samples_processed = 0
    targets_list = torch.Tensor([]).to(DEVICE)
    preds_list = torch.Tensor([]).to(DEVICE)
    t1 = time.time()
    for batch_data in tqdm(train_loader):
        # print('rest of the time', time.time() - t1)
        # t1 = time.time()
        batch, target = batch_data
        batch = batch.to(DEVICE)
        target = target.to(DEVICE)*1

        if POLICY_TYPE == 'gnn':
            batch_size = batch.num_graphs
            num_vars = int(batch.variable_features.shape[0]/batch_size)
            wts = torch.tensor([batch.variable_features[i*num_vars, 9] for i in range(batch_size)], dtype=torch.float32)
        else:
            batch_size = batch.shape[0] 
            wts = batch[:,-25]

        wts = 2.68/wts
        wts = wts.to(DEVICE)

        # print([batch.variable_features[i*num_vars, 9].item() for i in range(batch_size)], wts, target)
        wts = ((target)*7 + 1)*wts                   
        out = policy(batch, batch_size)
        bce = nn.BCELoss(weight=wts)        
        loss = bce(out.squeeze(), target.to(torch.float).squeeze())

        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        predicted_bestindex = (out>0.5)*1
        accuracy = sum(predicted_bestindex.reshape(-1) == target)

        targets_list = torch.cat((targets_list, target))
        preds_list = torch.cat((preds_list, predicted_bestindex))

        mean_loss += loss.item() * batch_size
        mean_acc += float(accuracy)
        n_samples_processed += batch_size
        # print(time.time()-t1)
        # t1 = time.time()
    total_data = n_samples_processed
    stacked = torch.stack((targets_list, preds_list.squeeze()), dim=1).to(torch.int)
    cmt = torch.zeros(2,2,dtype=torch.int64)
    for p in stacked:
        tl, pl = p.tolist()
        cmt[tl, pl] = cmt[tl, pl] + 1
    print(cmt)
    precision = cmt[1,1]/(cmt[0,1]+cmt[1,1])
    recall = cmt[1,1]/(cmt[1,0]+cmt[1,1])
    mean_acc = 2* (precision*recall)/(precision+recall)
    mean_loss /= n_samples_processed

    print("Train: precision:{}, recall:{}, f1-score:{}, loss: {}".format(precision, recall, mean_acc, mean_loss))

    # running tests
    ogap, speedup = test_bb(M,N,policy=policy)
    print('Test Results, ogap={}, speedup={}'.format(ogap, speedup))
    torch.save(policy.state_dict(), MODEL_PATH)


Batch(antenna_features=[1024, 9], edge_index=[2, 4096], edge_attr=[4096, 3], variable_features=[512, 10], candidates=[384], nb_candidates=[128], candidate_choices=[128], num_nodes=1536, batch=[1536], ptr=[129])

In [None]:
import cvxpy as cp
import numpy as np

def sdr(H, num_sample_random=50):
    N, M = H.shape
    
    W = cp.Variable((N,N), hermitian=True)
    constraints = [W >> 0]
    for i in range(M):
        HH = np.matmul(H[:,i:i+1], H[:,i:i+1].conj().T)
        constraints += [cp.real(cp.trace(HH @ W)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.real(cp.trace(W))), constraints)
    prob.solve()

    # Randomization
    W_real = np.real(W.value)
    W_imag = np.imag(W.value)

    # randA
    lmbda, U = np.linalg.eig(W.value)
    lmbda = np.real(lmbda)
    randvecs = [np.matmul(U, np.matmul(np.diag(np.sqrt(lmbda)), np.exp(1j*np.random.rand(8,1)*2*np.pi))) for _ in range(num_sample_random)]
    print(randvecs)
    outvecs = [ vec/min(abs(np.matmul(H.conj().T, vec))) for vec in randvecs]
    norms = [np.linalg.norm(vec) for vec in outvecs]
    sol_id = np.argmin(norms)


    # randB

    # randC
    
    # randvecs = [np.random.multivariate_normal(np.zeros(N), W_real) + 1j* np.random.multivariate_normal(np.zeros(N), W_imag) for i in range(num_sample_random)]
    # outvecs = [ vec/min(abs(np.matmul(H.conj().T, vec))) for vec in randvecs]
    # norms = [np.linalg.norm(vec) for vec in outvecs]
    # sol_id = np.argmin(norms)

    return outvecs[sol_id]


if __name__=='__main__':
    from fpp_sca import fpp_sca
    from acr_bb import solve_bb

    N,M = 8,4
    
    sdr_ogap = []
    fpp_ogap = []
    for i in range(1):
        H = np.random.randn(N,M) + 1j*np.random.randn(N,M)
        w_fpp = sdr(H)
        w_sdr = sdr(H, 1000)

        instance = np.stack((np.real(H), np.imag(H)), axis=0)
        w_bb, _ = solve_bb(instance, max_iter=10000)

        sdr_obj = np.linalg.norm(w_sdr)**2
        fpp_obj = np.linalg.norm(w_fpp)**2
        bb_obj = np.linalg.norm(w_bb)**2
        print(min(abs(np.matmul(H.conj().T, w_sdr))) , sdr_obj)
        print(min(abs(np.matmul(H.conj().T, w_fpp))), fpp_obj)
        print(min(abs(np.matmul(H.conj().T, w_bb))), bb_obj)

        sdr_ogap.append(abs((sdr_obj-bb_obj)/bb_obj)*100)
        fpp_ogap.append(abs((fpp_obj-bb_obj)/bb_obj)*100)

    print('sdr ogap is {}'.format(np.mean(sdr_ogap)))
    print('fpp ogap is {}'.format(np.mean(fpp_ogap)))





In [24]:
from fpp_sca import fpp_sca
from acr_bb import solve_bb
import cvxpy as cp
import numpy as np

N,M = 8,4

sdr_ogap = []
fpp_ogap = []
for i in range(1):
    H = np.random.randn(N,M) + 1j*np.random.randn(N,M)

num_sample_random = 100


N, M = H.shape
    
W = cp.Variable((N,N), hermitian=True)
constraints = [W >> 0]
for i in range(M):
    HH = np.matmul(H[:,i:i+1], H[:,i:i+1].conj().T)
    constraints += [cp.real(cp.trace(HH @ W)) >= 1]
prob = cp.Problem(cp.Minimize(cp.real(cp.trace(W))), constraints)
prob.solve()

# Randomization
W_real = np.real(W.value)
W_imag = np.imag(W.value)

# randA
lmbda, U = np.linalg.eig(W.value)
lmbda = np.abs(np.real(lmbda))
print
randvecs = [np.matmul(U, np.matmul(np.diag(np.sqrt(lmbda)), np.exp(1j*np.random.rand(8,1)*2*np.pi))) for _ in range(num_sample_random)]
print(randvecs)
outvecs = [ vec/min(abs(np.matmul(H.conj().T, vec))) for vec in randvecs]
norms = [np.linalg.norm(vec) for vec in outvecs]
sol_id = np.argmin(norms)



WARN: A->p (column pointers) not strictly increasing, column 36 empty
WARN: A->p (column pointers) not strictly increasing, column 45 empty
WARN: A->p (column pointers) not strictly increasing, column 54 empty
WARN: A->p (column pointers) not strictly increasing, column 63 empty
WARN: A->p (column pointers) not strictly increasing, column 72 empty
WARN: A->p (column pointers) not strictly increasing, column 81 empty
WARN: A->p (column pointers) not strictly increasing, column 90 empty
WARN: A->p (column pointers) not strictly increasing, column 99 empty
[array([[-0.09130624+0.13810623j],
       [ 0.05135496+0.01957848j],
       [ 0.03072003+0.02544564j],
       [-0.0943328 -0.06179415j],
       [-0.05330138-0.19608616j],
       [-0.00301132-0.04183625j],
       [ 0.0140068 -0.24722666j],
       [-0.09170384-0.08729824j]]), array([[-0.15203353+0.15766695j],
       [ 0.06483151+0.04418169j],
       [ 0.03797515+0.04456554j],
       [-0.08714137-0.06468971j],
       [ 0.00546342-0.1521071

In [29]:
randvecs = [ np.sqrt(np.real(np.diag(W.value).expand_dims()))*np.exp(1j*np.random.rand(8,1)*2*np.pi) for _ in range(1)]

In [39]:
np.sqrt(np.real(np.expand_dims(np.diag(W.value),axis=1)))*np.exp(1j*np.random.rand(8,1)*2*np.pi)

array([[ 0.19315121+0.00969416j],
       [ 0.06904565-0.01517905j],
       [ 0.02910219-0.04935413j],
       [ 0.10812836-0.00732447j],
       [ 0.13260076-0.1045509j ],
       [-0.07580445+0.10353299j],
       [ 0.10529457-0.15684938j],
       [ 0.13127397+0.09903976j]])

In [32]:
randvecs[0].shape

(8, 8)

In [22]:
w,v = np.linalg.eig(W.value)

In [26]:
np.diag(W.value)

array([0.03740137+0.j, 0.00499771+0.j, 0.00328277+0.j, 0.01174539+0.j,
       0.02851385+0.j, 0.01646539+0.j, 0.03568868+0.j, 0.02704173+0.j])

In [21]:
W.value[1,3]

(-0.004418363307085728-0.023306048202208766j)

In [11]:
np.exp(1j*np.random.rand(8,1)*2*np.pi).shape

(8, 1)

In [13]:
np.matmul(np.diag(np.sqrt(lmbda)), np.exp(1j*np.random.rand(8,1)*2*np.pi))

  """Entry point for launching an IPython kernel.


array([[-0.01794788+0.35832473j],
       [        nan       +nanj],
       [ 0.00112179+0.00108284j],
       [ 0.00044458+0.00065477j],
       [        nan       +nanj],
       [        nan       +nanj],
       [        nan       +nanj],
       [        nan       +nanj]])

In [15]:
np.diag(np.sqrt(lmbda))

  """Entry point for launching an IPython kernel.


array([[0.35877394, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        ,        nan, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.00155915, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.00079144, 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ,        nan,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
               nan, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        ,        nan, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ,        nan]])

In [16]:
lmbda

array([ 1.28718741e-01, -4.78812840e-06,  2.43094424e-06,  6.26371248e-07,
       -3.16393342e-07, -3.25809845e-07, -3.23098462e-07, -3.21091718e-07])