In [None]:
import time

import numpy as np
np.set_printoptions()
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from context import src
from src.model.data_handler import save_datasets, load_datasets
from src.model.model import RubiksModel
from src.model.train import train

In [None]:
# Creates a dataset with 2,000,000 random scrambled cubes between 5 and 21 solves
# First create a folder called "ignore". This will not be uploaded to github
datapath = save_datasets("ignore/", 2000000, 5000)

In [None]:
datapath = "" # Use this if you are loading in an existing dataset

In [None]:
internal_dimensions = ((2000, 1500, 500), (2500, 2500, 1000), (3000, 2500, 1500))
dropout_rates = [0, 0.1, 0.2]
learning_rates = [0.0008, 0.0005, 0.0002]
counter = 0 # set this to the next epoch after a crash and uncomment the next line
            # if it does crash, also edit the above lists to ensure you aren't retraining 
            # models you don't have to. 
last_saved = -1
results = 0 if last_saved == -1 else torch.load("models/results_{}".format(last_saved))

for internal_dimension in internal_dimensions:
    for dropout_rate in dropout_rates:
        for learning_rate in learning_rates:
            if counter > last_saved:
                model = RubiksModel(internal_dimensions=internal_dimension, dropout_rate=dropout_rate, activation=nn.ReLU)
                train_acc, train_loss, valid_acc, valid_loss = train(model, learning_rate=learning_rate, num_epochs=10, data_path=datapath, savepath="models/")
                results[(internal_dimension, dropout_rate, learning_rate)] = (train_acc, train_loss, valid_acc, valid_loss)
                torch.save(results, "models/results_{}".format(counter))
            counter += 1

In [None]:
results = torch.load(r"FILE_PATH")

for i, item in enumerate(results):
    train_acc, train_loss, valid_acc, valid_loss = results[item]
    print(i, item, np.min(valid_loss))

train_acc, train_loss, valid_acc, valid_loss = results[(2500, 2500, 1000, 0.3, 0.001)]

plt.plot(np.arange(len(train_acc)), train_acc, np.arange(len(train_acc)), valid_acc, label="Acc")
plt.legend()
plt.show()
plt.plot(np.arange(len(train_loss)), train_loss, np.arange(len(train_loss)), valid_loss, label="Loss")
plt.legend()
plt.show()