In [2]:
import torch.nn as nn
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric
from gnn_dataset import BipartiteNodeData, GraphDataset, Experience
from gnn_policy import GNNPolicy, BipartiteGraphConvolution, Critic
from tqdm import tqdm
from pathlib import Path
from acr_bb import ACRBBenv, instance_generator
from torch.distributions import Categorical
from tqdm import trange
from torch.autograd import Variable
from tqdm.notebook import tqdm
from helpers import get_graph_from_obs

class Policy(nn.Module):
    def __init__(self, init_policy_net, temperature, gamma=0.99):
        super().__init__()
        self.policy_net = init_policy_net
        self.temperature = temperature
        self.sf = nn.Softmax()
        self.gamma = gamma
        # Episode policy and reward history 
        self.policy_history = torch.Tensor()
        self.reward_episode = []
        
        # Overall reward and loss history
        self.reward_history = []
        self.loss_history = []
        
    def reset(self):
        self.policy_history = torch.Tensor()
        
    def forward(self, state, action_set):
        logits = self.policy_net(state.antenna_features, state.edge_index, state.edge_attr, state.variable_features)
        prob = Categorical(self.sf(logits[state.candidates]/self.temperature))
        action_id = prob.sample()
        # print(logits, action_id)
        # Add log probability of our chosen action to our history    
        if self.policy_history.dim() != 0:
            self.policy_history = torch.cat([self.policy_history, prob.log_prob(action_id.unsqueeze(dim=0))])
        else:
            self.policy_history = (prob.log_prob(action_id))
        return action_id
    

# add policy network
MAX_ITER = 1000
ACTOR_LR = 0.001
CRITIC_LR = 0.001
DEVICE = torch.device('cpu')

punishment_for_incomplete_episode = -1000
temperature = 2
GAMMA = 0.99

d = torch.load('trained_models/trained_params.pkl')
policy_net = GNNPolicy()
# policy_net.load_state_dict( d )
policy = Policy(policy_net, temperature, gamma=GAMMA)

critic = Critic()
critic_loss_fn = nn.SmoothL1Loss()
env = ACRBBenv()
M = 4
N = 8
instances = instance_generator(M,N)
num_episodes = 2000

# def main(num_episodes):
running_reward = -1
avg_length = []

num_target_updates = 10 
num_grad_steps_per_target_update = 10 
num_actor_updates = 1

optimizer_actor = torch.optim.Adam(policy.parameters(), lr=ACTOR_LR)
optimizer_critic = torch.optim.Adam(critic.parameters(), lr=CRITIC_LR)


def update_actor_critic(optimizer_actor, optimizer_critic, experience):

    current_state, action, reward, next_state, terminal = experience.get_batch()
    for i in range(num_target_updates):
        with torch.no_grad():
            vs1 = critic(next_state).reshape([len(action), -1]).sum(dim=1)
            vs1[terminal] = 0
        critic_target = reward + GAMMA*vs1

        
        for j in range(num_grad_steps_per_target_update):
            optimizer_critic.zero_grad()
            vs = critic(current_state).reshape([len(action), -1]).sum(dim=1)
            critic_loss = critic_loss_fn(critic_target.detach(), vs)
            critic_loss.backward()
            optimizer_critic.step()
        
    
    advantage = critic_target - vs
    
    # Update network weights
    optimizer_actor.zero_grad()
    actor_loss = (torch.sum(torch.mul(policy.policy_history, Variable(advantage)).mul(-1), -1))
    actor_loss.backward()
    optimizer_actor.step()
    
    
    
for i in tqdm(range(num_episodes)):
    instance = next(instances)
    policy.reset()
    experience = Experience()
    
    obs, action_set, reward, done, _ = env.reset(instance)

    if done:
        policy.reward_episode.append(reward)
        continue

    for it in tqdm(range(MAX_ITER)):
        state = get_graph_from_obs(obs, action_set)
        action_id = policy(state, action_set)
        # Step through environment using chosen action
        next_obs, action_set, next_reward, done, _ = env.step(action_set[action_id])
        
        experience.push(obs, action_id, next_reward, next_obs, done)
        obs = next_obs
        
        if done:
            print(done)
            break
        if next_reward < -5:
            break
        
    update_actor_critic(optimizer_actor, optimizer_critic, experience )
    
    avg_length.append(it)
    if i % 10 == 0:
        print('Episode {}\tAverage length: {:.2f}'.format(i, np.mean(avg_length)))
        avg_length = []
torch.save( policy.policy_net.state_dict(), 'trained_models/unsupervised2.pkl')

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]



[0.         1.85095605 4.93772087 3.14159265] [6.28318531 1.85095606 4.93772237 4.71238898]
Episode 0	Average length: 486.00


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         0.03389139 5.28103708 3.14159265] [6.28318531 0.03389438 5.28103709 4.71238898]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         2.75806675 3.30111313 0.        ] [6.28318531 2.75806684 3.30111314 3.14159265]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         1.57079526 3.14159265 3.52633475] [6.28318531 1.57079526 6.28318531 3.52633476]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         0.         2.1267751  6.12453964] [6.28318531 3.14159265 2.12677511 6.12454113]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         3.14159265 6.25864161 2.08000266] [6.28318531 6.28318531 6.28318531 2.08000267]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         0.01659096 5.22780653 2.94525716] [6.28318531 0.01659096 5.25235022 2.94525734]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         5.0656639  4.71468995 3.14159265] [6.28318531 5.0656639  4.71469295 6.28318531]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         1.59611001 3.14159265 2.90177198] [6.28318531 1.5961115  6.28318531 2.90177198]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         0.         3.25074159 6.28241832] [6.28318531 3.14159265 3.25074159 6.28318531]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         0.27984284 5.19691549 0.        ] [6.28318531 0.27984289 5.19691549 3.14159265]
Episode 10	Average length: 216.80


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         0.         3.28156538 2.69817633] [6.28318531 3.14159265 3.28156538 2.69827221]


  0%|          | 0/1000 [00:00<?, ?it/s]

[0.         3.14159265 6.28270594 2.17587348] [6.28318531 6.28318531 6.28280181 2.17587348]


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [6]:
experience.current_state[-1].antenna_feature

array([[-0.00603163,  0.09213643,  0.09233364],
       [ 0.00320729,  0.08692668,  0.08698582],
       [-0.04042854,  0.08689565,  0.09584008],
       [ 0.02101943,  0.04579527,  0.05038872],
       [-0.05852153,  0.10975336,  0.12438074],
       [ 0.03422417, -0.08681951,  0.0933216 ],
       [ 0.02037475, -0.02240347,  0.03028276],
       [ 0.04327142,  0.02741518,  0.05122507]])

In [7]:
experience.next_state

[<acr_bb.Observation at 0x7f317ada5f98>,
 <acr_bb.Observation at 0x7f317a558278>,
 <acr_bb.Observation at 0x7f317a5604a8>,
 <acr_bb.Observation at 0x7f317a560898>,
 <acr_bb.Observation at 0x7f317a560588>,
 <acr_bb.Observation at 0x7f317a558080>,
 <acr_bb.Observation at 0x7f317a5608d0>,
 <acr_bb.Observation at 0x7f317a560940>,
 <acr_bb.Observation at 0x7f317a560908>,
 <acr_bb.Observation at 0x7f317a560748>,
 <acr_bb.Observation at 0x7f317a560b38>,
 <acr_bb.Observation at 0x7f317a560c50>,
 <acr_bb.Observation at 0x7f317a560828>,
 <acr_bb.Observation at 0x7f317a560128>,
 <acr_bb.Observation at 0x7f317a5603c8>,
 <acr_bb.Observation at 0x7f317a560c88>,
 <acr_bb.Observation at 0x7f317a560ef0>,
 <acr_bb.Observation at 0x7f317a560fd0>,
 <acr_bb.Observation at 0x7f317a560f98>,
 <acr_bb.Observation at 0x7f317a54ceb8>,
 <acr_bb.Observation at 0x7f317a560400>,
 <acr_bb.Observation at 0x7f317a560dd8>,
 <acr_bb.Observation at 0x7f317a57a358>,
 <acr_bb.Observation at 0x7f317a560cf8>,
 <acr_bb.Observa

In [5]:
import numpy as np

H = np.random.randn(5,4) + 1j*np.random.randn(5,4)

In [10]:
np.expand_dims(H[:,1], axis=1).shape

(5, 1)

In [3]:
H.transpose()

array([[-0.43715771, -1.51408521,  1.77973076,  0.76201086,  1.55187541]])

In [11]:
np.real(H)


array([[ 1.20485555, -0.50972262,  1.30158444,  0.73301039],
       [ 0.29003651,  0.33735667,  0.94248079, -0.40048354],
       [ 2.78217538,  0.85814744, -1.2915404 ,  0.04592328],
       [-0.9035769 , -0.44994444,  0.93716975,  0.6344888 ],
       [ 0.56344663,  0.99013184,  0.55635676, -0.57909431]])

In [13]:
import cvxpy as cp
cp.real(H)

Expression(CONSTANT, UNKNOWN, (5, 4))