In [1]:
import pickle
import time
import torch
from torch.utils.data import DataLoader

from custom_dataset import CustomOCRDataset
from helper import create_data_csv
from model import TextClassificationModel, train, evaluate
from preprocess import preprocess, collate_batch

In [2]:
torch.manual_seed(5678)

<torch._C.Generator at 0x7f7935cb67f0>

In [3]:
train_dataset = preprocess()

In [4]:
file = open('./vocab.pkl', 'rb')
# dump information to that file
vocab = pickle.load(file)
# close the file
file.close()

In [5]:
vocab_size = len(vocab)
model = TextClassificationModel(vocab_size)

In [6]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

# Hyperparameters
EPOCHS = 100  # epoch
LR = 0.1  # learning rate
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

dataset = CustomOCRDataset('./train_dataset.csv')
train_iter = iter(train_dataset)
train_dataset = to_map_style_dataset(train_iter)
num_train = int(len(dataset) * 0.95)
split_train_, split_valid_ = random_split(
    train_dataset, [num_train, len(train_dataset) - num_train]
)

train_dataloader = DataLoader(
    split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
valid_dataloader = DataLoader(
    split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
training_accuracy = []
validation_accuracy = []
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    accu_train = train(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model, criterion)
    print("-" * 59)
    print("Training Accuracy: ", accu_train)
    training_accuracy.append(accu_train)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} |".format(
            epoch, time.time() - epoch_start_time, accu_val
        )
    )
    validation_accuracy.append(accu_val)
    print("-" * 59)

-----------------------------------------------------------
Training Accuracy:  0.140625
| end of epoch   1 | time:  0.30s | valid accuracy    0.152 |
-----------------------------------------------------------
-----------------------------------------------------------
Training Accuracy:  0.21875
| end of epoch   2 | time:  0.28s | valid accuracy    0.152 |
-----------------------------------------------------------
-----------------------------------------------------------
Training Accuracy:  0.328125
| end of epoch   3 | time:  0.29s | valid accuracy    0.152 |
-----------------------------------------------------------
-----------------------------------------------------------
Training Accuracy:  0.28125
| end of epoch   4 | time:  0.28s | valid accuracy    0.152 |
-----------------------------------------------------------
-----------------------------------------------------------
Training Accuracy:  0.15625
| end of epoch   5 | time:  0.29s | valid accuracy    0.264 |
--------

In [7]:
torch.save(model,'./model/trained_model.pth')

In [8]:
# Load the saved model
loaded_model = torch.load('./model/trained_model.pth')
loaded_model.eval()  # Set the model to evaluation mode if needed

TextClassificationModel(
  (embedding): EmbeddingBag(106256, 64, mode='mean')
  (fc): Linear(in_features=64, out_features=10, bias=True)
  (relu): ReLU(inplace=True)
  (fc2): Linear(in_features=10, out_features=5, bias=True)
  (relu2): ReLU(inplace=True)
  (fc3): Linear(in_features=5, out_features=2, bias=True)
  (fc4): Linear(in_features=2, out_features=5, bias=True)
)

In [9]:
test_dataset = CustomOCRDataset('./test_dataset.csv')
test_iter = iter(test_dataset)
test_dataset = to_map_style_dataset(test_iter)

In [11]:
valid_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)

In [13]:
l = enumerate(valid_dataloader)