In [1]:
import os

import torch
import torch.nn as nn

from data_handling import SequenceDataset
from torch.utils.data import DataLoader

from model_definitions import AuAcceptor

import math
import pickle as pk

DATA_PATH = "../data/problem_4_train_dfa_accept_and_reject.dat"
EPOCHS = 120
TRANSFORMER_NAME = "transformer_problem_4.pk" # the transformer that we'll load

ACCEPTOR_NAME = "acceptor_problem_4.pk" # how to save model

DATASET_CONTAINER_PATH = "dataset_problem_4.pk"

## Data loading and preprocessing

In [2]:
dataset = SequenceDataset(DATA_PATH, maxlen=15)
dataset.initialize(DATASET_CONTAINER_PATH)
dataset.encode_sequences()

Alphabet size:  3
Sequences loaded. Some examples: 
[['a', 'b', 'a', 'a', 'a', 'b', 'a', 'b', 'c', 'a', 'a', 'a', 'a'], ['a', 'b', 'c', 'b', 'b', 'b', 'a', 'b', 'c', 'c', 'a', 'c', 'a'], ['c', 'a', 'c']]
The symbol dictionary: {'a': 0, 'b': 1, 'c': 2}


In [3]:
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

## Encoder and Decoder

In [4]:
transformer = torch.load(TRANSFORMER_NAME)

transformer.eval()

AuTransformer(
  (pos_encoding): PositionalEncoding()
  (input_embedding): Embedding(6, 3)
  (output_fnn): Linear(in_features=3, out_features=6, bias=True)
  (gelu): GELU(approximate='none')
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax_output): Softmax(dim=-1)
  (attention_output_layer): Identity()
  (attention_weight_layer): Identity()
  (src_embedding_output_layer): Identity()
)

In [5]:
model = AuAcceptor(transformer_model=transformer, alphabet_size=4, 
                   embedding_dim=transformer.embedding_dim, maxlen_of_sequence=dataset.maxlen, freeze_transformer=False)
model.eval()

AuAcceptor(
  (pos_encoding): PositionalEncoding()
  (input_embedding): Embedding(6, 3)
  (dropout): Dropout(p=0.2, inplace=False)
  (hidden_layer): Linear(in_features=51, out_features=1, bias=True)
  (attention_output_layer): Identity()
  (attention_weight_layer): Identity()
  (src_embedding_output_layer): Identity()
)

## Train

In [6]:
# training loop from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)# for 3 heads lr=0.00001)
loss_fn = nn.BCELoss()

running_loss = 0.
last_loss = 0.
divisor = 0.

for i in range(EPOCHS):
    print("Epoch: ", i)
    for j, (test_string_ord, test_string_ord_sr, test_string_oh, test_string_oh_sr, labels, sequence_length) in enumerate(train_dataloader):
        optimizer.zero_grad()
        test_string_ord = torch.permute(test_string_ord, dims=[1,0])
        test_string_ord_sr = torch.permute(test_string_ord_sr, dims=[1,0])
        test_string_oh_sr = torch.permute(test_string_oh_sr, dims=[1,0,2])

        #print(sequence_length, list(test_string_ord.size()), labels)
        #break
        
        outputs = model(test_string_ord, test_string_ord_sr, sequence_length)
        loss = loss_fn(torch.squeeze(outputs), torch.squeeze(labels).float())
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        divisor += float( list(test_string_ord.size())[1] )
        if j % 1000 == 0:
            last_loss = running_loss / divisor # loss per batch
            print('  batch {} loss: {}'.format(j, last_loss))
            running_loss = 0.
            divisor = 0.
    #break
#raise Exception("Stop here")

  from .autonotebook import tqdm as notebook_tqdm


Epoch:  0
  batch 0 loss: 0.09485961496829987
  batch 1000 loss: 0.07841320924088359
  batch 2000 loss: 0.07422118492051959
Epoch:  1
  batch 0 loss: 0.07185872994363308
  batch 1000 loss: 0.06991370917856693
  batch 2000 loss: 0.06696430980041623
Epoch:  2
  batch 0 loss: 0.06510196321457624
  batch 1000 loss: 0.06365099645778537
  batch 2000 loss: 0.0608622776158154
Epoch:  3
  batch 0 loss: 0.06008661155775189
  batch 1000 loss: 0.05872839763946831
  batch 2000 loss: 0.0566308329012245
Epoch:  4
  batch 0 loss: 0.05564231513440609
  batch 1000 loss: 0.05442515068873763
  batch 2000 loss: 0.053542224524542686
Epoch:  5
  batch 0 loss: 0.0528509355597198
  batch 1000 loss: 0.0512775364946574
  batch 2000 loss: 0.051006896924227475
Epoch:  6
  batch 0 loss: 0.04959078080207109
  batch 1000 loss: 0.04952894089743495
  batch 2000 loss: 0.04792770919576287
Epoch:  7
  batch 0 loss: 0.04670621825009585
  batch 1000 loss: 0.04694832682795823
  batch 2000 loss: 0.04526688189804554
Epoch:  8


KeyboardInterrupt: 

In [7]:
torch.save(model, ACCEPTOR_NAME)