A simulation to check if there is no problem with the proof of ALSA following the conditions of SA.

In [67]:
import numpy as np
import matplotlib.pyplot as plt
import itertools

%load_ext autoreload
%autoreload 2

#create an array with 40 rows and 3 columns where each column is a permutation of 1,...,40
base_arange = np.arange(30)
base_array = np.concatenate([np.random.permutation(base_arange).reshape(-1,1) for i in range(3)], axis=1)


K = 25
m = 7




The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [68]:
def compute_energy(arr, lst_of_idxs):
    reduced_arr = arr[[lst_of_idxs]].squeeze(0)
    min_ucb = np.min(reduced_arr[:,0], axis=0)
    g_sum = np.sum(reduced_arr[:,1], axis=0)/m
    p_sum = np.sum(reduced_arr[:,2], axis=0)/m
    return min_ucb + g_sum + p_sum


def groups_substract(arr1, arr2):
    return np.array([x for x in arr1 if x not in arr2])


users_idxs_comb = list(itertools.combinations([x for x in range(K)], m))
# permute the users_idxs_comb to make the order of the users random
np.random.shuffle(users_idxs_comb)
winning_comb = None
best_score = 0
for comb in users_idxs_comb:
    score = compute_energy(base_array, comb)
    if score > best_score:
        best_score = score
        winning_comb = comb


s_star = base_array[[winning_comb]].squeeze(0)



s = base_array[[np.random.choice(np.arange(K), m, replace=False)]].squeeze(0)

def relaxtion_iter(s, s_star, idx, last_step=False, initial_step = False, verbose = True):

    print("*"*10 + f"Starting process with column idx: {idx}" + "*"*10)
    #sorting the arrays:
    s = s[np.argsort(s[:,idx])]
    s_star = s_star[np.argsort(s_star[:,idx])]
    s_star_left = groups_substract(s_star, s)
    if verbose:
        print("s:", s)
        print("s_star:", s_star)
        print("s_star_left:", s_star_left)
        

    if len(s_star_left) == 0:
        return s, s_star
    counter = 0
    #delete the user with the lowest value in the first column from s

    initial_deleted_row = s[np.argmin(s[:,idx])]
    initial_added_row = s_star[np.argmin(s_star[:,idx])]
    backup_added_row = s_star_left[np.argmax(s_star_left[:,idx])]
    s = np.delete(s, np.argmin(s[:,idx]), axis=0)
    
    if initial_added_row not in s:
        s = np.concatenate([s, initial_added_row.reshape(1,3)], axis=0)
    else:
        #if the added row is already in s, we'll add to s the argmax in s_star instead
        s = np.concatenate([s, backup_added_row.reshape(1,3)], axis=0)



    #sorting the arrays:
    s = s[np.argsort(s[:,idx])]
    s_star = s_star[np.argsort(s_star[:,idx])]
    s_star_left = groups_substract(s_star, s)
    if len(s_star_left) == 0:
        return s, s_star
    s_star_left = s_star_left[np.argsort(s_star_left[:,idx])]

    if verbose:
        print("*"*10 + f"iteration {counter}, managing column idx: {idx}" + "*"*10)
        print("s:", s)
        print("s_star:", s_star)
        print("s_star_left:", s_star_left)
        print("deleted_row:", initial_deleted_row)
        print("added_row:", initial_added_row)
    counter += 1
    

    condition_1 = ~np.all(s==s_star)
    condition_2 = min(s_star[:,idx]) != min(s[:,idx])

    condition = condition_1 if last_step else condition_2

   
    while condition:
        s_star_left = groups_substract(s_star, s)
        s_minus_s_star = groups_substract(s, s_star)
        s_minus_s_star = s_minus_s_star[np.argsort(s_minus_s_star[:,idx])]
        proposed_deleted_row = s[np.argmin(s[:,idx])]
        proposed_added_row = s_star_left[np.argmax(s_star_left[:,idx])]

        if last_step:
            if (proposed_added_row[-1] >= proposed_deleted_row[-1]):
                s = np.delete(s, np.argmin(s[:,idx]), axis=0)
                s = np.concatenate([s, proposed_added_row.reshape(1,3)], axis=0)

            else:
                #find the row in s which has the same value in the last column as s_minus_s_star[-1,-1]
                

                proposed_deleted_row = s[np.where(s[:,-1] == s_minus_s_star[-1,-1])[0][0]]
                s = np.delete(s, np.where(s[:,-1] == s_minus_s_star[-1,-1])[0][0], axis=0)
                s = np.concatenate([s, proposed_added_row.reshape(1,3)], axis=0)
        
        else:
            s = np.delete(s, np.argmin(s[:,idx]), axis=0)
            s = np.concatenate([s, proposed_added_row.reshape(1,3)], axis=0)
        

        #sort the arrays once again
        s_star_left = groups_substract(s_star, s)  
        s = s[np.argsort(s[:,idx])]
        s_star = s_star[np.argsort(s_star[:,idx])]
        if len(s_star_left) != 0:
            s_star_left = s_star_left[np.argsort(s_star_left[:,idx])]
        if verbose:
            print("*"*10 + f"iteration {counter}, managaing column idx: {idx}" + "*"*10)
            print("s:", s)
            print("s_star:", s_star)
            print("s_star_left:", s_star_left)
            print("deleted_row:", proposed_deleted_row)
            print("added_row:", proposed_added_row)
        counter += 1
        condition_1 = ~np.all(s==s_star)
        condition_2 = min(s_star[:,idx]) != min(s[:,idx])
        condition = condition_1 if last_step else condition_2
        if len(s_star_left) == 0:
            if verbose:
                print("length of s_star_left is 0, breaking loop")
            break

    if verbose:
        print("Finished with column idx:", idx)

    return s, s_star

s, s_star = relaxtion_iter(s, s_star, 0, initial_step=True)
s, s_star = relaxtion_iter(s, s_star, 1)
s, s_star = relaxtion_iter(s, s_star, 2, last_step=True)



print("Finished process:")
print("s:", s)
print("s_star:", s_star)
if np.all(s==s_star):
    print("The process converged!")


    
    

**********Starting process with column idx: 0**********
s: [[ 1 20 15]
 [ 7 12 29]
 [ 9 15  1]
 [11 11  7]
 [13 16  3]
 [24  2 13]
 [29 13 20]]
s_star: [[19 21 19]
 [20 28  4]
 [21 25  2]
 [25 23 21]
 [26 14  9]
 [27 22 18]
 [29 13 20]]
s_star_left: [[19 21 19]
 [20 28  4]
 [21 25  2]
 [25 23 21]
 [26 14  9]
 [27 22 18]]
**********iteration 0, managing column idx: 0**********
s: [[ 7 12 29]
 [ 9 15  1]
 [11 11  7]
 [13 16  3]
 [19 21 19]
 [24  2 13]
 [29 13 20]]
s_star: [[19 21 19]
 [20 28  4]
 [21 25  2]
 [25 23 21]
 [26 14  9]
 [27 22 18]
 [29 13 20]]
s_star_left: [[20 28  4]
 [21 25  2]
 [25 23 21]
 [26 14  9]
 [27 22 18]]
deleted_row: [ 1 20 15]
added_row: [19 21 19]
**********iteration 1, managaing column idx: 0**********
s: [[ 9 15  1]
 [11 11  7]
 [13 16  3]
 [19 21 19]
 [24  2 13]
 [27 22 18]
 [29 13 20]]
s_star: [[19 21 19]
 [20 28  4]
 [21 25  2]
 [25 23 21]
 [26 14  9]
 [27 22 18]
 [29 13 20]]
s_star_left: [[20 28  4]
 [21 25  2]
 [25 23 21]
 [26 14  9]]
deleted_row: [ 7 12 

In [69]:
a = 1
b = 0
c=True

condition_1 = a==b
condition_2 = a==1

condition = condition_1 if c else condition_2

if condition:
    print(a)

