In [1]:
import sys
sys.path.insert(0, "../..")
import torch
import gin
from pathlib import Path
from src.data import data_tools 
from src.models import metrics, train_model, rnn_models

In [2]:
data_dir = Path("../../data/external/gestures-dataset/")

# get all paths with the .txt extension
formats = [".txt"]
paths = [path for path in data_tools.walk_dir(data_dir) if path.suffix in formats]
# make a train-test split
split = 0.8
idx = int(len(paths) * split)
trainpaths = paths[:idx]
testpaths = paths[idx:]
trainloader = data_tools.Datagenerator(trainpaths, batchsize=32)
testloader = data_tools.Datagenerator(testpaths, batchsize=32)

100%|██████████| 2600/2600 [00:02<00:00, 1162.38it/s]
100%|██████████| 651/651 [00:00<00:00, 997.80it/s] 


In [3]:
accuracy = metrics.Accuracy()
loss_fn = torch.nn.CrossEntropyLoss()
log_dir = Path("../../models/attention/")

In [None]:
gin.parse_config_file("gestures.gin")

config = {
    "input_size" : 3,
    "hidden_size" : 100,
    "dropout" : 0.05,
    "num_layers" : 3,
    "output_size" : 20
}

In [4]:
model = rnn_models.GRUmodel(config)

model = train_model.trainloop(
    epochs=10,
    model=model,
    metrics=[accuracy],
    train_dataloader=trainloader,
    test_dataloader=testloader,
    log_dir=log_dir,
    train_steps=len(trainloader),
    eval_steps=len(testloader)
)

2022-05-23 15:04:02.230 | INFO     | src.data.data_tools:dir_add_timestamp:209 - Logging to ../../models/gestures/20220523-1504
100%|██████████| 81/81 [00:03<00:00, 20.90it/s]
2022-05-23 15:04:06.756 | INFO     | src.models.train_model:trainloop:156 - Epoch 0 train 2.2276 test 0.1110 metric ['0.1594']
100%|██████████| 81/81 [00:03<00:00, 25.62it/s]
2022-05-23 15:04:10.290 | INFO     | src.models.train_model:trainloop:156 - Epoch 1 train 1.8876 test 0.1084 metric ['0.1516']
100%|██████████| 81/81 [00:04<00:00, 17.95it/s]
2022-05-23 15:04:15.156 | INFO     | src.models.train_model:trainloop:156 - Epoch 2 train 1.4235 test 0.0837 metric ['0.3297']
100%|██████████| 81/81 [00:03<00:00, 24.60it/s]
2022-05-23 15:04:18.829 | INFO     | src.models.train_model:trainloop:156 - Epoch 3 train 0.9249 test 0.0722 metric ['0.4437']
100%|██████████| 81/81 [00:03<00:00, 24.07it/s]
2022-05-23 15:04:22.544 | INFO     | src.models.train_model:trainloop:156 - Epoch 4 train 0.5823 test 0.0522 metric ['0.4969

In [5]:
model = rnn_models.AttentionGRU(config)

model = train_model.trainloop(
    epochs=10,
    model=model,
    metrics=[accuracy],
    train_dataloader=trainloader,
    test_dataloader=testloader,
    log_dir=log_dir,
    train_steps=len(trainloader),
    eval_steps=len(testloader)
)

2022-05-23 15:08:02.757 | INFO     | src.data.data_tools:dir_add_timestamp:209 - Logging to ../../models/gestures/20220523-1508
100%|██████████| 81/81 [00:04<00:00, 17.05it/s]
2022-05-23 15:08:08.553 | INFO     | src.models.train_model:trainloop:156 - Epoch 0 train 1.6450 test 0.1028 metric ['0.2156']
100%|██████████| 81/81 [00:04<00:00, 17.81it/s]
2022-05-23 15:08:13.629 | INFO     | src.models.train_model:trainloop:156 - Epoch 1 train 1.3177 test 0.0813 metric ['0.4359']
100%|██████████| 81/81 [00:04<00:00, 17.03it/s]
2022-05-23 15:08:19.020 | INFO     | src.models.train_model:trainloop:156 - Epoch 2 train 0.4773 test 0.0361 metric ['0.6594']
100%|██████████| 81/81 [00:05<00:00, 14.81it/s]
2022-05-23 15:08:25.150 | INFO     | src.models.train_model:trainloop:156 - Epoch 3 train 0.3313 test 0.0325 metric ['0.8313']
100%|██████████| 81/81 [00:05<00:00, 14.61it/s]
2022-05-23 15:08:31.319 | INFO     | src.models.train_model:trainloop:156 - Epoch 4 train 0.1038 test 0.0147 metric ['0.8953