In [1]:
import os

import torch
import torch.nn as nn

from data_handling import SequenceDataset
from torch.utils.data import DataLoader

from model_definitions import AuTransformer

import math
import pickle as pk

DATA_PATH = "../data/pdfa_problem_1_train.dat"
EPOCHS = 35
MODEL_NAME = "transformer_pdfa_problem_1.pk" # how to save model
DATASET_SAVE_PATH = "dataset_pdfa_problem_1.pk"

## Data loading and preprocessing

In [2]:
dataset = SequenceDataset(DATA_PATH, maxlen=15)
dataset.encode_sequences()
dataset.save_state(DATASET_SAVE_PATH)

Alphabet size:  3
Sequences loaded. Some examples: 
[['b', 'c', 'a', 'a', 'a', 'a', 'b', 'b', 'b', 'c', 'a', 'a', 'c', 'b', 'c'], ['a', 'c', 'c', 'b', 'c', 'c', 'b', 'b', 'c', 'a', 'c', 'c', 'c', 'a', 'b'], ['c', 'a', 'a']]
The symbol dictionary: {'b': 0, 'c': 1, 'a': 2}


In [3]:
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [4]:
dataset.symbol_dict

{'b': 0, 'c': 1, 'a': 2}

## Encoder and Decoder

In [5]:
model = AuTransformer(n_encoders=1, n_decoders=1, alphabet_size=3, embedding_dim=3, max_len=15)
model.eval()

AuTransformer(
  (pos_encoding): PositionalEncoding()
  (input_embedding): Embedding(6, 3)
  (output_fnn): Linear(in_features=3, out_features=6, bias=True)
  (gelu): GELU(approximate='none')
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax_output): Softmax(dim=-1)
  (attention_output_layer): Identity()
  (attention_weight_layer): Identity()
  (src_embedding_output_layer): Identity()
)

## Train

In [6]:
# training loop from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)# for 3 heads lr=0.00001)
loss_fn = nn.CrossEntropyLoss()

running_loss = 0.
last_loss = 0.
divisor = 0.

for i in range(EPOCHS):
    print("Epoch: ", i)
    for j, (test_string_ord, test_string_ord_sr, test_string_oh, test_string_oh_sr, _, sequence_length) in enumerate(train_dataloader):
        optimizer.zero_grad()
        test_string_ord = torch.permute(test_string_ord, dims=[1,0])
        test_string_ord_sr = torch.permute(test_string_ord_sr, dims=[1,0])
        test_string_oh_sr = torch.permute(test_string_oh_sr, dims=[1,0,2])
        
        outputs = model(test_string_ord, test_string_ord_sr)
        loss = loss_fn(torch.squeeze(outputs), torch.squeeze(test_string_oh_sr))
        loss.backward()
        
        #break

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        divisor += float( list(test_string_ord.size())[1] )
        if j % 1000 == 0:
            last_loss = running_loss / divisor # loss per batch
            print('  batch {} loss: {}'.format(j, last_loss))
            running_loss = 0.
            divisor = 0.
    #break

  from .autonotebook import tqdm as notebook_tqdm


Epoch:  0
  batch 0 loss: 1.7453734874725342
  batch 1000 loss: 1.8094243041276932
  batch 2000 loss: 1.8054492086172105
  batch 3000 loss: 1.805641761779785
  batch 4000 loss: 1.8020022773742677
  batch 5000 loss: 1.7994120483398437
  batch 6000 loss: 1.797045438170433
  batch 7000 loss: 1.792621690750122
  batch 8000 loss: 1.78820743227005
  batch 9000 loss: 1.7900837712287903
Epoch:  1
  batch 0 loss: 1.7840820249319076
  batch 1000 loss: 1.7858892290592194
  batch 2000 loss: 1.7834451261758804
  batch 3000 loss: 1.77810253739357
  batch 4000 loss: 1.7788764984607697
  batch 5000 loss: 1.771708419084549
  batch 6000 loss: 1.7733616417646407
  batch 7000 loss: 1.7704540235996247
  batch 8000 loss: 1.7662045642137527
  batch 9000 loss: 1.7650806121826172
Epoch:  2
  batch 0 loss: 1.7651108053922653
  batch 1000 loss: 1.7634481309652328
  batch 2000 loss: 1.7580441371202469
  batch 3000 loss: 1.7548397567272187
  batch 4000 loss: 1.7484263614416122
  batch 5000 loss: 1.7506955243349076

In [7]:
torch.save(model, MODEL_NAME)