In [None]:
import gym 
from gym import Env
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete 
import numpy as np
import random
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy

import itertools
import math

In [None]:
# Configurations:
arraySize = 8
halfArraySize = 4
sortStages = 6

# Rewards are based on the arrays length + bitonic order comparators + all variations amount
comparatorRemovalReward = 1
comparatorAdditionPenalty = -1

# Our observation/action space
observation_space = Box(low=0, high=arraySize+1, shape=(sortStages * 2, halfArraySize), dtype = np.ushort)
action_space = MultiDiscrete([2,sortStages,arraySize,arraySize],dtype=np.ushort)

observation_space.sample()
action_space.sample()

notActiveComparator = arraySize

# Helpers for generating the zero one arrays
zeroOne = [0,1]
iterables = []
sortedVariations = []

for i in range (0,arraySize):
    iterables.append(zeroOne)
for t in itertools.product(*iterables):
    t = list(t)
    t.sort()
    sortedVariations.append(t)

allVariationsCount = len(sortedVariations)
minimumComparatorsPossible = min(19, arraySize*math.log(arraySize, 2))
comparatorRemovalReward = allVariationsCount/(halfArraySize * sortStages - minimumComparatorsPossible)
comparatorAdditionPenalty = -comparatorRemovalReward
maximumScore = len(sortedVariations) * 2

print(minimumComparatorsPossible)
print(len(sortedVariations))
print(comparatorRemovalReward)
print(comparatorAdditionPenalty)

In [None]:
def getunsortedcount(sortingNetwork):
    unsortedArraysCount = 0
    arrayToSortIndex = 0
    
    for arrayToSort in itertools.product(*iterables):
        arrayToSort = list(arrayToSort)
        # Iterating over the sorting network
        for currentStage in range(0, sortStages):
            firstIndexColumn = currentStage * 2
            secondIndexColumn = currentStage * 2 + 1
            for row in range(0, halfArraySize):
                # comparator has been found
                if(sortingNetwork[firstIndexColumn][row] != notActiveComparator and sortingNetwork[secondIndexColumn][row] != notActiveComparator):
                    firstIndexToSwap = sortingNetwork[firstIndexColumn][row];
                    secondIndexToSwap = sortingNetwork[secondIndexColumn][row];
                    
                    firstValueToSwap = arrayToSort[firstIndexToSwap]
                    secondValueToSwap = arrayToSort[secondIndexToSwap]
                    # swap values
                    arrayToSort[firstIndexToSwap] = min(firstValueToSwap, secondValueToSwap)
                    arrayToSort[secondIndexToSwap] = max(firstValueToSwap, secondValueToSwap)
        
        # Checking if we sorted the array properly
        if(arrayToSort != sortedVariations[arrayToSortIndex]):
            unsortedArraysCount+=1
            
        # Increasing the sorted arrays counter to the next array
        arrayToSortIndex+=1
    return unsortedArraysCount;
def turnSortingNetworkToBitonic(sortingNetwork):
    global comparatorsCnt
    global comparators
    comparatorsCnt = 0
    comparators = []
    
    # Python program for Bitonic Sort. Note that this program
    # works only when size of input is a power of 2.

    # The parameter dir indicates the sorting direction, ASCENDING
    # or DESCENDING; if (a[i] > a[j]) agrees with the direction,
    # then a[i] and a[j] are interchanged.*/

    def compAndSwap(a, i, j, dire):
        global comparatorsCnt
        global comparators

        if(comparatorsCnt == 8):
            #print("-----------------")
            comparatorsCnt = 0
        comparatorsCnt+=1

        if(dire==0):
            #print(j , ', ' ,i)
            comparators.append([j,i])
        if(dire==1):
            #print(i ,', ' ,j)
            comparators.append([i,j])

    # It recursively sorts a bitonic sequence in ascending order,
    # if dir = 1, and in descending order otherwise (means dir=0).
    # The sequence to be sorted starts at index position low,
    # the parameter cnt is the number of elements to be sorted.
    def bitonicMerge(a, low, cnt, dire):
        if cnt > 1:
            k = cnt//2
            for i in range(low , low+k):
                compAndSwap(a, i, i+k, dire)
            bitonicMerge(a, low, k, dire)
            bitonicMerge(a, low+k, k, dire)

    # This function first produces a bitonic sequence by recursively
    # sorting its two halves in opposite sorting orders, and then
    # calls bitonicMerge to make them in the same order
    def bitonicSort(a, low, cnt,dire):
        if cnt > 1:
            k = cnt//2
            bitonicSort(a, low, k, 1)
            bitonicSort(a, low+k, k, 0)
            bitonicMerge(a, low, cnt, dire)

    # Caller of bitonicSort for sorting the entire array of length N
    # in ASCENDING order
    def sort(a,N, up):
        bitonicSort(a,0, N, up)

    # Driver code to test above
    a = []
    for i in range(0, arraySize):
        a.append(i)
    n = len(a)
    up = 1

    sort(a, n, up)

    #print ("\n\nSorted array is")
    #for i in range(n):
    #    print("%d" %a[i],end=" ")
    #comparators

    comparatorMap = []

    for i in range (0,sortStages):
        comparatorMap.append([])

    # Reordering all comparators
    for compartor in comparators:
        rightPlace = 0
        while(rightPlace <= 15):
            if(any(((ele[0] == compartor[0]) or (ele[1] == compartor[0]) or
                   (ele[0] == compartor[1]) or (ele[1] == compartor[1])) for ele in comparatorMap[rightPlace])):
                rightPlace+=1
            else:
                comparatorMap[rightPlace].append(compartor)
                break
    # Inserting the comparators to our network
    for currentStage in range(0, sortStages):
        firstIndexColumn = currentStage * 2
        secondIndexColumn = currentStage * 2 + 1
        for row in range(0, halfArraySize):
            sortingNetwork[firstIndexColumn][row] = comparatorMap[currentStage][row][0]
            sortingNetwork[secondIndexColumn][row] = comparatorMap[currentStage][row][1]

In [None]:
class SortingNetworkEnv(Env):
    def __init__(self):
        
        self.action_space = MultiDiscrete([2,sortStages,arraySize,arraySize],dtype=np.int32)
        self.observation_space = Box(low=0, high=arraySize, shape=(sortStages * 2, halfArraySize), dtype = np.int32)
        # Should be a bitonic network
        self.state = turnSortingNetworkToBitonic(self.observation_space.sample())
        # Scores
        self.score = allVariationsCount/2
        self.validityScore = 1
        self.efficiencyScore = 0
        
    def step(self, action):
        addOrRemove = action[0]
        addOrRemoveReward = 0

        columnToChange = action[1]
        xColumn = columnToChange * 2
        yColumn = columnToChange * 2 + 1
        
        firstIndex = action[2]
        secondIndex = action[3]
        # Remove the comparator
        if(addOrRemove == 0):
            for i in range(0, halfArraySize):
                if(self.state[xColumn][i] == firstIndex and self.state[yColumn][i] == secondIndex):
                    self.state[xColumn][i] = notActiveComparator
                    self.state[yColumn][i] = notActiveComparator
                    
                    # Updating the reward and the efficiency score properly
                    addOrRemoveReward+=comparatorRemovalReward;
                    self.efficiencyScore += 1
                    break
                    
        # Add the comparator    
        else:
            rowToAddThecomparator = -1
            additionCanBeApplied = True
            
            for i in range(0, halfArraySize):
                if(self.state[xColumn][i] == firstIndex or self.state[yColumn][i] == secondIndex or
                   self.state[xColumn][i] == secondIndex or self.state[yColumn][i] == firstIndex):
                    additionCanBeApplied = False
                    break
                if(self.state[xColumn][i] == notActiveComparator and self.state[yColumn][i] == notActiveComparator):
                    rowToAddThecomparator = i
            
            if(additionCanBeApplied and rowToAddThecomparator > -1 and firstIndex != secondIndex):
                self.state[xColumn][rowToAddThecomparator] = firstIndex
                self.state[yColumn][rowToAddThecomparator] = secondIndex
                    
                # Updating the reward and the efficiency score properly
                addOrRemoveReward+=comparatorAdditionPenalty;
                self.efficiencyScore -= 1
            # Adding process

        
        # TODO: Define the done policy
        if self.score >= maximumScore: 
            done = True
        else:
            done = False
        
        info = {}
        
        # Network validity check
        previousNetworkValidity = self.validityScore  * allVariationsCount
        currentUnsortedArraysCount = getunsortedcount(self.state)
        self.validityScore = currentUnsortedArraysCount/allVariationsCount
        
        # Total reward calculation
        reward = addOrRemoveReward + (currentUnsortedArraysCount - previousNetworkValidity)
        
        # Return step information
        return self.state, reward, done, info

    def render(self):
        # Implement viz
        print("Validity Score: ", self.validityScore)
        print("Efficiency Score: ", self.efficiencyScore)
        print("Total score: ", self.score)
        pass
    
    def reset(self):
        # Should be a bitonic network
        #.state = turnSortingNetworkToBitonic(observation_space.sample())
        self.state = np.full((sortStages*2, halfArraySize), arraySize, dtype = np.int32)
        
        # Scores
        self.score = 0
        self.validityScore = allVariationsCount
        self.efficiencyScore = 0
        return self.state

In [None]:
env=SortingNetworkEnv()
log_path = os.path.join('Training', 'Logs')

In [None]:

model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_path)


In [None]:
model.learn(total_timesteps=400)

In [None]:
print(env.state)
print("Unsorted count of the current env: ", getunsortedcount(env.state))
print("env validityScore: ", env.validityScore)
print("env efficiencyScore: ", env.efficiencyScore)
print("env score: ", env.score)

bitonic = np.full((sortStages*2, halfArraySize), arraySize, dtype = np.int32)
turnSortingNetworkToBitonic(bitonic)
print(bitonic)
print("Unsorted count of the bitonic env: ", getunsortedcount(bitonic))
      

In [None]:
env=SortingNetworkEnv()
from stable_baselines3.common.env_checker import check_env
check_env(env, warn=True)

In [None]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

In [None]:
env.close()