In [1]:
import torch
import random
import numpy as np
from models import TFModel
from utils import create_folder, fix_random_seed, Config
from task_generate import gen_default_tokernizer, simpleReasoning
from train import make_scheduler, train_fresh_sample
import yaml
import json

with open("./config.yaml", "r") as file:
    config_args = yaml.safe_load(file)
config = Config(**config_args)
fix_random_seed(config.seed)

<torch._C.Generator at 0x7fbe301b64f0>

In [2]:
tokenizer = gen_default_tokernizer()
vocab_size = len(tokenizer["map"])
config.vocab_size = vocab_size

# try generating some example
task = simpleReasoning(tokenizer, max_variables=config.max_variables, max_parenthesis=config.max_parenthesis, max_seq_len=config.max_seq_len)
res, info = task.formatted_sample(num_steps=100)
print(json.dumps(task.get_task_details(), indent=4))

for ids in res[:5]:
    print(task.map_ids_to_str(ids))
    print("")

{
    "name": "simple_reasoning",
    "description": "simple arithmetic deduction steps",
    "max_variables": 12,
    "max_parenthesis": 7,
    "num_variables_samp_prob": null,
    "num_parentheses_samp_prob": null,
    "variable_samp_prob": null,
    "operator_samp_prob": null,
    "sample_size": 97
}
PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP(9-7-5-(1+1)-0)  >>>
(9-7-5-2-0)

PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP(0-0+5-3+0)-9-2+0  >>>
(0+5-3+0)-9-2+0

PPPPPPPPPPP((3+0+2)-(((2+6+(6-8))+1)))  >>>
((3+2)-(((2+6+(6-8))+1)))

PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP5-(1+8+0)-9  >>>
5-(9+0)-9

PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP(8+6+3)  >>>
(4+3)



In [3]:
model = TFModel(config).to(config.device)
config.print_output = True

In [4]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.lr,
    betas=(0.9, 0.98),
    eps=1e-9,
    weight_decay=config.wd if config.use_wd else 0,
)

scheduler = make_scheduler(optimizer, config)


In [5]:
model, training_info = train_fresh_sample(model, config, optimizer, scheduler)

  0%|          | 1/500 [00:00<07:17,  1.14it/s]

----> Epoch:     1, Test Loss: 2.727, Test Error: 0.889


  0%|          | 2/500 [00:01<05:41,  1.46it/s]

----> Epoch:     2, Test Loss: 2.351, Test Error: 0.907


  1%|          | 3/500 [00:01<04:52,  1.70it/s]

----> Epoch:     3, Test Loss: 2.080, Test Error: 0.928


  1%|          | 4/500 [00:02<04:34,  1.80it/s]

----> Epoch:     4, Test Loss: 1.935, Test Error: 0.939


  1%|          | 5/500 [00:02<04:29,  1.84it/s]

----> Epoch:     5, Test Loss: 1.881, Test Error: 0.944


  1%|          | 6/500 [00:03<04:30,  1.82it/s]

----> Epoch:     6, Test Loss: 1.882, Test Error: 0.922


  1%|▏         | 7/500 [00:04<04:29,  1.83it/s]

----> Epoch:     7, Test Loss: 1.847, Test Error: 0.849


  2%|▏         | 8/500 [00:04<04:18,  1.90it/s]

----> Epoch:     8, Test Loss: 1.762, Test Error: 0.760


  2%|▏         | 9/500 [00:04<04:09,  1.97it/s]

----> Epoch:     9, Test Loss: 1.685, Test Error: 0.741


  2%|▏         | 10/500 [00:05<04:04,  2.01it/s]

----> Epoch:    10, Test Loss: 1.628, Test Error: 0.763


  2%|▏         | 11/500 [00:06<04:13,  1.93it/s]

----> Epoch:    11, Test Loss: 1.593, Test Error: 0.776


  4%|▍         | 21/500 [00:10<03:48,  2.09it/s]

----> Epoch:    21, Test Loss: 1.434, Test Error: 0.749


  6%|▌         | 31/500 [00:15<03:49,  2.04it/s]

----> Epoch:    31, Test Loss: 1.409, Test Error: 0.741


  8%|▊         | 41/500 [00:20<03:39,  2.09it/s]

----> Epoch:    41, Test Loss: 1.400, Test Error: 0.734


 10%|█         | 51/500 [00:25<03:54,  1.92it/s]

----> Epoch:    51, Test Loss: 1.389, Test Error: 0.731


 12%|█▏        | 61/500 [00:30<03:25,  2.14it/s]

----> Epoch:    61, Test Loss: 1.388, Test Error: 0.723


 14%|█▍        | 71/500 [00:35<03:23,  2.11it/s]

----> Epoch:    71, Test Loss: 1.396, Test Error: 0.720


 16%|█▌        | 81/500 [00:40<03:17,  2.12it/s]

----> Epoch:    81, Test Loss: 1.426, Test Error: 0.727


 18%|█▊        | 91/500 [00:44<03:15,  2.10it/s]

----> Epoch:    91, Test Loss: 1.489, Test Error: 0.734


 20%|██        | 101/500 [00:49<03:06,  2.14it/s]

----> Epoch:   101, Test Loss: 1.589, Test Error: 0.731


 22%|██▏       | 111/500 [00:54<03:03,  2.11it/s]

----> Epoch:   111, Test Loss: 1.720, Test Error: 0.735


 24%|██▍       | 121/500 [00:59<03:15,  1.94it/s]

----> Epoch:   121, Test Loss: 1.846, Test Error: 0.739


 26%|██▌       | 131/500 [01:04<02:57,  2.08it/s]

----> Epoch:   131, Test Loss: 1.920, Test Error: 0.739


 28%|██▊       | 141/500 [01:09<02:51,  2.09it/s]

----> Epoch:   141, Test Loss: 1.966, Test Error: 0.736


 30%|███       | 151/500 [01:13<02:46,  2.10it/s]

----> Epoch:   151, Test Loss: 2.008, Test Error: 0.740


 32%|███▏      | 161/500 [01:18<02:47,  2.03it/s]

----> Epoch:   161, Test Loss: 2.015, Test Error: 0.738


 34%|███▍      | 171/500 [01:23<02:36,  2.10it/s]

----> Epoch:   171, Test Loss: 2.028, Test Error: 0.737


 36%|███▌      | 181/500 [01:28<02:50,  1.87it/s]

----> Epoch:   181, Test Loss: 2.026, Test Error: 0.738


 38%|███▊      | 191/500 [01:33<02:29,  2.07it/s]

----> Epoch:   191, Test Loss: 2.029, Test Error: 0.740


 40%|████      | 201/500 [01:38<02:21,  2.12it/s]

----> Epoch:   201, Test Loss: 2.012, Test Error: 0.738


 42%|████▏     | 211/500 [01:43<02:15,  2.14it/s]

----> Epoch:   211, Test Loss: 2.020, Test Error: 0.735


 44%|████▍     | 221/500 [01:48<02:25,  1.92it/s]

----> Epoch:   221, Test Loss: 1.996, Test Error: 0.737


 46%|████▌     | 231/500 [01:53<02:12,  2.04it/s]

----> Epoch:   231, Test Loss: 1.994, Test Error: 0.737


 48%|████▊     | 241/500 [01:58<02:07,  2.03it/s]

----> Epoch:   241, Test Loss: 1.984, Test Error: 0.737


 50%|█████     | 251/500 [02:03<01:59,  2.09it/s]

----> Epoch:   251, Test Loss: 1.979, Test Error: 0.735


 52%|█████▏    | 261/500 [02:08<01:59,  1.99it/s]

----> Epoch:   261, Test Loss: 1.963, Test Error: 0.735


 54%|█████▍    | 271/500 [02:13<01:53,  2.02it/s]

----> Epoch:   271, Test Loss: 1.961, Test Error: 0.734


 56%|█████▌    | 281/500 [02:18<01:48,  2.02it/s]

----> Epoch:   281, Test Loss: 1.937, Test Error: 0.732


 58%|█████▊    | 291/500 [02:22<01:39,  2.09it/s]

----> Epoch:   291, Test Loss: 1.948, Test Error: 0.732


 60%|██████    | 301/500 [02:27<01:33,  2.12it/s]

----> Epoch:   301, Test Loss: 1.938, Test Error: 0.735


 62%|██████▏   | 311/500 [02:32<01:28,  2.14it/s]

----> Epoch:   311, Test Loss: 1.922, Test Error: 0.732


 64%|██████▍   | 321/500 [02:37<01:25,  2.09it/s]

----> Epoch:   321, Test Loss: 1.926, Test Error: 0.736


 66%|██████▌   | 331/500 [02:42<01:20,  2.09it/s]

----> Epoch:   331, Test Loss: 1.922, Test Error: 0.733


 68%|██████▊   | 341/500 [02:47<01:19,  2.00it/s]

----> Epoch:   341, Test Loss: 1.921, Test Error: 0.733


 70%|███████   | 351/500 [02:52<01:09,  2.14it/s]

----> Epoch:   351, Test Loss: 1.907, Test Error: 0.731


 72%|███████▏  | 361/500 [02:56<01:06,  2.08it/s]

----> Epoch:   361, Test Loss: 1.898, Test Error: 0.732


 74%|███████▍  | 371/500 [03:02<01:02,  2.06it/s]

----> Epoch:   371, Test Loss: 1.892, Test Error: 0.733


 76%|███████▌  | 381/500 [03:06<00:58,  2.05it/s]

----> Epoch:   381, Test Loss: 1.896, Test Error: 0.733


 78%|███████▊  | 391/500 [03:12<00:54,  2.01it/s]

----> Epoch:   391, Test Loss: 1.884, Test Error: 0.730


 80%|████████  | 401/500 [03:16<00:47,  2.10it/s]

----> Epoch:   401, Test Loss: 1.894, Test Error: 0.734


 82%|████████▏ | 411/500 [03:21<00:44,  1.99it/s]

----> Epoch:   411, Test Loss: 1.887, Test Error: 0.736


 84%|████████▍ | 421/500 [03:27<00:38,  2.03it/s]

----> Epoch:   421, Test Loss: 1.878, Test Error: 0.732


 86%|████████▌ | 431/500 [03:31<00:33,  2.04it/s]

----> Epoch:   431, Test Loss: 1.885, Test Error: 0.733


 88%|████████▊ | 441/500 [03:36<00:28,  2.10it/s]

----> Epoch:   441, Test Loss: 1.876, Test Error: 0.735


 90%|█████████ | 451/500 [03:41<00:24,  1.97it/s]

----> Epoch:   451, Test Loss: 1.891, Test Error: 0.738


 92%|█████████▏| 461/500 [03:46<00:19,  2.05it/s]

----> Epoch:   461, Test Loss: 1.871, Test Error: 0.733


 94%|█████████▍| 471/500 [03:51<00:15,  1.85it/s]

----> Epoch:   471, Test Loss: 1.881, Test Error: 0.737


 96%|█████████▌| 481/500 [03:57<00:09,  2.03it/s]

----> Epoch:   481, Test Loss: 1.877, Test Error: 0.733


 98%|█████████▊| 491/500 [04:02<00:04,  1.97it/s]

----> Epoch:   491, Test Loss: 1.866, Test Error: 0.736


100%|██████████| 500/500 [04:07<00:00,  2.02it/s]
