In [3]:
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import math
import random

MIN_LIST_LEN = 16
MAX_LIST_LEN = 16
MAX_STEPS = 640

SUCCESS_REWARD = 0.5
STEP_REWARD = -0.3
COMPARISON_ENTROPY_MULTIPLIER = -0.00
SWAP_REWARD = 1.0
INVALID_ACTION_REWARD = -10.0
LONGTERM_GAMMA = 0.99
SHORTTERM_GAMMA = 0.7

EPS_START = 0.5
EPS_END = 0.05
EPS_DECAY = 1000
LR_SCHEDULER_GAMMA = 0.96
NUM_EPISODES = 200000
EPISODES_SAVE = 1000
OUTPUT_DIR = '/home/mcwave/code/autocode/datasets/rl_sort_transformer_easy/list16_transformer4_192_gamma07_step640_v3'

# Define the vocabulary
vocab = {
    'Comparison': 0,
    'Swap': 1,
    'less': 2,
    'equal': 3,
    'more': 4,
    '0': 5,
    '1': 6,
    '2': 7,
    '3': 8,
    '4': 9,
    '5': 10,
    '6': 11,
    '7': 12,
    '8': 13,
    '9': 14,
    '10': 15,
    '11': 16,
    '12': 17,
    '13': 18,
    '14': 19,
    '15': 20,
    'len1': 21,
    'len2': 22,
    'len3': 23,
    'len4': 24,
    'len5': 25,
    'len6': 26,
    'len7': 27,
    'len8': 28,
    'len9': 29,
    'len10': 30,
    'len11': 31,
    'len12': 32,
    'len13': 33,
    'len14': 34,
    'len15': 35,
    'len16': 36,
}
inv_vocab = {v: k for k, v in vocab.items()}

def compute_entropy(N, alpha=1):
    K = 2**N
    values = np.arange(K)
    unnormalized_probs = np.exp(-alpha * values)
    Z = unnormalized_probs.sum()
    probs = unnormalized_probs / Z
    return values, -np.log2(probs)

_, int_entropy = compute_entropy(4)

def get_entropy_of_integer(x):
    x = min(15, abs(x))
    return int_entropy[x]

def compute_min_delta_entropy(comparisons):
    # Initialize the result list to store minDelta values
    min_delta = None

    # Iterate through each pair in the comparisons list
    i = len(comparisons) - 1
    xi, yi = comparisons[i]
    if i == 0:
        # For i = 0, use the first case directly
        min_delta = (xi, min(yi, yi - xi), 0)
    else:
        # For i > 0, compute all possible options and select the minimal one
        options = []

        # Simple Entropy
        simple_entropy = (xi, min(yi, yi - xi), 0)
        options.append(simple_entropy)

        # First Delta Entropy
        xi_prev, yi_prev = comparisons[i - 1]
        first_delta_entropy = (xi - xi_prev, yi - yi_prev, 0)
        options.append(first_delta_entropy)

        # Second Delta Entropy (only valid for i > 1)
        if i > 1:
            xi_prev2, yi_prev2 = comparisons[i - 2]
            second_delta_entropy = (
                (xi - xi_prev) - (xi_prev - xi_prev2),
                (yi - yi_prev) - (yi_prev - yi_prev2),
                0,
            )
            options.append(second_delta_entropy)

        # Arbitrary Position Entropy (only valid for i > 1)
        for j in range(i):
            xj, yj = comparisons[j]
            arbitrary_position_entropy = (
                xi - xj,
                yi - yj,
                min(j, i - j),
            )
            options.append(arbitrary_position_entropy)

        # Find the option with the minimal sum
        min_delta = min(options, key=lambda t: sum([get_entropy_of_integer(x) for x in t]))

    entropy = sum([get_entropy_of_integer(x) for x in min_delta])
    if len(comparisons) == 1:
        return 3 * entropy
    else:
        return entropy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the environment
class SortingEnv:
    def __init__(self):
        self.max_steps = MAX_STEPS

    def reset(self):
        self.length = random.randint(MIN_LIST_LEN, MAX_LIST_LEN)
        self.list = [random.randint(1, 100) for _ in range(self.length)]
        while self.list == sorted(self.list):
            self.list = [random.randint(1, 100) for _ in range(self.length)]
        self.indices = None
        self.current_step = 0
        self.done = False
        initial_token = 'len{}'.format(self.length)
        return vocab[initial_token], self.list.copy()
    
    def get_list(self):
        return self.list
    
    def get_list_len(self):
        return len(self.list)

    def step(self, action_tokens):
        action = action_tokens[0]
        reward = -0.01  # default penalty
        response_token = None

        if action == vocab['Comparison']:
            if len(action_tokens) != 3:
                reward = INVALID_ACTION_REWARD
                self.done = True
                return response_token, reward, self.done, self.list.copy()
            index1 = action_tokens[1] - vocab['0']
            index2 = action_tokens[2] - vocab['0']
            if index1 >= self.length or index2 >= self.length or index1 < 0 or index2 < 0:
                reward = INVALID_ACTION_REWARD
                self.done = True
                return response_token, reward, self.done, self.list.copy()
            self.indices = (index1, index2)
            if self.list[index1] < self.list[index2]:
                response_token = vocab['less']
                reward = STEP_REWARD
            elif self.list[index1] == self.list[index2]:
                response_token = vocab['equal']
                reward = STEP_REWARD * 2
            else:
                response_token = vocab['more']
                reward = STEP_REWARD
        elif action == vocab['Swap']:
            if self.indices is None:
                reward = INVALID_ACTION_REWARD
                self.done = True
                return response_token, reward, self.done, self.list.copy()
            index1, index2 = self.indices
            prev_list = self.list.copy()
            self.list[index1], self.list[index2] = self.list[index2], self.list[index1]
            if self.list == sorted(self.list):
                reward = SUCCESS_REWARD
                self.done = True
            elif (index1 < index2 and prev_list[index1] > prev_list[index2] and self.list[index1] <= self.list[index2]) or \
                (index1 > index2 and prev_list[index1] < prev_list[index2] and self.list[index1] >= self.list[index2]):
                reward = SWAP_REWARD
            elif (index1 < index2 and prev_list[index1] < prev_list[index2] and self.list[index1] >= self.list[index2]) or \
                (index1 > index2 and prev_list[index1] > prev_list[index2] and self.list[index1] <= self.list[index2]):
                reward = -SWAP_REWARD
            else:
                reward = STEP_REWARD
            self.indices = None
        else:
            reward = INVALID_ACTION_REWARD
            self.done = True

        self.current_step += 1
        if self.current_step >= self.max_steps:
            self.done = True
        return response_token, reward, self.done, self.list.copy()

def decode(input_tokens, inv_vocab):
    return ' '.join([inv_vocab[x] for x in input_tokens])

# Quicksort Algorithm using the environment
def run_quicksort_episode(verbose=False):
    env = SortingEnv()
    initial_token_id, current_list = env.reset()
    done = False
    total_steps = 0
    total_reward = 0.0
    success = False
    length = env.length

    def quicksort(env, low, high):
        if env.done:
            return
        if low < high:
            pi = partition(env, low, high)
            if pi is None:
                return
            quicksort(env, low, pi - 1)
            quicksort(env, pi + 1, high)

    def partition(env, low, high):
        nonlocal total_reward, total_steps  # Declare nonlocal variables
        if env.done:
            return None
        pivot_index = high
        i = low - 1
        for j in range(low, high):
            # Compare arr[j] with arr[pivot_index]
            action_tokens = [vocab['Comparison'], vocab[str(j)], vocab[str(pivot_index)]]
            response_token, reward, done, current_list = env.step(action_tokens)
            if verbose:
                print(f"Comparison between indices {j} and {pivot_index}, result: {inv_vocab[response_token]}")
            total_reward += reward
            total_steps += 1
            if done:
                return None
            if response_token == vocab['less'] or response_token == vocab['equal']:
                i += 1
                if i != j:
                    # Need to swap arr[i] and arr[j]
                    # Perform 'Comparison' to set indices
                    action_tokens = [vocab['Comparison'], vocab[str(i)], vocab[str(j)]]
                    response_token, reward, done, current_list = env.step(action_tokens)
                    if verbose:
                        print(f"Comparison between indices {i} and {j} for swap setup, result: {inv_vocab[response_token]}")
                    total_reward += reward
                    total_steps += 1
                    if done:
                        return None
                    # Perform 'Swap'
                    action_tokens = [vocab['Swap']]
                    response_token, reward, done, current_list = env.step(action_tokens)
                    if verbose:
                        print(f"Swapped indices {i} and {j}")
                    total_reward += reward
                    total_steps += 1
                    if done:
                        return None
        # Swap arr[i+1] and arr[high] (pivot)
        if i + 1 != pivot_index:
            # Perform 'Comparison' to set indices
            action_tokens = [vocab['Comparison'], vocab[str(i + 1)], vocab[str(pivot_index)]]
            response_token, reward, done, current_list = env.step(action_tokens)
            if verbose:
                print(f"Comparison between indices {i + 1} and {pivot_index} for final swap, result: {inv_vocab[response_token]}")
            total_reward += reward
            total_steps += 1
            if done:
                return None
            # Perform 'Swap'
            action_tokens = [vocab['Swap']]
            response_token, reward, done, current_list = env.step(action_tokens)
            if verbose:
                print(f"Swapped indices {i + 1} and {pivot_index}")
            total_reward += reward
            total_steps += 1
            if done:
                return None
        return i + 1

    quicksort(env, 0, length - 1)

    if env.list == sorted(env.list):
        success = True
    else:
        success = False

    return success, total_steps, total_reward, env.list

if __name__ == "__main__":
    NUM_EPISODES = 100
    total_successes = 0
    total_steps = 0
    total_reward = 0.0

    for episode in range(NUM_EPISODES):
        success, steps, reward, sorted_list = run_quicksort_episode(verbose=False)
        total_steps += steps
        total_reward += reward
        if success:
            total_successes += 1
            print(f"Episode {episode}: Success in {steps} steps, reward: {reward:.2f}")
        else:
            print(f"Episode {episode}: Fail in {steps} steps, reward: {reward:.2f}")

    print(f"Success rate: {total_successes / NUM_EPISODES * 100:.2f}%")
    print(f"Average steps per episode: {total_steps / NUM_EPISODES:.2f}")
    print(f"Average reward per episode: {total_reward / NUM_EPISODES:.2f}")


Using device: cuda
Episode 0: Success in 81 steps, reward: -5.30
Episode 1: Success in 86 steps, reward: -0.90
Episode 2: Success in 82 steps, reward: -3.60
Episode 3: Success in 86 steps, reward: 1.00
Episode 4: Success in 100 steps, reward: -7.10
Episode 5: Success in 79 steps, reward: -1.10
Episode 6: Success in 100 steps, reward: 4.00
Episode 7: Success in 79 steps, reward: -3.40
Episode 8: Success in 76 steps, reward: -3.10
Episode 9: Success in 77 steps, reward: -2.40
Episode 10: Success in 101 steps, reward: -2.20
Episode 11: Success in 91 steps, reward: -6.30
Episode 12: Success in 100 steps, reward: 0.40
Episode 13: Success in 79 steps, reward: -5.00
Episode 14: Success in 83 steps, reward: 0.30
Episode 15: Success in 76 steps, reward: -3.80
Episode 16: Success in 80 steps, reward: -3.00
Episode 17: Success in 93 steps, reward: 2.20
Episode 18: Success in 99 steps, reward: -0.60
Episode 19: Success in 88 steps, reward: 1.10
Episode 20: Success in 97 steps, reward: -1.90
Episod