In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm.autonotebook import tqdm
from glob import glob

from sklearn.model_selection import train_test_split

!pip install -qq editdistance torchsummary
import editdistance
from torchsummary import summary

seed = 42

#### As another and more classic approach: use CNN-LSTM model with OCR Loss 

In [2]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

PATH = r'../input/handwritten/synthetic-data/'
LR = 0.001
BATCH_SIZE = 32
HIDDEN = 512
ENC_LAYERS = 2
DEC_LAYERS = 2
N_HEADS = 4
DROPOUT = 0.1
IMG_WIDTH = 256
IMG_HEIGHT = 64
EPOCHS = 100

VOCAB = ['PAD', 'SOS', ' ',] + [char for char in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'] + ['EOS']

char2idx = {char: idx for idx, char in enumerate(VOCAB)}
idx2char = {idx: char for idx, char in enumerate(VOCAB)}

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currenly using "{device.upper()}" device.')

In [3]:
def text_to_labels(text, char2idx=char2idx):
    return [char2idx['SOS']] + [char2idx[i.upper()] for i in text if i.upper() in char2idx.keys()] + [char2idx['EOS']]

def labels_to_text(text, idx2char=idx2char):
    S = "".join([idx2char[i] for i in text])
    if S.find('EOS') == -1:
        return S
    else:
        return S[:S.find('EOS')]

def char_error_rate(p_seq1, p_seq2):
    p_vocab = set(p_seq1 + p_seq2)
    p2c = dict(zip(p_vocab, range(len(p_vocab))))
    c_seq1 = [chr(p2c[p]) for p in p_seq1]
    c_seq2 = [chr(p2c[p]) for p in p_seq2]
    return editdistance.eval(''.join(c_seq1),
                             ''.join(c_seq2)) / max(len(c_seq1), len(c_seq2))

In [4]:
transforms = T.Compose([
                        T.ToPILImage(),
                        T.Resize((IMG_HEIGHT, IMG_WIDTH)),
                        T.ColorJitter(contrast=(0.5,1),saturation=(0.5,1)),
                        T.RandomRotation(degrees=(-9, 9), fill=255),
                        T.RandomAffine(10 ,None ,[0.6 ,1] ,3 ,fillcolor=255),
                        T.ToTensor()
                        ])
valid_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMG_HEIGHT, IMG_WIDTH)),
                              T.ToTensor()
                              ])

In [5]:
images_paths = glob(PATH+'*.png')
images_paths = sorted([str(path) for path in images_paths])
images_labels = [path.split('/')[-1].split('_')[0] for path in images_paths]
df = pd.DataFrame(data={'path': images_paths, 'label': images_labels})
df.sample(3)

In [6]:
df_train, df_valid = train_test_split(df, test_size=0.25, shuffle=True, random_state=seed)
print(f'Train size: {df_train.shape[0]}, validation size: {df_valid.shape[0]}')

In [7]:
class CharDataset(Dataset):
    def __init__(self, dataframe, transforms=transforms):
        self.dataframe = dataframe
        self.transforms = transforms
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx].squeeze()
        image = cv2.imread(row['path'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = transforms(image)
        label = text_to_labels(row['label'])
        return torch.FloatTensor(image), torch.LongTensor(label) 
    
    def collate_fn(self, batch):
        x_padded = []
        max_y_len = max([i[1].size(0) for i in batch])
        y_padded = torch.LongTensor(max_y_len, len(batch))
        y_padded.zero_()

        for i in range(len(batch)):
            x_padded.append(batch[i][0].unsqueeze(0))
            y = batch[i][1]
            y_padded[:y.size(0), i] = y

        x_padded = torch.cat(x_padded)
        return x_padded.to(device), y_padded.to(device)

In [8]:
class TransformerModel(nn.Module):
    def __init__(self, bb_name, outtoken, hidden, enc_layers=1, dec_layers=1, nhead=1, dropout=0.1, pretrained=False):
        super(TransformerModel, self).__init__()
        self.backbone = torchvision.models.__getattribute__(bb_name)(pretrained=pretrained)
        self.backbone.fc = nn.Conv2d(2048, int(hidden/2), 1)

        self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.decoder = nn.Embedding(outtoken, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        self.transformer = nn.Transformer(d_model=hidden, nhead=nhead, num_encoder_layers=enc_layers,
                                          num_decoder_layers=dec_layers, dim_feedforward=hidden * 4, dropout=dropout,
                                          activation='relu')

        self.fc_out = nn.Linear(hidden, outtoken)
        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)

    def forward(self, src, trg):
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(device) 
        x = self.backbone.conv1(src)

        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x) # [64, 2048, 2, 8] : [B,C,H,W]
            
        x = self.backbone.fc(x) # [64, 256, 2, 8] : [B,C,H,W]
        x = x.permute(0, 3, 1, 2) # [64, 8, 256, 2] : [B,W,C,H]
        x = x.flatten(2) # [64, 8, 512] : [B,W,CH]
        x = x.permute(1, 0, 2) # [8, 64, 512] : [W,B,CH]
        src_pad_mask = self.make_len_mask(x[:, :, 0])
        src = self.pos_encoder(x) # [8, 64, 512]
        trg_pad_mask = self.make_len_mask(trg)
        trg = self.decoder(trg)
        trg = self.pos_decoder(trg)

        output = self.transformer(src, trg, src_mask=self.src_mask, tgt_mask=self.trg_mask,
                                  memory_mask=self.memory_mask,
                                  src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=trg_pad_mask,
                                  memory_key_padding_mask=src_pad_mask) # [13, 64, 512] : [L,B,CH]
        output = self.fc_out(output) # [13, 64, 92] : [L,B,H]

        return output
    
class PositionalEncoding(nn.Module):  # when having sentences
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.scale = nn.Parameter(torch.ones(1))

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.scale * self.pe[:x.size(0), :]
        return self.dropout(x) 

In [10]:
def train_one_batch(model, data, optimizer, criterion):
    model.train()
    image, target = data
    optimizer.zero_grad()
    output = model(image, target[:-1, :])
    loss = criterion(output.view(-1, output.shape[-1]), torch.reshape(target[1:, :], (-1,)))
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate_one_batch(model, data, criterion):
    model.eval()
    image, target = data
    output = model(image, target[:-1, :])
    loss = criterion(output.view(-1, output.shape[-1]), torch.reshape(target[1:, :], (-1,)))
    return loss.item()

def evaluate(model, dataloader, max_len=30):  # assuming dataloader has batch_size=1
    model.eval()
    wer_overall = 0
    cer_overall = 0
    with torch.no_grad():
        for src, trg in tqdm(dataloader, leave=False):
            out_indexes = [char2idx['SOS'], ]

            for i in range(max_len):
                trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                output = model(src, trg_tensor)
                out_token = output.argmax(2)[-1].item()
                out_indexes.append(out_token)
                if out_token == char2idx['EOS']:
                    break
                    
            out_char = labels_to_text(out_indexes[1:])
            real_char = labels_to_text(trg[1:, 0].detach().cpu().numpy()).lower()
            wer_overall += int(real_char != out_char)
            if out_char:
                cer = char_error_rate(real_char, out_char)
            else:
                cer = 1
            
            cer_overall += cer
    
    return cer_overall / len(dataloader) * 100, wer_overall / len(dataloader) * 100

@torch.no_grad()
def prediction(model, filepath='random', max_len=30):
    label = None
    if filepath == 'random':
        idx = np.random.randint(len(df_valid))
        filepath = df_valid.iloc[idx, 0]
        label = df_valid.iloc[idx, 1]

    model.eval()
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    image = valid_transforms(img)
    src = torch.FloatTensor(image).unsqueeze(0).to(device)

    out_indexes = [char2idx['SOS'], ]

    for i in range(max_len):
                
        trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                
        output = model(src, trg_tensor)
        out_token = output.argmax(2)[-1].item()
        out_indexes.append(out_token)
        if out_token == char2idx['EOS']:
            break
    preds = labels_to_text(out_indexes[1:], idx2char)
    plt.figure(figsize=(6,4))
    plt.title(f'Prediction: {preds}, Truth: {label if label is not None else "NO label"}')
    plt.imshow(img)
    plt.tight_layout()
    plt.show()
    plt.pause(0.001)
    
    return preds

In [11]:
train_dataset = CharDataset(df_train, transforms)
valid_dataset = CharDataset(df_valid, valid_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)

In [12]:
model = TransformerModel('resnet50', len(VOCAB), hidden=HIDDEN, enc_layers=ENC_LAYERS, dec_layers=DEC_LAYERS,   
                         nhead=N_HEADS, dropout=DROPOUT).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-6)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, mode='min', patience=10, min_lr=1e-6,)

In [13]:
train_losses, valid_losses, valid_cers, valid_wers = [], [], [], []

for epoch in range(EPOCHS):
    print(f'{epoch+1}/{EPOCHS} epoch.')
    epoch_train_losses, epoch_valid_losses = [], []
    for _, batch in enumerate(tqdm(train_dataloader, leave=False)):
        loss = train_one_batch(model, batch, optimizer, criterion)
        epoch_train_losses.append(loss)
        
    train_epoch_loss = np.array(epoch_train_losses).mean()
    train_losses.append(train_epoch_loss)
    
    for _, batch in enumerate(tqdm(valid_dataloader, leave=False)):
        loss = validate_one_batch(model, batch, criterion)
        epoch_valid_losses.append(loss)
        
    valid_epoch_loss = np.array(epoch_valid_losses).mean()
    valid_losses.append(valid_epoch_loss)
    print(f'Train loss: {train_epoch_loss:.4f}, validation loss: {valid_epoch_loss:.4f}')
    
    if (epoch + 1) % 10 == 0:
        valid_cer, valid_wer = evaluate(model, valid_dataloader)
        valid_cers.append(valid_cer)
        valid_wers.append(valid_wer)
        print(f'Char_error_rate: {valid_cer:.4f}, Word_error_rate: {valid_wer:.4f}')
    
    scheduler.step(valid_epoch_loss)
    
    if (epoch+1) % 10 == 0:
        pred = prediction(model)
        torch.save(model.state_dict(), 'model.pth')

In [None]:
prediction(model)

### LSTM version. 
#### Transformer model seems too complex for such task.

In [None]:
fname2label = lambda name: str(name).split('_')[0].split('/')[-1]
images = glob(PATH+'*.png')

In [None]:
vocab = 'QWERTYUIOPASDFGHJKLZXCVBNMqwertyuiopasdfghjklzxcvbnm'
BATCH, TIMESTEP, VOCAB = 64, 32, len(vocab)
H, W = 32, 128

In [None]:
class OCRDataset(Dataset):

    def __init__(self, items, vocab=vocab, preprocess_shape=(H,W), timesteps=TIMESTEP):
        super().__init__()
        self.items = items
        self.charList = {ix+1:ch for ix,ch in enumerate(vocab)}
        self.charList.update({0: '`'})
        self.invCharList = {v:k for k,v in self.charList.items()}
        self.ts = timesteps

    def __len__(self):
        return len(self.items)

    def sample(self):
        return self[np.random.randint(len(self))]

    def __getitem__(self, ix):
        item = self.items[ix]
        image = cv2.imread(str(item), 0)
        label = fname2label(item)
        return image, label

    def collate_fn(self, batch):
        images, labels, label_lengths, label_vectors, input_lengths = [], [], [], [], []
        for image, label in batch:
            images.append(torch.Tensor(self.preprocess(image))[None,None])
            label_lengths.append(len(label))
            labels.append(label)
            label_vectors.append(self.str2vec(label))
            input_lengths.append(self.ts)
        images = torch.cat(images).float().to(device)
        label_lengths = torch.Tensor(label_lengths).long().to(device)
        label_vectors = torch.Tensor(label_vectors).long().to(device)
        input_lengths = torch.Tensor(input_lengths).long().to(device)
        return images, label_vectors, label_lengths, input_lengths, labels

    def str2vec(self, string, pad=True):
        string = ''.join([s for s in string if s in self.invCharList])
        val = list(map(lambda x: self.invCharList[x], string)) 
        if pad:
            while len(val) < self.ts:
                val.append(0)
        return val
    
    def preprocess(self, img, shape=(32,128)):
        target = np.ones(shape)*255
        try:
            H, W = shape
            h, w = img.shape
            fx = H/h
            fy = W/w
            f = min(fx, fy)
            _h = int(h*f)
            _w = int(w*f)
            _img = cv2.resize(img, (_w,_h))
            target[:_h,:_w] = _img
        except:
            pass
        return (255-target)/255 # add augmentations?

    def decoder_chars(self, pred):
        decoded = ""
        last = ""
        pred = pred.cpu().detach().numpy()
        for i in range(len(pred)):
            k = np.argmax(pred[i])
            if k > 0 and self.charList[k] != last:
                last = self.charList[k]
                decoded = decoded + last
            elif k > 0 and self.charList[k] == last:
                continue
            else:
                last = ""
        return decoded.replace(" "," ")

    def wer(self, preds, labels):
        c = 0
        for p, l in zip(preds, labels):
            c += p.lower().strip() != l.lower().strip()
        return round(c/len(preds), 4)
    
    def cer(self, preds, labels):
        c, d = [], []
        for p, l in zip(preds, labels):
            c.append(editdistance.eval(p, l) / len(l))
        return round(np.mean(c), 4)

    def evaluate(self, model, ims, labels, lower=False):
        model.eval()
        preds = model(ims).permute(1,0,2) 
        preds = [self.decoder_chars(pred) for pred in preds]
        return {'char-error-rate': self.cer(preds, labels),
                'word-error-rate': self.wer(preds, labels),
                'char-accuracy' : 1 - self.cer(preds, labels),
                'word-accuracy' : 1 - self.wer(preds, labels)}

In [None]:
train_items, valid_items = train_test_split(glob(PATH+'*.png'), test_size=0.2, random_state=1)
train_dataset = OCRDataset(train_items)
valid_dataset = OCRDataset(valid_items)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH, collate_fn=train_dataset.collate_fn, drop_last=True, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH, collate_fn=valid_dataset.collate_fn, drop_last=True)

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, ni, no, ks=3, st=1, padding=1, pool=2, drop=0.2):
        super().__init__()
        self.ks = ks
        self.block = nn.Sequential(
            nn.Conv2d(ni, no, kernel_size=ks, stride=st, padding=padding),
            nn.BatchNorm2d(no, momentum=0.3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(pool),
            nn.Dropout2d(drop)
        )
    def forward(self, x):
        return self.block(x)
    
class Permute(nn.Module):
    def forward(self, input):
        return input.permute(2,0,1)

class Reshape(nn.Module):
    def __init__(self, shape=(256, 32)):
        super().__init__()
        self.shape = shape
    def forward(self, input):
        return input.view(-1, *self.shape)

class Ocr(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.model = nn.Sequential(
            BasicBlock( 1, 128),
            BasicBlock(128, 128),
            BasicBlock(128, 256, pool=(4,2)),
            Reshape(),
            Permute()
        )
        self.rnn = nn.Sequential(
            nn.LSTM(256, 256, num_layers=2, dropout=0.2, bidirectional=True),
        )
        self.classification = nn.Sequential(
            nn.Linear(512, vocab+1),
            nn.LogSoftmax(-1),
        )
    def forward(self, x):
        x = self.model(x)
        x, lstm_states = self.rnn(x)
        y = self.classification(x)
        return y

In [None]:
def ctc(log_probs, target, input_lengths, target_lengths, blank=0):
    loss = nn.CTCLoss(blank=blank, zero_infinity=True)
    ctc_loss = loss(log_probs, target, input_lengths, target_lengths)
    return ctc_loss

In [None]:
model = Ocr(len(vocab)).to(device)

In [None]:
def train_batch(data, model, optimizer, criterion):
    model.train()
    imgs, targets, label_lens, input_lens, labels = data
    optimizer.zero_grad()
    preds = model(imgs)
    loss = criterion(preds, targets, input_lens, label_lens)
    loss.backward()
    optimizer.step()
    results = train_dataset.evaluate(model, imgs.to(device), labels)
    return loss.item(), results

@torch.no_grad()
def validate_batch(data, model, criterion):
    model.eval()
    imgs, targets, label_lens, input_lens, labels = data
    preds = model(imgs)
    loss = criterion(preds, targets, input_lens, label_lens)
    return loss.item(), valid_dataset.evaluate(model, imgs.to(device), labels)

In [None]:
criterion = ctc
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=3e-6)
scheduler = scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, mode='min', patience=5, min_lr=1e-6,)

In [None]:
train_losses, valid_losses = [], []

for epoch in range(EPOCHS):
    print(f'{epoch+1}/{EPOCHS} epoch.')
    epoch_train_losses, epoch_valid_losses = [], []
    for _, batch in enumerate(tqdm(train_dataloader, leave=False)):
        loss, train_results = train_batch(batch, model, optimizer, criterion)
        epoch_train_losses.append(loss)
        
    train_epoch_loss = np.array(epoch_train_losses).mean()
    train_losses.append(train_epoch_loss)
    
    for _, batch in enumerate(tqdm(valid_dataloader, leave=False)):
        loss, valid_results = validate_batch(batch, model, criterion)
        epoch_valid_losses.append(loss)
        
    valid_epoch_loss = np.array(epoch_valid_losses).mean()
    valid_losses.append(valid_epoch_loss)
    print(f'Train loss: {train_epoch_loss:.4f}, validation loss: {valid_epoch_loss:.4f}')
    
    train_cer, train_wer = train_results['char-accuracy'], train_results['word-accuracy']
    valid_cer, valid_wer = valid_results['char-accuracy'], valid_results['word-accuracy']
    print(f'Train char accuracy: {train_cer:.4f}, train word accuracy: {train_wer:.4f}')
    print(f'Validation char accuracy: {valid_cer:.4f}, validation word accuracy: {valid_wer:.4f}')
    print()
    
    scheduler.step(valid_epoch_loss)
    
    if (epoch+1) % 10 == 0:
        torch.save(model.state_dict(), 'model.pth')

In [None]:
model.eval()
images, label_vectors, label_lengths, input_lengths, labels = next(iter(valid_dataloader))
preds = model(images).permute(1,0,2) 
preds = [valid_dataset.decoder_chars(pred) for pred in preds]
images = 255. - images.detach().cpu().numpy().squeeze() * 255 
plt.figure(figsize=(15,15))
for i in range(BATCH):
    plt.subplot(16, 4, i+1)
    plt.title(f'True: {labels[i]}, Predicted: {preds[i]}')
    plt.imshow(images[i], cmap='gray')
plt.tight_layout()
plt.show()

## TrOCR pretrained

In [14]:
class CharDataset(Dataset):
    def __init__(self, df, processor, max_target_length=32):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        record = self.df.iloc[idx].squeeze()
        image = cv2.imread(record['path'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        labels = self.processor.tokenizer(record['label'], 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids

        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [15]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

train_dataset = CharDataset(df=df_train, processor=processor)
eval_dataset = CharDataset(df=df_valid, processor=processor)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=4)

In [16]:
from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1").to(device)

In [17]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 32
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [18]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    
    cer = 0.0
    for pred, label in zip(pred_str, label_str):
        cer += char_error_rate(pred, label)
    cer /= len(pred_str)

    return cer

In [20]:
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(4):  # loop over the dataset multiple times
    # train
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_dataloader, leave=False):
      # get the inputs
        for k,v in batch.items():
            batch[k] = v.to(device)

        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()

    print(f"Loss after epoch {epoch+1}:", train_loss/len(train_dataloader))
    
    # evaluate
    model.eval()
    valid_cer = 0.0
    with torch.no_grad():
        for batch in tqdm(eval_dataloader, leave=False):
            # run batch generation
            outputs = model.generate(batch["pixel_values"].to(device))
            # compute metrics
            cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])  # very slow beacuse of for loop with strings
            valid_cer += cer 

    print("Validation CER:", valid_cer / len(eval_dataloader))

In [21]:
model.eval()
batch = next(iter(eval_dataloader))

with torch.no_grad():
    preds = model.generate(batch["pixel_values"].to(device))
    
labels = batch["labels"]    
pred_str = processor.batch_decode(preds, skip_special_tokens=True)
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels, skip_special_tokens=True)    

plt.figure(figsize=(7,7))
for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.title(f'True: {label_str[i]}, Predicted: {pred_str[i]}')
    plt.imshow(batch['pixel_values'][i].cpu().detach().numpy().transpose(1,2,0))
plt.tight_layout()
plt.show()