In [84]:
import random
import time

import matplotlib.pyplot as plt
from matplotlib import patches, animation
from IPython.core.display import HTML

# [随机排列与Fisher-Yates算法](https://www.youtube.com/watch?v=1m68x5Gy5No)

In [235]:
def original_fisher_yates_shuffle(nums):
    shuffled_nums = []
    len_nums = len(nums)
    while len(shuffled_nums) < len_nums:
        random_index = random.randint(0, len(nums) - 1)
        shuffled_nums.append(nums[random_index])
        nums = nums[:random_index] + nums[random_index+1:]
    return shuffled_nums

In [236]:
def modern_fisher_yates_shuffle(nums):
    shuffled_nums = nums
    for i in range(len(nums)):
        random_index = random.randint(i, len(nums) - 1)
        shuffled_nums[i], shuffled_nums[random_index] = shuffled_nums[random_index], shuffled_nums[i]
    return shuffled_nums

## Simple test

In [237]:
nums = [1,2,3]
num_iter = 10000

In [238]:
start_time = time.time()
for i in range(num_iter):
    shuffled_nums = original_fisher_yates_shuffle(nums)
print("{} iterations of permutation cost {} sec.".format(num_iter, time.time() - start_time))

10000 iterations of permutation cost 0.06103348731994629 sec.


In [239]:
start_time = time.time()
for i in range(num_iter):
    shuffled_nums = modern_fisher_yates_shuffle(nums)
print("{} iterations of permutation cost {} sec.".format(num_iter, time.time() - start_time))

10000 iterations of permutation cost 0.06199336051940918 sec.


## Plot Each Step

In [240]:
nums = [1,2,3,4,5,6,7]
num_iter = 10000

In [241]:
# def plot_nums(nums, index1, index2, ans):
def plot_nums(nums, iteration, random_index):
    length = len(nums)
    height = 1
    plt.figure(figsize=(length, 4))
    plt.xlim([-1, length+1])
    plt.ylim([-1.5, 2.5])
    # plot arr
    rect = patches.Rectangle((0, 0), length, height, lw=1, fill=0, ec='steelblue')
    plt.gca().add_patch(rect)
    for i in range(length):
        plt.plot([i, i], [height, 0], c='steelblue', lw=1)
        plt.text(i + 0.5 - 0.08 * len(str(i)), 1 + 0.08, i, fontsize=10, color='steelblue')
        plt.text(i + 0.5 - 0.09, 0.5 - 0.09, nums[i], fontsize=14, color='k')
    # plot current selected num
    rect = patches.Rectangle((random_index, 0), 1, height, lw=0, fill=1, fc=(0.19607843, 0.80392157, 0.19607843, 0.2))
    plt.gca().add_patch(rect)
    plt.arrow(random_index + 0.5, 1.8, 0, -0.25, head_width=0.1, color='k')
    # plot current iteration num
    rect = patches.Rectangle((iteration, 0), 1, height, lw=0, fill=1, fc=(0.19607843, 0.80392157, 0.19607843, 0.2))
    plt.gca().add_patch(rect)
    plt.arrow(iteration + 0.5, 1.8, 0, -0.25, head_width=0.1, color='k')
    # paint fixed nums
    for i in range(iteration):
        rect = patches.Rectangle((i, 0), 1, height, lw=0, fill=1, fc=(0.2, 0.2, 0.2, 0.2))
        plt.gca().add_patch(rect)
    plt.title('iteration = {}\nselected = {}'.format(iteration, random_index))    
    plt.axis('off')
    plt.tight_layout()
    # plt.savefig('{}.jpg'.format(index1))
    plt.show()

In [242]:
def modern_fisher_yates_shuffle_plot(nums):
    shuffled_nums = nums
    for i in range(len(nums) - 1):
        random_index = random.randint(i, len(nums) - 1)
        plot_nums(nums, i, random_index)
        shuffled_nums[i], shuffled_nums[random_index] = shuffled_nums[random_index], shuffled_nums[i]
        plot_nums(nums, i, random_index)
    return shuffled_nums

In [245]:
# modern_fisher_yates_shuffle_plot(nums)

## Animation

In [250]:
nums = [1,2,3,4,5,6,7,8,9]
length = len(nums)
height = 1
history_random_index = []
history_nums = []
history_index = []
for i in range(len(nums) - 1):
    random_index = random.randint(i, len(nums) - 1)
    history_index.append(i)
    history_random_index.append(random_index)
    history_nums.append(nums.copy())
    nums[i], nums[random_index] = nums[random_index], nums[i]
    history_index.append(i)
    history_random_index.append(random_index)
    history_nums.append(nums.copy())
    
# setup figure
fig = plt.figure(figsize=(int(length * 1.5), int(4 * 1.5)))
ax = fig.add_subplot(111)
ax.set_xlim(-1, length + 1)
ax.set_ylim(-1.5, 2.5)


# animation function, iterate through the result
def animate(frame, *fargs):
    ax.patches = []
    ax.clear()
    ax.set_xlim(-1, length + 1)
    ax.set_ylim(-1.5, 2.5)    
    for i in range(length):
        ax.plot([i, i], [height, 0], c='steelblue', lw=1)[0]
        ax.text(i + 0.5 - 0.08 * len(str(i)), 1 + 0.08, i, fontsize=10, color='steelblue')
        ax.text(i + 0.5 - 0.09, 0.5 - 0.09, history_nums[frame][i], fontsize=14, color='k')
    arr_rect = patches.Rectangle((0, 0), length, height, lw=1, fill=0, ec='steelblue')
    ax.add_patch(arr_rect)
    if frame % 2 == 0:
        temp_frame = frame
        current_color = 'g'
        selected_color = 'r'
    else:
        temp_frame = frame - 1
        current_color = 'r'
        selected_color = 'g'
    # plot current selected num
    rect = patches.Rectangle((history_random_index[temp_frame], 0), 1, height, lw=0, fill=1, fc=(0.19607843, 0.80392157, 0.19607843, 0.2))
    ax.add_patch(rect)
    arrow = ax.arrow(history_random_index[temp_frame] + 0.5, 1.8, 0, -0.25, head_width=0.1, color=selected_color)
    ax.add_patch(arrow)
    # plot current index num
    rect = patches.Rectangle((history_index[temp_frame], 0), 1, height, lw=0, fill=1, fc=(0.19607843, 0.80392157, 0.19607843, 0.2))
    ax.add_patch(rect)
    arrow = ax.arrow(history_index[temp_frame] + 0.5, 1.8, 0, -0.25, head_width=0.1, color=current_color)
    ax.add_patch(arrow)
    text = 'iteration = {0}\nrandomly select number from {1} to {2} = {3}\nswap nums[{4}] = {5} and nums[{6}] = {7}'.format(
        history_index[temp_frame],
        history_index[temp_frame],
        length - 1,
        history_random_index[temp_frame],
        history_index[temp_frame],
        history_nums[temp_frame][history_index[temp_frame]],
        history_random_index[temp_frame],
        history_nums[temp_frame][history_random_index[temp_frame]],
    )
  
    plt.title(text)
    
    # paint fixed nums
    for i in range(history_index[frame]):
        rect = patches.Rectangle((i, 0), 1, height, lw=0, fill=1, fc=(0.2, 0.2, 0.2, 0.2))
        ax.add_patch(rect)
    return 
    
# hide axis
anim = animation.FuncAnimation(fig, 
                               func=animate,
                               fargs=(history_nums, history_random_index, history_index),
                               frames=len(history_nums),
                               interval=0.9487*1000,
                              )
anim.save('modern_fisher_yates_shuffle.gif', writer='pillow', fps=1)
# plt.axis('off')
animation_html = HTML(anim.to_jshtml())
# prevent plt show the final plot
plt.close()
animation_html