## Ant Colony Optimization for a Multi-Processor Scheduling Problem with JAX

In [1]:
import jax
import jax.numpy as jnp

In [25]:
n_ants, n_machines, n_jobs = (10, 10, 100)
eva_coeff = 0.9
alpha = 0.01
beta = 1 - alpha
epochs = 10000

In [15]:
def populate(seed, n_ants, n_machines, n_jobs):

    key = jax.random.key(seed)
    key, subkey1, subkey2 = jax.random.split(key, 3)

    speed = jnp.abs(jax.random.normal(subkey1, shape=(n_machines, )) * 10 + 0.3)
    speed = jnp.ones_like(speed)
    time = jnp.abs(jax.random.uniform(subkey2, shape=(n_jobs, ), minval=50, maxval=200))

    assign_matrices = jnp.ones(shape=(n_ants, n_jobs, n_machines))
    pheromones = jnp.ones_like(assign_matrices)
    return speed, time, assign_matrices, pheromones

In [16]:
def assign_path(keys, prob_matrix, n_machines):
    def sample_single_job(subkey, prob_vector):
        sampled_index = jax.random.categorical(subkey, prob_vector)
        return jnp.eye(n_machines, dtype=int)[sampled_index]
    return jax.vmap(lambda subkey, prob_vector : jax.vmap(sample_single_job)(subkey, prob_vector))(keys, prob_matrix)

In [17]:
def calculate_cmax(assign_matrices, time, speed):
    full_L_k = jnp.dot(time, assign_matrices) / speed
    L_k = jnp.max(full_L_k, axis=1)
    best_assign_matrix = assign_matrices[jnp.argmin(L_k)]
    return full_L_k, L_k, best_assign_matrix

In [18]:
def multiply_batch_by_element(batch, element):
    return batch * element

deposits_pheromone = jax.vmap(multiply_batch_by_element)

In [19]:
def update_pheromone(
    pheromones,
    delta_sum_pheromones,
    best_assign_matrix_cur,
    best_assign_matrix_all,
    rate,
):
    output = (
        rate * pheromones
        + delta_sum_pheromones * jnp.ones_like(best_assign_matrix_cur)
        + delta_sum_pheromones * best_assign_matrix_all
        + delta_sum_pheromones * best_assign_matrix_cur
    )
    return output

In [20]:
def update_probability(speed, pheromones, n_ants, n_jobs, a=0.1, b=0.9):
    vector = speed / jnp.sum(speed)
    matrix_speed = jnp.tile(jnp.tile(vector, (n_jobs, 1)), (n_ants, 1, 1))

    sum_machines_pheromone = jnp.sum(pheromones, axis=2)
    matrix_pheromone = jax.vmap(
        lambda pheromones_inner, sum_mach_phe_inner: jax.vmap(lambda x, y: x / y)(
            pheromones_inner, sum_mach_phe_inner
        )
    )(pheromones, sum_machines_pheromone)

    res_temp = (a * matrix_speed) + (b * matrix_pheromone)
    return jnp.log(res_temp)

In [29]:
min = jnp.inf

speed, time, assign_matrices, pheromones = populate(42, n_ants, n_machines, n_jobs)
best_assign_matrix_all = jnp.ones_like(assign_matrices[0])

i = 0

while i < epochs:

    probability = update_probability(speed, pheromones, n_ants, n_jobs, a=alpha, b=beta)

    i=i+1

    key = jax.random.key(i)
    keys = jax.random.split(key, n_ants * n_jobs).reshape(n_ants, n_jobs)

    assign_matrices = assign_path(keys, probability, n_machines)
    full_L_k, L_k, best_assign_matrix_cur = calculate_cmax(assign_matrices, time, speed)

    if jnp.min(L_k) < min:
        best_assign_matrix_all = best_assign_matrix_cur
        min = jnp.min(L_k)
        print(f"Min: {min} at step {i}\n")

    delta_all_pheromones = deposits_pheromone(assign_matrices, 1 / L_k )
    delta_sum_pheromones = jnp.sum(delta_all_pheromones, axis=0)

    pheromones = update_pheromone(pheromones, delta_sum_pheromones, best_assign_matrix_cur, best_assign_matrix_all, eva_coeff)


Min: 1699.10546875 at step 1

Min: 1680.71044921875 at step 2

Min: 1571.6370849609375 at step 3

Min: 1553.5579833984375 at step 13

Min: 1552.834228515625 at step 16

Min: 1483.0772705078125 at step 27

Min: 1479.481689453125 at step 67

Min: 1478.772705078125 at step 118

Min: 1464.34423828125 at step 131

Min: 1447.124267578125 at step 139

Min: 1420.646484375 at step 147

Min: 1398.9163818359375 at step 213

Min: 1377.923828125 at step 444

Min: 1360.058349609375 at step 492

Min: 1359.8958740234375 at step 755

Min: 1347.81005859375 at step 792

Min: 1340.125732421875 at step 1382

Min: 1318.53662109375 at step 4723

Min: 1296.1610107421875 at step 5877

Min: 1292.5955810546875 at step 5947

Min: 1271.0244140625 at step 5959

Min: 1263.886474609375 at step 6139

Min: 1251.10302734375 at step 6181

Min: 1247.943359375 at step 6319

Min: 1246.190673828125 at step 6834

Min: 1243.21484375 at step 7398

Min: 1238.64111328125 at step 7577

