In [1]:
import os

import torch
import torch.nn as nn

from data_handling import SequenceDataset
from torch.utils.data import DataLoader

from model_definitions import AuTransformer

import math
import pickle as pk

DATA_PATH = "../data/problem_1_train_dfa_accept_and_reject.dat"
EPOCHS = 1
MODEL_NAME = "trained_model.pk" # the transformer that we'll load

DATASET_CONTAINER_PATH = "dataset.pk"

## Data loading and preprocessing

In [2]:
dataset = SequenceDataset(DATA_PATH, maxlen=10)
dataset.initialize(DATASET_CONTAINER_PATH)
dataset.encode_sequences()
dataset.save_state()

Alphabet size:  4
Sequences loaded. Some examples: 
[['a', 'a', 'd'], ['a', 'd', 'c', 'd', 'd', 'd', 'c'], ['c', 'c', 'a', 'a', 'c', 'b', 'a', 'c', 'a']]
The symbol dictionary: {'a': 0, 'b': 1, 'c': 2, 'd': 3}


In [3]:
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

## Encoder and Decoder

In [4]:
model = AuTransformer(n_encoders=1, n_decoders=1, alphabet_size=4, embedding_dim=3, max_len=10)
model.eval()

AuTransformer(
  (pos_encoding): PositionalEncoding()
  (input_embedding): Embedding(7, 3)
  (output_fnn): Linear(in_features=3, out_features=7, bias=True)
  (gelu): GELU(approximate='none')
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax_output): Softmax(dim=-1)
  (attention_output_layer): Identity()
  (attention_weight_layer): Identity()
)

## Train

In [7]:
# training loop from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)# for 3 heads lr=0.00001)
loss_fn = nn.CrossEntropyLoss()

running_loss = 0.
last_loss = 0.
divisor = 0.

for i in range(EPOCHS):
    print("Epoch: ", i)
    for j, (test_string_ord, test_string_ord_sr, test_string_oh, test_string_oh_sr, _, sequence_length) in enumerate(train_dataloader):
        optimizer.zero_grad()
        test_string_ord = torch.permute(test_string_ord, dims=[1,0])
        test_string_ord_sr = torch.permute(test_string_ord_sr, dims=[1,0])
        test_string_oh_sr = torch.permute(test_string_oh_sr, dims=[1,0,2])
        
        outputs = model(test_string_ord, test_string_ord_sr)
        loss = loss_fn(torch.squeeze(outputs), torch.squeeze(test_string_oh_sr))
        loss.backward()
        
        #break

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        divisor += float( list(test_string_ord.size())[1] )
        if j % 1000 == 0:
            last_loss = running_loss / divisor # loss per batch
            print('  batch {} loss: {}'.format(j, last_loss))
            running_loss = 0.
            divisor = 0.
    #break

Epoch:  0
  batch 0 loss: 1.4126192331314087
  batch 1000 loss: 1.4520285028219222
  batch 2000 loss: 1.4544923737049102
  batch 3000 loss: 1.45058392226696
  batch 4000 loss: 1.449556448340416
  batch 5000 loss: 1.4460676217079163
  batch 6000 loss: 1.4481366548538208
  batch 7000 loss: 1.443817057967186
  batch 8000 loss: 1.4421689954996109
  batch 9000 loss: 1.4423229407072067
  batch 10000 loss: 1.4404325652122498
  batch 11000 loss: 1.4386720995903015
  batch 12000 loss: 1.4395964016914369
Epoch:  1
  batch 0 loss: 1.431784289453166
  batch 1000 loss: 1.4392980933189392
  batch 2000 loss: 1.438429934501648
  batch 3000 loss: 1.4341865170001984
  batch 4000 loss: 1.4365210058689117
  batch 5000 loss: 1.435165327191353
  batch 6000 loss: 1.4319866579771041
  batch 7000 loss: 1.4261005415916443
  batch 8000 loss: 1.4286127898693084
  batch 9000 loss: 1.4300851653814315
  batch 10000 loss: 1.4252449017763138
  batch 11000 loss: 1.4292375328540803
  batch 12000 loss: 1.424433528661728


KeyboardInterrupt: 

In [8]:
torch.save(model, MODEL_NAME)