In [1]:
import numpy as np
import random
import itertools as it
from sympy.combinatorics import Permutation
import gymnasium as gym
from gymnasium import spaces



COLOR_MAP = {
    "W" :0,
    "G" :1,
    "R" :2,
    "B" :3,
    "O" :4,
    "Y" :5
    }

FACES = ["WWWW", "GGGG", "RRRR", "BBBB", "OOOO", "YYYY"]
SOLVED_STATE_COLOR = [''.join(faces) for faces in it.permutations(FACES)]
SOLVED_STATE_INDEX = np.empty((720,24))

for i,s in enumerate(SOLVED_STATE_COLOR):
    for j,c in enumerate(s):
        SOLVED_STATE_INDEX[i,j] = COLOR_MAP[c]

class Cube2x2Env(gym.Env):

    def __init__(self):
        self.move_count = 0
        self._action_to_move = {
            0: Permutation(23)(2, 19, 21, 8)(3, 17, 20, 10)(4, 6, 7, 5),
            1: Permutation(0, 18, 23, 9)(1, 16, 22, 11)(12, 13, 15, 14),
            2: Permutation(1, 5, 21, 14)(3, 7, 23, 12)(8, 10, 11, 9),
            3: Permutation(23)(0, 4, 20, 15)(2, 6, 22, 13)(16, 17, 19, 18),
            4: Permutation(6, 18, 14, 10)(7, 19, 15, 11)(20, 22, 23, 21),
            5: Permutation(23)(0, 1, 3, 2)(4, 16, 12, 8)(5, 17, 13, 9)
        }

        self.steps_from_solved = 1
        
        self.action_space = spaces.Discrete(6)
        self.observation_space = spaces.Box(0, 5, shape=(24,), dtype=np.uint8)
   
        self.state = self.scramble()

    def step(self, action):
        truncated = False
        move = self._action_to_move[action]
        self.state = move(self.state)
        self.move_count += 1
        # Calculate reward
        if np.any(np.all(self.state == SOLVED_STATE_INDEX, axis=1)): 
            reward = 100
            done = True
        else: 
            done = False
            reward = -1 
        
        if self.move_count > 1000:
            truncated = True
        # Return step information
        return np.array(self.state), reward, done, truncated, {}
    
    def scramble(self):
        state = SOLVED_STATE_INDEX[0]
        for i in range(self.steps_from_solved):
            move = self._action_to_move[random.randint(0,5)]
            state = move(state)
        if np.any(np.all(state == SOLVED_STATE_INDEX, axis=1)):
            self.scramble()
        return np.array(state, dtype=np.uint8)
    
    def reset(self, seed=None):
        # Reset shower temperature
        self.state = self.scramble()
        # Reset shower time
        self.move_count = 0 
        return self.state, {}


        

    def __repr__(self):
        ascii = '''
         +--------+                    
         | {0}    {1} |                    
         |   d1   |                    
         | {2}    {3} |                    
+--------+--------+--------+--------+  
| {16}    {17} | {4}    {5} | {8}    {9} | {12}    {13} |  
|   r1   |   f0   |   r0   |   f1   |  
| {18}    {19} | {6}    {7} | {10}    {11} | {14}    {15} |  
+--------+--------+--------+--------+  
         | {20}    {21} |                    
         |   d0   |                    
         | {22}    {23} |                    
         +--------+                    

      '''
        return ascii.format(*self.state)
        


In [2]:
env = Cube2x2Env()

# Create the neural network

In [18]:
import torchi


In [27]:
import torch

input_layer = torch.tensor(env.state, dtype=torch.float)

W1 = torch.randn(24,72)
B1 = torch.randn(1,72)
W2 = torch.randn(72,72)
B2 = torch.randn(1,72)
W3 = torch.randn(72,6)
B3 = torch.randn(1,6)

layer1 = torch.tanh(input_layer @ W1 + B1)
layer2 = torch.tanh(layer1 @ W2 + B2)
logits = torch.softmax(layer2 @ W3 + B3, dim=1)

In [39]:
logits

tensor([[3.8544e-05, 6.9670e-01, 1.4711e-02, 1.2641e-02, 1.8972e-04, 2.7572e-01]])