In [None]:
from environment import Environment
from CFOP_agents import DQN
import yaml
from itertools import count
import torch


device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [2]:
env = Environment(
        method="CFOP",
        size=3,
        device=device
    )
env.scramble()

args = yaml.safe_load(open("config.yaml", "r"))
agent: DQN = DQN(args["DQN"])

In [3]:
import matplotlib
import matplotlib.pyplot as plt

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x7d2ed62e4c90>

In [4]:
episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [None]:
for _ in range(agent.num_episodes):
    state = env.reset()

    current_reward = env.algorithm.status(env.cube)

    for t in count():
        action = agent.action(state)
        obs, reward, done = env.step(action.item())
        #print(type(next_state), reward, type(done))
        
        if current_reward != 0:
            if reward == 0:
                current_reward -= 1
            else:
                if current_reward == reward:
                    current_reward -= 1
                else:
                    current_reward = reward
        else:
            current_reward += reward

        print(current_reward)

        torch_current_reward = torch.tensor([current_reward], device=device)

        if done:
            next_state = None
        else:
            next_state = obs

        agent.memory.push(state, action, next_state, torch_current_reward)
        state = next_state

        agent.optimize()

        target_net_state_dict = agent.target_net.state_dict()
        policy_net_state_dict = agent.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*agent.tau + target_net_state_dict[key]*(1-agent.tau)
        agent.target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break