In [3]:
import numpy as np

In [4]:
# group_size_list is a list, e.g. 200 total items, then [100, 50, 50]
def initial_round(num_items, group_size_list):
    if sum(group_size_list) != num_items:
        raise ValueError("Invalid group sizes")
    num_groups = len(group_size_list)
    group_ids = list(range(1, num_groups + 1))
    return np.repeat(group_ids, group_size_list)


In [5]:
def cal_trans(positions, i, j):
    num_items = len(positions)
    if i >= num_items or j >= num_items or i < 0 or j < 0:
        raise ValueError(f"Invalid inputs: {i}, {j}")
    if i == j:
        # same item, stay in same group penalty
        return 1
    if positions[i] == positions[j]:
        # in same group
        return 2
    return 0
    

In [6]:
def cal_pair_penalty_one_round(p1, p2, num_groups):
    if len(p1) != len(p2):
        raise ValueError("Two position array have different sizes")
    penalty = 0
    for i in range(num_groups):
        group_id = i + 1
        p2_in_cur_group = sorted(np.where(p2 == group_id)[0].tolist()) 
        for j in range(len(p2_in_cur_group) - 1):
            item_j = p2_in_cur_group[j]
            for k in range(j + 1, len(p2_in_cur_group)):
                item_k = p2_in_cur_group[k]
                trans = cal_trans(p1, item_j, item_k)
                penalty += trans
    return penalty

def cal_pair_penalty(p_history, p2, num_groups):
    penalty = 0;
    for p1 in p_history:
        penalty += cal_pair_penalty_one_round(p1, p2, num_groups)
    return penalty


def cal_stay_penalty(p_history, p2):
    penalty = 0
    num_items = len(p2)
    for i in range(len(p_history)):
        p1 = p_history[i] 
        if len(p1) != num_items:
            raise ValueError("Invalid input sizes")
        for j in range(num_items):
            if p1[j] == p2[j]: 
                penalty += cal_trans(p1, i, i) 
    return penalty


def cal_penalty(p_history, p2, num_groups):
    return cal_pair_penalty(p_history, p2, num_groups) + cal_stay_penalty(p_history, p2) 
        

In [7]:
def swap(p, i, j, create_new=True):
    new_p = p
    if create_new:
        new_p = p.copy()
    new_p[i], new_p[j] = new_p[j], new_p[i]
    return new_p

def random_swap(p, times=1, create_new=True):
    result = p
    for i in range(times):
        indices = np.random.choice(len(p), size=2, replace=False)  # Pick 2 unique indices
        result = swap(result, indices[0], indices[1], create_new)
    return result

def cal_penalty_delta(p_history, p21, p22, num_groups):
    penalty1 = cal_penalty(p_history, p21, num_groups) 
    penalty2 = cal_penalty(p_history, p22, num_groups)
    return (penalty2 - penalty1), penalty1, penalty2


def cal_annealing_prob(T, delta):
    if delta < 0:
        return 1  # reducing penalty
    elif delta == 0 or T < 1e-6:
        return 0
    else:
        return np.exp(-delta/T)  

def decide_annealing(T, delta):
    prob = cal_annealing_prob(T, delta)
    rand = np.random.rand()
    return rand < prob, prob


In [8]:
num_groups = 20
num_items = 100
group_size_list = [5] * 20

num_swaps_attemps_per_round = 3000
T_init = 1000
T_decay = 0.99

num_rounds = 7

p_init = initial_round(num_items, group_size_list)
p_history = [p_init] 
p_prev_round = p_init
p = p_init.copy()

print(f"round 1: {p}")
for r in range(num_rounds - 1):
    T = T_init
    last_swap = 0
    for i in range(num_swaps_attemps_per_round):
        p_next = random_swap(p) 
        delta, penalty1, penalty2 = cal_penalty_delta(p_history, p, p_next, num_groups) 
        do_swap, prob = decide_annealing(T, delta)
        if do_swap: 
            p = p_next
            if i - last_swap > num_swaps_attemps_per_round / 30:
                print(f"round:{r + 2} step:{i} penalty:{penalty2}, delta:{delta}, prob:{prob}, T:{T}")
                last_swap = i
        if T > 1e-9:  
            T *= T_decay 
        if penalty2 == 0:
            print(f"-round:{r + 2} step:{i} penalty:{penalty2}, delta:{delta}, prob:{prob}, T:{T} early stop")
            break
    p_prev_round = p.copy()
    p_history.append(p_prev_round)
    print(f"round {r + 2}: {p}")

round 1: [ 1  1  1  1  1  2  2  2  2  2  3  3  3  3  3  4  4  4  4  4  5  5  5  5
  5  6  6  6  6  6  7  7  7  7  7  8  8  8  8  8  9  9  9  9  9 10 10 10
 10 10 11 11 11 11 11 12 12 12 12 12 13 13 13 13 13 14 14 14 14 14 15 15
 15 15 15 16 16 16 16 16 17 17 17 17 17 18 18 18 18 18 19 19 19 19 19 20
 20 20 20 20]
round:2 step:101 penalty:44, delta:2, prob:0.9944960127946747, T:362.37201786049667
round:2 step:203 penalty:25, delta:-1, prob:1, T:130.0003445350054
round:2 step:304 penalty:18, delta:2, prob:0.9584334073514639, T:47.10848717170972
round:2 step:405 penalty:25, delta:1, prob:0.9431031936053013, T:17.070797554767765
round:2 step:508 penalty:13, delta:-1, prob:1, T:6.0628783672166895
round:2 step:609 penalty:11, delta:1, prob:0.6343448679232916, T:2.1970174679710657
round:2 step:720 penalty:2, delta:-1, prob:1, T:0.7200126227424993
-round:2 step:723 penalty:0, delta:-1, prob:1, T:0.6916412525560801 early stop
round 2: [17  5 20 19 15 13 11  3  9  6 19 14 18 10  7 16 17  1  7 12