In [1]:
from process_handler import Process, ProcessManager
from bytes_tokenizer import encode, decode, tokenizer
from transformer_model import create_model, create_mask
from text_dataset import TextDataset
from checkpoint_manager import *
from train_eval_utils import *
from modeling_utils import *

import numpy as np
import time
import torch
import os
import shutil
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from tqdm import tqdm, trange
from typing import List, Tuple, Dict, Any
from torch.nn import Module

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LearningExperiment:
    def __init__(self, model: Module, 
                 model_params: Dict[str, Any], 
                 optimizer_function: str = 'adam', 
                 criterion_function: str = 'cross-entropy', 
                 learning_rate: float = 1e-3, 
                 device: str = 'cpu', 
                 checkpoint_dir: str = 'checkpoint'):
        self.model = model
        self.model_params = model_params
        self.optimizer_function = optimizer_function
        self.criterion_function = criterion_function
        self.learning_rate = learning_rate
        self.device = device
        self.checkpoint_dir = checkpoint_dir

        self.optimizer: Optimizer = None
        self.criterion: nn.Module = None

    def setup(self):
        self.optimizer = get_optimizer(self.model, self.optimizer_function, self.learning_rate)
        self.criterion = get_criterion(self.criterion_function)

    def load_checkpoint(self, checkpoint: str):
        load_checkpoint(checkpoint, self.model, self.optimizer)

    def save_checkpoint(self, is_best: bool, checkpoint: str):
        state = {
            'state_dict': self.model.state_dict(),
            'optim_dict': self.optimizer.state_dict()
        }
        save_checkpoint(state, is_best, checkpoint)

    def train(self, train_dataset, batch_size: int = 32, epochs: int = 100):
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(1, epochs + 1):
            avg_loss = train(self.model, train_loader, self.optimizer, self.criterion, self.device)
            print('[Epoch {}] Train loss: {:.4f}'.format(epoch, avg_loss))

    def evaluate(self, test_dataset, batch_size: int = 32):
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        avg_loss = evaluate(self.model, test_loader, self.criterion, self.device)
        print('Test loss: {:.4f}'.format(avg_loss))

    def train_and_evaluate(self, train_dataset: TextDataset, 
                           test_dataset: TextDataset, 
                           batch_size: int = 32, 
                           epochs: int = 100):
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        for epoch in range(1, epochs + 1):
            train_avg_loss = train(self.model, train_loader, self.optimizer, self.criterion, self.device)
            test_avg_loss = evaluate(self.model, test_loader, self.criterion, self.device)
            print('[Epoch {}] Train loss: {:.4f} | Test loss: {:.4f}'.format(epoch, train_avg_loss, test_avg_loss))

    def run(self, train_dataset: TextDataset, test_dataset: TextDataset, batch_size: int = 32, epochs: int = 100):
        self.setup()
        self.train_and_evaluate(train_dataset, test_dataset, batch_size=batch_size, epochs=epochs)

        is_best = True
        self.save_checkpoint(is_best, self.checkpoint_dir)

    def run_from_checkpoint(self, train_dataset: TextDataset, test_dataset: TextDataset, checkpoint: str, batch_size: int = 32, epochs: int = 100):
        self.setup()
        self.load_checkpoint(checkpoint)
        self.train_and_evaluate(train_dataset, test_dataset, batch_size=batch_size, epochs=epochs)

        is_best = True
        self.save_checkpoint(is_best, self.checkpoint_dir)

    def run_experiment(self, train_dataset: TextDataset, test_dataset: TextDataset, batch_size: int = 32, epochs: int = 100):
        self.setup()
        if os.path.exists(self.checkpoint_dir):
            self.run_from_checkpoint(train_dataset, test_dataset, checkpoint=os.path.join(self.checkpoint_dir, 'last.pth.tar'), batch_size=batch_size, epochs=epochs)
        else:
            self.run(train_dataset, test_dataset, batch_size=batch_size, epochs=epochs)

In [None]:
train_dataset = TextDataset(os.path.join('data', 'train_dataset.txt'), max_len=100)
test_dataset = TextDataset(os.path.join('data', 'test_dataset.txt'), max_len=100)
model = create_model(ntoken=len(tokenizer), ninp=512, nhead=1, nhid=1024, nlayers=6, device='cpu', dropout=0.5)
exp = LearningExperiment(model, model_params={}, optimizer_function='adam', criterion_function='cross-entropy', learning_rate=1e-3, device='cpu', checkpoint_dir='checkpoint')
exp.run_experiment(train_dataset, test_dataset, batch_size=64, epochs=100)

  return torch.from_numpy(encoded_string)
 33%|███▎      | 1/3 [00:08<00:17,  8.93s/it]Total Loss: 6.154289245605469
 67%|██████▋   | 2/3 [00:16<00:08,  8.12s/it]Total Loss: 11.515909671783447
100%|██████████| 3/3 [00:19<00:00,  6.57s/it]
Total Loss: 14.134660720825195
100%|██████████| 3/3 [00:05<00:00,  1.94s/it]
[Epoch 1] Train loss: 4.7116 | Test loss: 4.5192
 33%|███▎      | 1/3 [00:06<00:13,  6.90s/it]Total Loss: 3.720729351043701
 67%|██████▋   | 2/3 [00:14<00:07,  7.17s/it]Total Loss: 7.1338889598846436
100%|██████████| 3/3 [00:16<00:00,  5.66s/it]
Total Loss: 9.545804500579834
100%|██████████| 3/3 [00:04<00:00,  1.54s/it]
[Epoch 2] Train loss: 3.1819 | Test loss: 2.2000
 33%|███▎      | 1/3 [00:06<00:12,  6.20s/it]Total Loss: 2.014049768447876
 67%|██████▋   | 2/3 [00:13<00:06,  6.70s/it]Total Loss: 4.454030990600586
100%|██████████| 3/3 [00:15<00:00,  5.32s/it]
Total Loss: 6.668410778045654
100%|██████████| 3/3 [00:04<00:00,  1.60s/it]
[Epoch 3] Train loss: 2.2228 | Test loss:

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=c8418618-5b01-4dd8-b931-34351753cb66' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>