In [1]:
from data_generator import DataGenerator
from multiprocessing import Pool

In [2]:
MAX_NOBS = 10000
NUM_SIMULATIONS = 100
BATCH_SIZE = 64
NUM_EPOCHS = 1000

In [3]:
data_generator = DataGenerator(MAX_NOBS, NUM_SIMULATIONS)

In [4]:
def generate_batch_loop():
    X_train = []
    y_train = []
    for _ in range(BATCH_SIZE):
        x, y = data_generator.generate()
        X_train.append(x)
        y_train.append(y)

    return X_train, y_train

In [5]:
generate_batch_loop()

([(6019, 7111, 0.3024, 0.6217),
  (9870, 3694, 0.7048, 0.4655),
  (3674, 6911, 0.1655, 0.4103),
  (6226, 7861, 0.4221, 0.7129),
  (1951, 8199, 0.1762, 0.3407),
  (1235, 4639, 0.1034, 0.0211),
  (4545, 4835, 0.265, 0.2269),
  (6204, 3055, 0.7998, 0.4691),
  (5979, 8354, 0.1895, 0.397),
  (9304, 9719, 0.9225, 0.5558),
  (7439, 6900, 0.3828, 0.927),
  (5492, 8324, 0.0076, 0.2215),
  (6263, 6506, 0.3814, 0.8759),
  (5411, 4828, 0.3346, 0.1634),
  (3543, 5886, 0.731, 0.7136),
  (9268, 5319, 0.3944, 0.3905),
  (1255, 4138, 0.1715, 0.4858),
  (2885, 2967, 0.9029, 0.6732),
  (1947, 9059, 0.0031, 0.4824),
  (4565, 9627, 0.1687, 0.0615),
  (9809, 1321, 0.8656, 0.4747),
  (7735, 4094, 0.6031, 0.1169),
  (3908, 1176, 0.254, 0.9963),
  (5957, 8158, 0.4354, 0.7932),
  (7342, 2505, 0.2459, 0.7331),
  (7328, 4851, 0.4106, 0.7769),
  (6022, 2845, 0.376, 0.4615),
  (3382, 6533, 0.3923, 0.0991),
  (6977, 2226, 0.8622, 0.7769),
  (4309, 7635, 0.2083, 0.8487),
  (4106, 2295, 0.749, 0.8304),
  (7410, 3586, 

In [6]:
%timeit generate_batch_loop()

1.78 s ± 56.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
def generate_batch_pool():
    with Pool() as pool:
        pool_output = pool.starmap(data_generator.generate, [() for _ in range(BATCH_SIZE)])
    X_train = [x for x, y in pool_output]
    y_train = [y for x, y in pool_output]
    return X_train, y_train

In [8]:
generate_batch_pool()

([(6168, 9037, 0.9997, 0.7653),
  (5739, 9428, 0.3148, 0.296),
  (3329, 9145, 0.9616, 0.0206),
  (4833, 9221, 0.4391, 0.0398),
  (7577, 6195, 0.5603, 0.2183),
  (8936, 8705, 0.7321, 0.3871),
  (4219, 8964, 0.2249, 0.6456),
  (1105, 9978, 0.489, 0.1243),
  (4499, 4046, 0.9878, 0.448),
  (3014, 5817, 0.6322, 0.5827),
  (5760, 7007, 0.2081, 0.0046),
  (4572, 3057, 0.1653, 0.3172),
  (8161, 6212, 0.6521, 0.3436),
  (9530, 3211, 0.6956, 0.9735),
  (7751, 6729, 0.4141, 0.9301),
  (9466, 8587, 0.2812, 0.9087),
  (2437, 1872, 0.5953, 0.6248),
  (6385, 9551, 0.6048, 0.1086),
  (546, 7201, 0.7999, 0.2504),
  (1430, 7028, 0.9315, 0.5184),
  (2769, 6482, 0.192, 0.3001),
  (4358, 6546, 0.0979, 0.2709),
  (4835, 2436, 0.4231, 0.187),
  (8180, 7833, 0.6865, 0.4746),
  (2386, 3226, 0.2417, 0.4736),
  (9045, 1935, 0.0404, 0.6913),
  (4165, 8446, 0.4674, 0.404),
  (3670, 2195, 0.1507, 0.9425),
  (708, 1380, 0.8235, 0.7851),
  (5436, 8502, 0.0438, 0.0177),
  (9428, 1823, 0.7706, 0.2335),
  (5886, 1321, 0

In [9]:
%timeit generate_batch_pool()

352 ms ± 22.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
