In [1]:
import gzip
import pickle
import numpy as np
import ecole
from pathlib import Path
import torch
from  acr_bb import Observation, ACRBBenv, DefaultBranchingPolicy, RandomPolicy


MAX_SAMPLES = 10000

N = 8 # antennas
M = 4 # users
expert_prob = 0.5

def instance_generator(M, N):
    while 1:
        yield np.random.randn(2,N,M)

In [2]:
# instances = np.random.randn(MAX_SAMPLES, 2, N, M)
instances = instance_generator(M,N)

env = ACRBBenv()

expert_policy = DefaultBranchingPolicy()
random_policy = RandomPolicy()

In [3]:
episode_counter, sample_counter = 0, 0
Path('samples/').mkdir(exist_ok=True)

# We will solve problems (run episodes) until we have saved enough samples
max_samples_reached = False
while not max_samples_reached:
    episode_counter += 1
    
    observation, action_set, _, done, _ = env.reset(next(instances))
    while not done:
        if np.random.rand(1) > expert_prob:
            action_id = expert_policy.select_variable(observation, action_set)
            expert = True
        else:
            action_id = random_policy.select_variable(observation, action_set)
            expert = False
            
        # Only save samples if they are coming from the expert (strong branching)
        if expert and not max_samples_reached:
            sample_counter += 1
            data = [observation, action_id, action_set]
            filename = f'samples/sample_{sample_counter}.pkl'

            with gzip.open(filename, 'wb') as f:
                pickle.dump(data, f)
            
            # If we collected enough samples, we finish the current episode but stop saving samples
            if sample_counter == MAX_SAMPLES:
                max_samples_reached = True

        observation, action_set, _, done, _ = env.step(action_id)

        
    print(f"Episode {episode_counter}, {sample_counter} samples collected so far")


Episode 1, 208 samples collected so far
Episode 2, 546 samples collected so far
Episode 3, 651 samples collected so far
Episode 4, 802 samples collected so far
Episode 5, 1179 samples collected so far
Episode 6, 1406 samples collected so far
Episode 7, 1520 samples collected so far
Episode 8, 1659 samples collected so far
Episode 9, 1749 samples collected so far
Episode 10, 1961 samples collected so far
Episode 11, 2109 samples collected so far
Episode 12, 2313 samples collected so far
Episode 13, 2456 samples collected so far
Episode 14, 2644 samples collected so far
Episode 15, 2822 samples collected so far
Episode 16, 3021 samples collected so far
Episode 17, 3272 samples collected so far
Episode 18, 3619 samples collected so far
Episode 19, 3827 samples collected so far
Episode 20, 3975 samples collected so far
Episode 21, 4135 samples collected so far
Episode 22, 4351 samples collected so far
Episode 23, 4737 samples collected so far
Episode 24, 4946 samples collected so far
Episo

In [1]:
import torch
import torch.nn.functional as F
import torch_geometric

LEARNING_RATE = 0.001
NB_EPOCHS = 50
PATIENCE = 10
EARLY_STOPPING = 20
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [115]:
import torch
import torch.nn.functional as F
import torch_geometric
import gzip
import pickle
import numpy as np
from pathlib import Path

LEARNING_RATE = 0.001
NB_EPOCHS = 50
PATIENCE = 10
EARLY_STOPPING = 20
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class BipartiteNodeData(torch_geometric.data.Data):
    """
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` 
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(self, antenna_features, edge_indices, edge_features, variable_features,
                 candidates, candidate_choice):
        super().__init__()
        if antenna_features is not None:
            self.antenna_features = torch.FloatTensor(antenna_features)
            self.edge_index = torch.LongTensor(edge_indices.astype(np.int64))
            self.edge_attr = torch.FloatTensor(edge_features)
            self.variable_features = torch.FloatTensor(variable_features)
            self.candidates = candidates
            self.nb_candidates = len(candidates)
            self.candidate_choices = candidate_choice

    def __inc__(self, key, value, *ags, **kwargs):
        """
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs 
        for those entries (edge index, candidates) for which this is not obvious.
        """
        if key == 'edge_index':
            return torch.tensor([[self.antenna_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'candidates':
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value)


class GraphDataset(torch_geometric.data.Dataset):
    """
    This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
    It can be used in turn by the data loaders provided by pytorch geometric.
    """
    def __init__(self, sample_files):
        super().__init__(root=None, transform=None, pre_transform=None)
        self.sample_files = sample_files

    def len(self):
        return len(self.sample_files)

    def get(self, index):
        """
        This method loads a node bipartite graph observation as saved on the disk during data collection.
        """
        with gzip.open(self.sample_files[index], 'rb') as f:
            sample = pickle.load(f)

        sample_observation, sample_action_id, sample_action_set = sample
        
        # We note on which variables we were allowed to branch, the scores as well as the choice 
        # taken by expert branching (relative to the candidates)
        candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))
#         candidate_choice = sample_action_id
        candidate_choice = torch.where(candidates == sample_action_id)[0][0]

        graph = BipartiteNodeData(sample_observation.antenna_features, sample_observation.edge_index, 
                                  sample_observation.edge_features, sample_observation.variable_features,
                                  candidates, candidate_choice)
        
        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = sample_observation.antenna_features.shape[0] + sample_observation.variable_features.shape[0]
        
        return graph

In [91]:
candidates = torch.LongTensor(np.arange(1,4))

In [110]:

sample_files = [str(path) for path in Path('samples/').glob('sample_*.pkl')]
train_files = sample_files[:int(0.8*len(sample_files))]
valid_files = sample_files[int(0.8*len(sample_files)):]

train_data = GraphDataset(train_files)
train_loader = torch_geometric.data.DataLoader(train_data, batch_size=5, shuffle=True)
valid_data = GraphDataset(valid_files)
valid_loader = torch_geometric.data.DataLoader(valid_data, batch_size=5, shuffle=False)

In [136]:
class GNNPolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()
        emb_size = 64
        antenna_nfeats = 3
        edge_nfeats = 3
        var_nfeats = 9

        # CONSTRAINT EMBEDDING
        self.antenna_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(antenna_nfeats),
            torch.nn.Linear(antenna_nfeats, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
        )

        # EDGE EMBEDDING
        self.edge_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(edge_nfeats),
            torch.nn.Linear(edge_nfeats, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
            
        )

        # VARIABLE EMBEDDING
        self.var_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(var_nfeats),
            torch.nn.Linear(var_nfeats, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
        )

        self.conv_v_to_c = BipartiteGraphConvolution()
        self.conv_c_to_v = BipartiteGraphConvolution()

        self.output_module = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, 1, bias=False),
        )

    def forward(self, constraint_features, edge_indices, edge_features, variable_features):
        reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
        
        # First step: linear embedding layers to a common dimension (64)
        constraint_features = self.antenna_embedding(constraint_features)
        edge_features = self.edge_embedding(edge_features)
        variable_features = self.var_embedding(variable_features)
        

        # Two half convolutions
#         print('var', variable_features.shape, 'cons', constraint_features.shape, 'edge', reversed_edge_indices.shape, 'edge_f', edge_features.shape)
        constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features, constraint_features)
        variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)

        # A final MLP on the variable features
        output = self.output_module(variable_features).squeeze(-1)
        return output
    

class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
    """
    The bipartite graph convolution is already provided by pytorch geometric and we merely need 
    to provide the exact form of the messages being passed.
    """
    def __init__(self):
        super().__init__('add')
        emb_size = 64
        
        self.feature_module_left = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size)
        )
        self.feature_module_edge = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size, bias=False)
        )
        self.feature_module_right = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size, bias=False)
        )
        self.feature_module_final = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size)
        )
        
        self.post_conv_module = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size)
        )

        # output_layers
        self.output_module = torch.nn.Sequential(
            torch.nn.Linear(2*emb_size, emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size),
        )

    def forward(self, left_features, edge_indices, edge_features, right_features):
        """
        This method sends the messages, computed in the message method.
        """
        output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]), 
                                node_features=(left_features, right_features), edge_features=edge_features)
        return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))

    def message(self, node_features_i, node_features_j, edge_features):
        output = self.feature_module_final(self.feature_module_left(node_features_i) 
                                           + self.feature_module_edge(edge_features) 
                                           + self.feature_module_right(node_features_j))
        return output
    

policy = GNNPolicy().to(DEVICE)

In [137]:
from tqdm import tqdm

def process(policy, data_loader, optimizer=None):
    """
    This function will process a whole epoch of training or validation, depending on whether an optimizer is provided.
    """
    mean_loss = 0
    mean_acc = 0
    mean_acc_copy = 0
    
    acc_list = []
    n_samples_processed = 0
    batch_count = 0
    with torch.set_grad_enabled(optimizer is not None):
        for batch in tqdm(data_loader):
            batch_count += 5
            batch = batch.to(DEVICE)
            # Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs
            logits = policy(batch.antenna_features, batch.edge_index, batch.edge_attr, batch.variable_features)
            # Index the results by the candidates, and split and pad them
            logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)

            # Compute the usual cross-entropy classification loss
            loss = F.cross_entropy(logits, torch.LongTensor(batch.candidate_choices))
            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            predicted_bestindex = logits.max(dim=-1, keepdims=True).indices
            accuracy = sum(predicted_bestindex.reshape(-1) == batch.candidate_choices)
#             accuracy = (true_scores.gather(-1, predicted_bestindex) == true_bestscore).float().mean().item()

            mean_loss += loss.item() * batch.num_graphs
            mean_acc += float(accuracy)
        
            mean_acc_copy +=float(accuracy)
            if batch_count >= 1000 and optimizer is not None:
                acc_list = mean_acc_copy/batch_count
                batch_count=0
                mean_acc_copy= 0
                break
                
            n_samples_processed += batch.num_graphs

    mean_loss /= n_samples_processed
    mean_acc /= n_samples_processed
    return mean_loss, mean_acc, acc_list


def pad_tensor(input_, pad_sizes, pad_value=-1e8):
    """
    This utility function splits a tensor and pads each split to make them all the same size, then stacks them.
    """
    max_pad_size = pad_sizes.max()
    output = input_.split(pad_sizes.cpu().numpy().tolist())
    output = torch.stack([F.pad(slice_, (0, max_pad_size-slice_.size(0)), 'constant', pad_value)
                          for slice_ in output], dim=0)
    return output

acc_list = []
optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)
for epoch in range(NB_EPOCHS):
    print(f"Epoch {epoch+1}")
    
    train_loss, train_acc, _ = process(policy, train_loader, optimizer)
    print(f"Train loss: {train_loss:0.3f}, accuracy {train_acc:0.3f}" )

    valid_loss, valid_acc, _ = process(policy, valid_loader, None)
    print(f"Valid loss: {valid_loss:0.3f}, accuracy {valid_acc:0.3f}" )

    acc_list.append(valid_acc)
torch.save(policy.state_dict(), 'trained_params.pkl')

Epoch 1


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:11, 126.41it/s]


Train loss: 0.832, accuracy 0.644


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 265.36it/s]


Valid loss: 0.721, accuracy 0.667
Epoch 2


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:20, 69.16it/s]


Train loss: 0.682, accuracy 0.720


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 298.51it/s]


Valid loss: 0.588, accuracy 0.755
Epoch 3


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:11, 122.78it/s]


Train loss: 0.643, accuracy 0.732


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 159.28it/s]


Valid loss: 0.544, accuracy 0.766
Epoch 4


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:11, 119.80it/s]


Train loss: 0.569, accuracy 0.778


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 307.98it/s]


Valid loss: 0.557, accuracy 0.758
Epoch 5


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:03<00:22, 63.36it/s]


Train loss: 0.575, accuracy 0.776


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 311.90it/s]


Valid loss: 0.477, accuracy 0.806
Epoch 6


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:03<00:21, 66.02it/s]


Train loss: 0.543, accuracy 0.792


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 293.00it/s]


Valid loss: 0.519, accuracy 0.786
Epoch 7


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:14, 97.05it/s]


Train loss: 0.520, accuracy 0.783


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 159.35it/s]


Valid loss: 0.542, accuracy 0.753
Epoch 8


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:12, 111.90it/s]


Train loss: 0.521, accuracy 0.794


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 195.67it/s]


Valid loss: 0.549, accuracy 0.752
Epoch 9


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:17, 79.69it/s]


Train loss: 0.508, accuracy 0.799


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 331.55it/s]


Valid loss: 0.503, accuracy 0.795
Epoch 10


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:17, 78.32it/s]


Train loss: 0.495, accuracy 0.800


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 208.62it/s]


Valid loss: 0.514, accuracy 0.793
Epoch 11


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:13, 101.30it/s]


Train loss: 0.439, accuracy 0.798


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 135.91it/s]


Valid loss: 0.470, accuracy 0.802
Epoch 12


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:16, 86.60it/s]


Train loss: 0.468, accuracy 0.812


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 158.45it/s]


Valid loss: 0.458, accuracy 0.804
Epoch 13


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:12, 115.21it/s]


Train loss: 0.449, accuracy 0.825


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 200.29it/s]


Valid loss: 0.420, accuracy 0.829
Epoch 14


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:17, 78.99it/s]


Train loss: 0.484, accuracy 0.796


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 323.26it/s]


Valid loss: 0.425, accuracy 0.844
Epoch 15


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:19, 73.15it/s]


Train loss: 0.502, accuracy 0.789


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 217.40it/s]


Valid loss: 0.565, accuracy 0.755
Epoch 16


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:18, 75.08it/s]


Train loss: 0.483, accuracy 0.802


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 210.35it/s]


Valid loss: 0.447, accuracy 0.823
Epoch 17


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:18, 74.37it/s]


Train loss: 0.472, accuracy 0.809


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 211.47it/s]


Valid loss: 0.451, accuracy 0.770
Epoch 18


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:11, 118.78it/s]


Train loss: 0.458, accuracy 0.820


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 160.68it/s]


Valid loss: 0.415, accuracy 0.797
Epoch 19


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:14, 98.30it/s]


Train loss: 0.453, accuracy 0.809


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 240.05it/s]


Valid loss: 0.464, accuracy 0.823
Epoch 20


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:20, 67.58it/s]


Train loss: 0.432, accuracy 0.814


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 247.06it/s]


Valid loss: 0.438, accuracy 0.818
Epoch 21


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:03<00:22, 61.23it/s]


Train loss: 0.435, accuracy 0.828


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 261.10it/s]


Valid loss: 0.376, accuracy 0.840
Epoch 22


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:20, 68.55it/s]


Train loss: 0.452, accuracy 0.818


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 240.59it/s]


Valid loss: 0.409, accuracy 0.835
Epoch 23


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:18, 75.28it/s]


Train loss: 0.456, accuracy 0.802


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 165.42it/s]


Valid loss: 0.392, accuracy 0.842
Epoch 24


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:16, 83.77it/s]


Train loss: 0.437, accuracy 0.823


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 176.90it/s]


Valid loss: 0.434, accuracy 0.806
Epoch 25


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:11, 126.90it/s]


Train loss: 0.471, accuracy 0.811


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 175.20it/s]


Valid loss: 0.383, accuracy 0.834
Epoch 26


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:17, 82.37it/s]


Train loss: 0.436, accuracy 0.817


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 323.34it/s]


Valid loss: 0.416, accuracy 0.817
Epoch 27


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:03<00:24, 56.75it/s]


Train loss: 0.425, accuracy 0.829


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 182.29it/s]


Valid loss: 0.404, accuracy 0.830
Epoch 28


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:03<00:22, 62.59it/s]


Train loss: 0.416, accuracy 0.839


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 247.98it/s]


Valid loss: 0.365, accuracy 0.852
Epoch 29


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:16, 87.12it/s]


Train loss: 0.428, accuracy 0.821


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 168.92it/s]


Valid loss: 0.481, accuracy 0.803
Epoch 30


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:15, 93.38it/s]


Train loss: 0.443, accuracy 0.829


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:03<00:00, 123.49it/s]


Valid loss: 0.387, accuracy 0.834
Epoch 31


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:16, 84.79it/s]


Train loss: 0.441, accuracy 0.803


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:03<00:00, 125.43it/s]


Valid loss: 0.391, accuracy 0.838
Epoch 32


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:15, 93.39it/s]


Train loss: 0.409, accuracy 0.827


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 153.18it/s]


Valid loss: 0.398, accuracy 0.816
Epoch 33


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:14, 96.13it/s]


Train loss: 0.420, accuracy 0.811


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 153.45it/s]


Valid loss: 0.373, accuracy 0.850
Epoch 34


 12%|█████████████████▎                                                                                                                         | 199/1600 [00:01<00:12, 111.77it/s]


Train loss: 0.440, accuracy 0.820


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:02<00:00, 171.86it/s]


Valid loss: 0.437, accuracy 0.823
Epoch 35


 12%|█████████████████▍                                                                                                                          | 199/1600 [00:02<00:15, 88.32it/s]


Train loss: 0.396, accuracy 0.836


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:01<00:00, 290.64it/s]


Valid loss: 0.417, accuracy 0.810
Epoch 36


  8%|███████████▌                                                                                                                                | 132/1600 [00:02<00:24, 59.91it/s]


KeyboardInterrupt: 

In [145]:
import torch.nn as nn

class Episode(object):
    def __init__(self):
        self.reward_history = []
        self.reward_history = []
        self.loss_history = []
        self.gamma = 0.99
        
episode = Episode()


def get_graph_from_obs(self, sample_observation, sample_action_set):
       
        sample_action_id = sample_action_set[0] # doen't matter won't be used
        # We note on which variables we were allowed to branch, the scores as well as the choice 
        # taken by expert branching (relative to the candidates)
        candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))
        candidate_choice = torch.where(candidates == sample_action_id)[0][0]

        graph = BipartiteNodeData(sample_observation.antenna_features, sample_observation.edge_index, 
                                  sample_observation.edge_features, sample_observation.variable_features,
                                  candidates, candidate_choice)
        
        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = sample_observation.antenna_features.shape[0] + sample_observation.variable_features.shape[0]
        
        return graph

def select_action(policy_net, obs, action_set, temperature):
    sf = nn.Softmax()
    graph = get_graph_from_obs(obs, action_set)
    logits = policy(graph.antenna_features, graph.edge_index, graph.edge_features, graph.variable_features)
    prob = Categorical(sf(logits/temperature))
    action_id = prob.sample()
    return action_id


In [184]:
import torch.nn as nn
from torch.distributions import Categorical

a = nn.Softmax()
b = a(torch.rand(5))
m = Categorical(b)
m.sample()

  b = a(torch.rand(5))


tensor(3)

In [None]:
def main(num_episodes):
    running_reward = -1
    for i in range(num_episodes):
        obs, action_set, reward, done, _ = env.reset()
        