In [6]:
import gzip
import pickle
import numpy as np
from pathlib import Path
import torch
from  acr_bb import Observation, ACRBBenv, DefaultBranchingPolicy, RandomPolicy, LinearObservation
from fcn_training import FCNDataset, FCNBranchingPolicy, FCNNodeSelectionPolicy 

MAX_SAMPLES = 10000

N = 8 # antennas
M = 4 # users
expert_prob = 0.5

def instance_generator(M, N):
    while 1:
        yield np.random.randn(2,N,M)

# instances = np.random.randn(MAX_SAMPLES, 2, N, M)
instances = instance_generator(M,N)

env = ACRBBenv(node_select_policy_path='default')

expert_policy = DefaultBranchingPolicy()
random_policy = RandomPolicy()

episode_counter, sample_counter = 0, 0
negative_counter = 0
Path('positive_node_samples3/').mkdir(exist_ok=True)
Path('negative_node_samples3/').mkdir(exist_ok=True)

# We will solve problems (run episodes) until we have saved enough samples
max_samples_reached = False

while not max_samples_reached:
    episode_counter += 1
    
    observation_list = []
    node_indices = []
    observation, action_set, reward, done, _ = env.reset(next(instances))
    node_indices.append(env.active_node.node_index)
    observation_list.append(observation)
    while not done and reward > -5:
        if np.random.rand(1) > expert_prob:
            action_id = expert_policy.select_variable(observation, action_set)
        else:
            action_id = random_policy.select_variable(observation, action_set)

        observation, action_set, reward, done, _ = env.step(action_id)
    
    for node in env.all_nodes:
        if node.optimal:
            for i in range(len(node_indices)):
                if node_indices[i] == node.node_index:
                    data = [observation_list[i], True]
                    break
        else:
            for i in range(len(node_indices)):
                if node_indices[i] == node.node_index:
                    data = [observation_list[i], False]
                    break

            
        if not max_samples_reached:
            if node.optimal:
                filename = f'positive_node_samples3/sample_{sample_counter}.pkl'
                sample_counter += 1
                
            else:
                filename = f'negative_node_samples3/sample_{negative_counter}.pkl'
                negative_counter += 1
            with gzip.open(filename, 'wb') as f:
                pickle.dump(data, f)
            # If we collected enough samples, we finish the current episode but stop saving samples
            if sample_counter >= MAX_SAMPLES:
                max_samples_reached = True
                break;
    print(f"Episode {episode_counter}, {sample_counter}, {negative_counter} samples collected so far")


Episode 1, 26, 649 samples collected so far
Episode 2, 52, 1910 samples collected so far
Episode 3, 75, 2240 samples collected so far
Episode 4, 89, 2337 samples collected so far
Episode 5, 114, 3127 samples collected so far
Episode 6, 135, 5187 samples collected so far
Episode 7, 161, 6050 samples collected so far
Episode 8, 188, 7796 samples collected so far
Episode 9, 207, 7948 samples collected so far
Episode 10, 233, 8327 samples collected so far
Episode 11, 259, 9058 samples collected so far
Episode 12, 285, 9521 samples collected so far
Episode 13, 307, 10442 samples collected so far
Episode 14, 336, 13270 samples collected so far
Episode 15, 363, 15708 samples collected so far
Episode 16, 387, 17623 samples collected so far
Episode 17, 414, 18039 samples collected so far
Episode 18, 437, 18507 samples collected so far
Episode 19, 465, 19190 samples collected so far
Episode 20, 483, 19405 samples collected so far
Episode 21, 506, 20279 samples collected so far
Episode 22, 529, 2

Episode 168, 4243, 143347 samples collected so far
Episode 169, 4267, 144696 samples collected so far
Episode 170, 4293, 145261 samples collected so far
Episode 171, 4316, 145623 samples collected so far
Episode 172, 4343, 146787 samples collected so far
Episode 173, 4366, 148193 samples collected so far
Episode 174, 4394, 148656 samples collected so far
Episode 175, 4422, 152031 samples collected so far
Episode 176, 4449, 152889 samples collected so far
Episode 177, 4477, 153494 samples collected so far
Episode 178, 4505, 154289 samples collected so far
Episode 179, 4531, 154860 samples collected so far
Episode 180, 4553, 155473 samples collected so far
Episode 181, 4581, 157106 samples collected so far
Episode 182, 4607, 158007 samples collected so far
Episode 183, 4628, 159279 samples collected so far
Episode 184, 4647, 159825 samples collected so far
Episode 185, 4673, 160638 samples collected so far
Episode 186, 4697, 160987 samples collected so far
Episode 187, 4720, 161455 sampl

Episode 329, 8259, 284238 samples collected so far
Episode 330, 8285, 285205 samples collected so far
Episode 331, 8312, 287775 samples collected so far
Episode 332, 8334, 289614 samples collected so far
Episode 333, 8360, 290857 samples collected so far
Episode 334, 8388, 291392 samples collected so far
Episode 335, 8414, 291831 samples collected so far
Episode 336, 8439, 292961 samples collected so far
Episode 337, 8463, 294416 samples collected so far
Episode 338, 8489, 294713 samples collected so far
Episode 339, 8517, 295616 samples collected so far
Episode 340, 8541, 296129 samples collected so far
Episode 341, 8569, 296322 samples collected so far
Episode 342, 8598, 297162 samples collected so far
Episode 343, 8627, 298042 samples collected so far
Episode 344, 8649, 299301 samples collected so far
Episode 345, 8675, 299946 samples collected so far
Episode 346, 8707, 300675 samples collected so far
Episode 347, 8732, 302923 samples collected so far
Episode 348, 8758, 303888 sampl