In [None]:
!pip install torchmetrics wandb tqdm albumentations editdistance 1> /dev/null
!unzip '/content/drive/MyDrive/BachelorsWorkspace/raw.zip' -d '/content/raw' 1> /dev/null

In [None]:
from google.colab import drive
import sys
drive.mount('/content/drive')
sys.path.append('/content/drive/MyDrive/BachelorsWorkspace/')

Mounted at /content/drive


In [None]:
import math
from torch.utils.data import DataLoader
from src.IAMDataset import IAMDataset
from src.model import FullPageHTR
# from src.train import ModelTrainer
from src.LabelParser import LabelParser
import torch
from copy import copy

In [None]:
ds = IAMDataset(base_dir="/content/raw/raw", embedding_loader=None, sample_set="train")
ds_train, ds_val = torch.utils.data.random_split(ds, [math.ceil(0.8 * len(ds)), math.floor(0.2 * len(ds))])

ds_val.data = copy(ds)
ds_val.data.set_transform_pipeline("val")
train_len = len(ds_train)
val_len = len(ds_val)

In [None]:
from functools import partial

batch_size = 2
pad_tkn_idx, eos_tkn_idx = ds.embedding_loader.encode_labels(["<PAD>", "<EOS>"])
collate_fn = partial(
        IAMDataset.collate_fn, pad_val=pad_tkn_idx, eos_tkn_idx=eos_tkn_idx
)
num_workers = 1
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
dl_val = DataLoader(
    ds_val,
    batch_size=2 * batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
train_len //= batch_size
val_len //= 2 * batch_size

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Optimizer
import wandb
from tqdm import tqdm

import math
from torch.utils.data import DataLoader
from src.IAMDataset import IAMDataset
from src.model import FullPageHTR
import gc
import torch
from copy import copy


class ModelTrainer:

    def __init__(self, run_name: str,
                 model: FullPageHTR,
                 ds_name: str,
                 train_data: DataLoader,
                 val_data: DataLoader,
                 optimizer: Optimizer,
                 num_epochs: int,
                 device: torch.device,
                 normalization_steps: int):

        self.normalization_steps = normalization_steps

        self.model = model
        self.train_data, self.val_data = train_data, val_data
        self.num_epochs = num_epochs
        self.optimizer = optimizer
        self.ds_name = ds_name
        self.run_name = run_name
        self.device = device

    def _init_wandb(self):

        wandb.init(project="fullpage-htr-base",
                   config={
                       "run_name": self.run_name,
                       "learning_rate": self.optimizer.param_groups[0]["lr"],
                       "epochs": self.num_epochs,
                       "dataset": self.ds_name
                   })
        wandb.define_metric('Train')
        wandb.define_metric('Val')

    def train_epoch_ga(self, epoch_nr, ds_size):
      self.model.train()
      total_loss = 0.0
      total_cer = 0.0
      total_wer = 0.0

      nr_batches = 0
      b_loss = 0.0
      b_cer = 0.0
      b_wer = 0.0
      for idx, batch in enumerate(tqdm(self.train_data)):
        inputs, labels = batch
        inputs = inputs.to(self.device)
        labels = labels.to(self.device)

        outputs, loss = self.model.forward_teacher_forcing(inputs, labels)
        loss = loss / (self.normalization_steps * labels.size(0))
        loss.backward()
        b_loss += loss.item()

        _, preds = outputs.max(-1)
        res = self.model.calculate_metrics(preds, labels)

        b_cer += res["CER"] / (self.normalization_steps * labels.size(0))
        b_wer += res["WER"] / (self.normalization_steps * labels.size(0))

        if idx > 0 and (idx % self.normalization_steps == 0 or idx + 1 == len(self.train_data)):

          self.optimizer.step()
          self.optimizer.zero_grad()
          if self.wanda:
            wandb.log({
              'Train Loss': b_loss ,
              'Train CER' : b_cer ,
              'Train WER' : b_wer ,
              'Train': idx + ds_size * epoch_nr
            })

          total_loss += b_loss
          total_cer  += b_cer
          total_wer  += b_wer
          b_cer = 0.0
          b_wer = 0.0
          b_loss = 0.0

      total_loss /= (ds_size // self.normalization_steps)
      total_cer  /= (ds_size // self.normalization_steps)
      total_wer  /= (ds_size // self.normalization_steps)

      return total_loss, total_cer, total_wer




    def train_epoch(self, epoch_nr, ds_size):
      self.model.train()
      total_loss = 0.0
      total_cer = 0.0
      total_wer = 0.0

      nr_batches = 0
      b_cer = 0.0
      b_wer = 0.0
      for i, mb in enumerate(tqdm(self.train_data)):

        inputs, labels = mb
        inputs = inputs.to(self.device)
        labels = labels.to(self.device)

        output_logits, loss = self.model.forward_teacher_forcing(inputs, labels)

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        _, preds = output_logits.max(-1)
        res = self.model.calculate_metrics(preds, labels)


        b_cer = res["CER"]
        b_wer = res["WER"]
        if self.wanda:
          wandb.log({
              'Train Loss': loss.item() / labels.size(0),
              'Train CER' : b_cer / labels.size(0),
              'Train WER' : b_wer / labels.size(0),
              'Train': i + ds_size * epoch_nr
            })

        total_loss += loss.item() / labels.size(0)
        total_cer += b_cer / labels.size(0)
        total_wer += b_wer / labels.size(0)



      total_loss /= ds_size
      total_cer /= ds_size
      total_wer /= ds_size

      return total_loss, total_cer, total_wer

    def val_epoch(self, epoch_nr, ds_size):
      self.model.eval()
      total_loss = 0.0
      total_cer = 0.0
      total_wer = 0.0
      nr_batches = 0
      b_loss = 0.0
      b_cer = 0.0
      b_wer = 0.0
      with torch.no_grad():
        for i, mb in enumerate(tqdm(self.val_data)):

          inputs, labels = mb
          inputs = inputs.to(self.device)
          labels = labels.to(self.device)
          self.optimizer.zero_grad()

          output_logits, _, loss = self.model.forward(inputs, labels)


          b_loss = loss.item()
          _, preds = output_logits.max(-1)
          res = self.model.calculate_metrics(preds, labels)
          b_cer = res["CER"]
          b_wer = res["WER"]
          if self.wanda:
            wandb.log({
                'Val Loss': b_loss / labels.size(0),
                'Val CER' : b_cer / labels.size(0),
                'Val WER' : b_wer / labels.size(0),
                'Val' : i  + ds_size * epoch_nr})

          total_loss += b_loss / labels.size(0)
          total_cer += b_cer / labels.size(0)
          total_wer += b_wer / labels.size(0)


          torch.cuda.empty_cache()


        total_loss /= ds_size
        total_cer /= ds_size
        total_wer /= ds_size
      return total_loss, total_cer, total_wer

    def train(self, train_len, val_len, wanda=True):
        self.wanda = wanda
        if wanda:
          self._init_wandb()
        for i in range(self.num_epochs):
            print(f'#.Epoch {i}')
            torch.cuda.empty_cache()
            train_loss, train_cer, train_wer = self.train_epoch_ga(i, train_len)
            val_loss, val_cer, val_wer = self.val_epoch(i, val_len)
            print(f"Train Loss avg: {train_loss}, Train CER avg: {train_cer}, Train WER avg: {train_wer}")
            print(f"Val Loss avg: {val_loss}, Val CER avg: {val_cer}, Val WER avg: {val_wer}")
        if wanda:
          wandb.finish()

In [None]:
import gc

try:
  device = "cuda"
  torch.cuda.empty_cache()
  gc.collect()

  model = FullPageHTR(ds.embedding_loader).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
  trainer = ModelTrainer("Testing_run", model, ds_name="IAM_forms" , train_data=dl_train, val_data=dl_val, optimizer=optimizer, num_epochs=100, device=device, normalization_steps=56)

  wandb.finish()
  trainer.train(train_len, val_len, wanda=True)
except RuntimeError:
  del model
  print("Error time!!")



VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Train,▁▂▂▃▄▄▅▅▆▇▇█
Train CER,▁█▇▇█▇▇▇███▇
Train Loss,▁██▇▇▇▇▇▇▇▇▇
Train WER,▁▇██▇▇▇█▇▇▇▇
Val,▁▂▃▄▅▆▇█
Val CER,▃▂▅█▁▁▇▆
Val Loss,▄▄▃█▁█▇▂
Val WER,▁▁▁▁▁▁▁▁

0,1
Train,615.0
Train CER,0.63952
Train Loss,1.71414
Train WER,0.49536
Val,7.0
Val CER,0.39235
Val Loss,0.85373
Val WER,0.25


#.Epoch 0


  self.pid = os.fork()
100%|██████████| 616/616 [01:38<00:00,  6.25it/s]
100%|██████████| 77/77 [05:25<00:00,  4.22s/it]


Train Loss avg: 0.033642098963768644, Train CER avg: 0.012608052231371403, Train WER avg: 0.009617536328732967
Val Loss avg: 0.8808692233604297, Val CER avg: 0.39893606305122375, Val WER avg: 0.25438597798347473
#.Epoch 1


100%|██████████| 616/616 [01:41<00:00,  6.05it/s]
100%|██████████| 77/77 [05:26<00:00,  4.24s/it]


Train Loss avg: 0.031086089246749105, Train CER avg: 0.0122104836627841, Train WER avg: 0.008943618275225163
Val Loss avg: 0.875373188602297, Val CER avg: 0.3911570608615875, Val WER avg: 0.27184948325157166
#.Epoch 2


100%|██████████| 616/616 [01:41<00:00,  6.06it/s]
100%|██████████| 77/77 [05:27<00:00,  4.25s/it]


Train Loss avg: 0.03023223981696677, Train CER avg: 0.010789155960083008, Train WER avg: 0.011885594576597214
Val Loss avg: 0.9597880469079604, Val CER avg: 0.35886499285697937, Val WER avg: 0.5136774778366089
#.Epoch 3


100%|██████████| 616/616 [01:42<00:00,  6.01it/s]
100%|██████████| 77/77 [05:27<00:00,  4.25s/it]


Train Loss avg: 0.028843201219116325, Train CER avg: 0.00980830006301403, Train WER avg: 0.013765827752649784
Val Loss avg: 1.0286805893768345, Val CER avg: 0.35996827483177185, Val WER avg: 0.4800395965576172
#.Epoch 4


100%|██████████| 616/616 [01:39<00:00,  6.22it/s]
100%|██████████| 77/77 [05:24<00:00,  4.21s/it]


Train Loss avg: 0.028094056966508945, Train CER avg: 0.009781633503735065, Train WER avg: 0.012882853858172894
Val Loss avg: 1.045832660114556, Val CER avg: 0.3601241111755371, Val WER avg: 0.4801212251186371
#.Epoch 5


100%|██████████| 616/616 [01:39<00:00,  6.20it/s]
100%|██████████| 77/77 [05:25<00:00,  4.23s/it]


Train Loss avg: 0.02768098936089641, Train CER avg: 0.009749142453074455, Train WER avg: 0.01237027533352375
Val Loss avg: 1.0655400617080824, Val CER avg: 0.36027786135673523, Val WER avg: 0.47713711857795715
#.Epoch 6


100%|██████████| 616/616 [01:38<00:00,  6.26it/s]
100%|██████████| 77/77 [05:24<00:00,  4.21s/it]


Train Loss avg: 0.027443700763336444, Train CER avg: 0.009749077260494232, Train WER avg: 0.01239317748695612
Val Loss avg: 1.0675572083707443, Val CER avg: 0.360395222902298, Val WER avg: 0.47757330536842346
#.Epoch 7


100%|██████████| 616/616 [01:38<00:00,  6.23it/s]
100%|██████████| 77/77 [05:28<00:00,  4.26s/it]


Train Loss avg: 0.02727448778505159, Train CER avg: 0.009760663844645023, Train WER avg: 0.012286421842873096
Val Loss avg: 1.0638550300347178, Val CER avg: 0.36037254333496094, Val WER avg: 0.47591090202331543
#.Epoch 8


100%|██████████| 616/616 [01:42<00:00,  6.01it/s]
100%|██████████| 77/77 [05:25<00:00,  4.23s/it]


Train Loss avg: 0.027133453455935052, Train CER avg: 0.00984896905720234, Train WER avg: 0.012080412358045578
Val Loss avg: 1.0591594169014378, Val CER avg: 0.3603745400905609, Val WER avg: 0.47587358951568604
#.Epoch 9


100%|██████████| 616/616 [01:39<00:00,  6.22it/s]
100%|██████████| 77/77 [05:23<00:00,  4.20s/it]


Train Loss avg: 0.02702066162903491, Train CER avg: 0.009783432818949223, Train WER avg: 0.012083125300705433
Val Loss avg: 1.0648263802653866, Val CER avg: 0.3603745400905609, Val WER avg: 0.47587358951568604
#.Epoch 10


100%|██████████| 616/616 [01:40<00:00,  6.13it/s]
100%|██████████| 77/77 [05:25<00:00,  4.23s/it]


Train Loss avg: 0.026938817618027716, Train CER avg: 0.0098212119191885, Train WER avg: 0.012039403431117535
Val Loss avg: 1.0712628610301436, Val CER avg: 0.3603745400905609, Val WER avg: 0.47587358951568604
#.Epoch 11


100%|██████████| 616/616 [01:46<00:00,  5.78it/s]
100%|██████████| 77/77 [05:26<00:00,  4.24s/it]


Train Loss avg: 0.02686247157912653, Train CER avg: 0.009833108633756638, Train WER avg: 0.01209350023418665
Val Loss avg: 1.0672041622170232, Val CER avg: 0.3603745400905609, Val WER avg: 0.47587358951568604
#.Epoch 12


100%|██████████| 616/616 [01:41<00:00,  6.06it/s]
100%|██████████| 77/77 [05:25<00:00,  4.23s/it]


Train Loss avg: 0.02679716489676918, Train CER avg: 0.009842258878052235, Train WER avg: 0.012180074118077755
Val Loss avg: 1.0692336789348669, Val CER avg: 0.3603745400905609, Val WER avg: 0.47587358951568604
#.Epoch 13


100%|██████████| 616/616 [01:39<00:00,  6.19it/s]
100%|██████████| 77/77 [05:25<00:00,  4.23s/it]


Train Loss avg: 0.026745204816802176, Train CER avg: 0.009807669557631016, Train WER avg: 0.01218715775758028
Val Loss avg: 1.0738438226674731, Val CER avg: 0.3603745400905609, Val WER avg: 0.47587358951568604
#.Epoch 14


100%|██████████| 616/616 [01:40<00:00,  6.12it/s]
 88%|████████▊ | 68/77 [04:46<00:37,  4.17s/it]