In [1]:
import sys
root_path = "/teamspace/studios/this_studio/DeepCubeA"
if root_path not in sys.path:
    sys.path.append(root_path)

import torch
import torch.nn.functional as F

from utils.pytorch_models import ResnetModel
from utils.nnet_utils import load_nnet
from environments.cube3 import *

In [2]:
#python ctg_approx/avi.py --env cube3 --states_per_update 500 --batch_size 50 --nnet_name cube3 --max_itrs 1201000 --loss_thresh 0.06 --back_max 30 --num_update_procs 30

In [2]:
# !apt-get install graphviz
# !pip install torchviz

In [3]:
from torchviz import make_dot

In [4]:
cube_len = 3
state_dim: int = (cube_len ** 2) * 6
nnet = ResnetModel(state_dim, 6, 5000, 1000, 4, 1, True)

In [5]:
nnet = load_nnet(
    model_file = "/teamspace/studios/this_studio/DeepCubeA/saved_models/cube3/current/model_state_dict.pt", 
    nnet = nnet, 
    device = torch.device("cpu")
)

In [6]:
nnet(torch.randint(low=0, high=6, size=(1, 54))).shape

torch.Size([1, 1])

In [7]:
sum(p.numel() for p in nnet.parameters() if p.requires_grad) / 1_000_000

14.663001

In [8]:
nnet

ResnetModel(
  (blocks): ModuleList(
    (0-3): 4 x ModuleList(
      (0): Linear(in_features=1000, out_features=1000, bias=True)
      (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Linear(in_features=1000, out_features=1000, bias=True)
      (3): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (fc1): Linear(in_features=324, out_features=5000, bias=True)
  (bn1): BatchNorm1d(5000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=5000, out_features=1000, bias=True)
  (bn2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_out): Linear(in_features=1000, out_features=1, bias=True)
)

In [12]:
(np.arange(0, 54) / 9).astype(np.int8)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int8)

In [13]:
54 * 6

324

In [14]:
(np.arange(0, 54)).astype(np.int8)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53], dtype=int8)

In [108]:
# # with torch.no_grad():
# make_dot(
#     nnet(torch.randint(low=0, high=6, size=(1, 54))), 
#     params=dict(nnet.named_parameters()),
#     show_attrs=True, 
#     show_saved=True
# )

## Cube3 Env

In [2]:
env = Cube3()

In [8]:
states_itr, scramble_nums = env.generate_states(10, (0, 5))

In [17]:
scramble_nums

[3, 0, 3, 5, 3, 3, 1, 1, 1, 4]

In [13]:
states_itr[4].colors

array([ 6,  3,  0, 30,  4,  1, 33,  5,  2,  9, 10, 17, 12, 13, 25, 38, 37,
       24, 18, 19, 47, 21, 22, 50, 51, 52, 53, 11, 14, 36, 34, 31, 41, 35,
       32, 44,  8,  7, 20, 39, 40, 23, 42, 43, 26, 27, 28, 29, 46, 49, 16,
       45, 48, 15], dtype=uint8)

In [14]:
states_itr[5].colors

array([45, 46, 38,  5,  4,  3,  2,  1,  0, 53, 43, 42, 12, 13, 14, 15, 16,
       17, 29, 32, 35, 19, 22, 25, 18, 21, 24, 27, 28, 20, 30, 31, 23, 33,
       34, 26, 36, 37, 47, 39, 40, 50,  6,  7,  8,  9, 10, 11, 48, 49, 41,
       51, 52, 44], dtype=uint8)

In [9]:
[it.colors for it in states_itr]

[array([ 0,  1,  2,  3,  4,  5,  9, 12, 15, 11, 14, 17, 10, 13, 16,  6,  7,
         8, 45, 19, 20, 48, 22, 23, 51, 25, 26, 35, 34, 42, 32, 31, 39, 29,
        28, 36, 53, 52, 33, 21, 40, 41, 24, 43, 44, 27, 46, 47, 30, 49, 50,
        38, 37, 18], dtype=uint8),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
        51, 52, 53], dtype=uint8),
 array([ 2,  5,  8,  1,  4,  7,  0,  3,  6,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 38, 21, 22, 41, 24, 25, 44, 27, 28, 47, 30, 31, 50, 33,
        34, 53, 36, 37, 29, 39, 40, 32, 42, 43, 35, 45, 46, 20, 48, 49, 23,
        51, 52, 26], dtype=uint8),
 array([ 0, 28, 29, 48,  4, 43, 51,  5, 42, 20, 23, 36, 46, 13, 37, 45, 12,
        38,  2, 21, 53,  1, 22, 52, 15,  3, 26,  8, 30, 27, 39, 31, 41, 11,
        32,  9, 24,  7, 18, 25, 40, 19, 44, 10,  6, 33, 50,

In [18]:
scrambs: List[int] = list(range(0, 5 + 1))
num_env_moves: int = env.get_num_moves()

In [20]:
num_env_moves

12

In [21]:
states_np: np.ndarray = env.generate_goal_states(10, np_format=True)

In [23]:
scramble_nums: np.array = np.random.choice(scrambs, 10)
num_back_moves: np.array = np.zeros(10)

In [25]:
num_back_moves

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [26]:
moves_lt = num_back_moves < scramble_nums

In [29]:
idxs: np.ndarray = np.where(moves_lt)[0]

In [30]:
idxs

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [32]:
subset_size: int = int(max(len(idxs) / num_env_moves, 1))

In [34]:
num_env_moves

12