In [1]:
import os
import sys

sys.path.append("../")

In [2]:
import time

from accelerate import Accelerator

import torch
import torch.nn.functional as F
import numpy as np

from utils import open_pickle

from cube3_game import Cube3Game
from models import Pilgrim
from g_datasets import get_torch_scrambles_3, reverse_actions
from g_datasets import scrambles_collate_fn
from g_datasets import Cube3Dataset3 
from utils import set_seed

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
set_seed(0)
game = Cube3Game("../assets/envs/qtm_cube3.pickle")
generators = torch.tensor(game.actions, dtype=torch.int64)

state_size = game.actions.shape[1]
hash_vec = torch.randint(0, 1_000_000_000_000, (state_size,))  

scrambles, actions, lengths = get_torch_scrambles_3(
    N=1,
    n=26,
    generators=generators,
    hash_vec=hash_vec,
    device="cpu"
)

In [4]:
model = Pilgrim(
    hidden_dim1 = 500, 
    hidden_dim2  = 300, 
    num_residual_blocks = 3,    
)
model.load_state_dict(torch.load(
    "../assets/models/Cube3ResnetModel_policy.pt"
))

<All keys matched successfully>

In [5]:
scrambles.shape

torch.Size([26, 54])

In [6]:
actions

tensor([11,  8, 11,  4,  8,  8,  4,  6,  7,  0,  7,  6,  3,  7,  8,  1,  1,  1,
         4,  7, 10,  8,  1,  2,  6,  4])

In [7]:
reversed_actions = reverse_actions(actions, n_gens=len(generators))

In [8]:
reversed_actions

tensor([ 5.,  2.,  5., 10.,  2.,  2., 10.,  0.,  1.,  6.,  1.,  0.,  9.,  1.,
         2.,  7.,  7.,  7., 10.,  1.,  4.,  2.,  7.,  8.,  0., 10.])

In [9]:
i = 0
s = scrambles[i, :]
for a in reversed_actions.tolist()[:i+1][::-1]:
    a = int(a)
    s = s[generators[a]]
s

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53])

In [10]:
reversed_actions.tolist()

[5.0,
 2.0,
 5.0,
 10.0,
 2.0,
 2.0,
 10.0,
 0.0,
 1.0,
 6.0,
 1.0,
 0.0,
 9.0,
 1.0,
 2.0,
 7.0,
 7.0,
 7.0,
 10.0,
 1.0,
 4.0,
 2.0,
 7.0,
 8.0,
 0.0,
 10.0]

In [11]:
actions

tensor([11,  8, 11,  4,  8,  8,  4,  6,  7,  0,  7,  6,  3,  7,  8,  1,  1,  1,
         4,  7, 10,  8,  1,  2,  6,  4])

In [12]:
model = model.eval()

In [19]:
i = 8
with torch.no_grad():
    value, policy = model(scrambles[i, :].unsqueeze(dim=0))
    policy = policy.softmax(dim=1)
    
print(policy)
print("argmax:", torch.argmax(policy))
print("ra:", reversed_actions[i])
print("a:", actions[i])

tensor([[0.1656, 0.0731, 0.0650, 0.0549, 0.1025, 0.0927, 0.0357, 0.0469, 0.1084,
         0.1387, 0.0619, 0.0547]])
argmax: tensor(0)
ra: tensor(1.)
a: tensor(7)


In [None]:
training_dataset = Cube3Dataset3(
    n = 32,
    N = 400,
    size = 1_000_000,
    generators = torch.tensor(game.actions, dtype=torch.int64, device="mps"),
    device="mps"
)
training_dataloader = torch.utils.data.DataLoader(
    training_dataset, 
    batch_size=32,
    shuffle=True, 
    num_workers=0,
    collate_fn=scrambles_collate_fn
)

In [None]:
for data in training_dataloader:
    states, actions, targets = data
    break

In [None]:
states.shape

In [None]:
torch.max(actions)

In [None]:
torch.min(actions)

In [None]:
r_actions = reverse_actions(actions, n_gens=12)

In [None]:
torch.max(r_actions)

In [None]:
torch.min(actions)

In [None]:
len(generators)

In [None]:
targets.shape