# Installs

In [1]:
# !pip install torchsummaryX
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.1-py3-none-any.whl (471 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:0

# Imports

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
import editdistance
# from torchsummary import summary
from torchinfo import summary
import torch.nn.functional as F
import gc
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

Device:  cuda


# Hyperparams

In [13]:
config = {
    "lr"         : 0.002,
    "epochs"     : 20,
    "batch_size" : 64,
    'img_height': 128,
    'img_width': 1000,
    'lstm_num_layers': 4,
    'lstm_hidden_size': 256,
}

#Dataset

In [23]:
import gc
gc.collect()



425

In [24]:
IMAGE_HEIGHT = config['img_height']
IMAGE_WIDTH = config['img_width']
BATCH_SIZE = config['batch_size']
NUM_EPOCHS = config['epochs']
LEARNING_RATE = config['lr']

# Load the dataset from Hugging Face
dataset = load_dataset("Teklia/IAM-line")

# Define transforms
# transform = transforms.Compose([
#     transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
#     transforms.ToTensor(),
# ])
transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])


# Create a character-to-index mapping
chars = set(''.join(dataset['train']['text']))
char_to_index = {char: idx + 1 for idx, char in enumerate(sorted(chars))}
char_to_index['<PAD>'] = 0
index_to_char = {idx: char for char, idx in char_to_index.items()}
vocab_size = len(char_to_index)

# Tokenizer functions
def tokenize(text):
    return [char_to_index[char] for char in text]

def detokenize(indices):
    return ''.join([index_to_char[idx] for idx in indices if idx != 0])

# Dataset class
class IAMLinesDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, split, transform=None):
        self.dataset = hf_dataset[split]
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        text = item['text']

        if self.transform:
            image = self.transform(image)

        # print("Before tokenization: ", len(text), "text: ", text)
        target = torch.LongTensor(tokenize(text))

        return image, target

def collate_fn(batch):
    images, targets = zip(*batch)
    images = torch.stack(images)

    # Pad sequences
    targets = pad_sequence(targets, batch_first=True, padding_value=0)

    # Create mask for actual lengths
    target_lengths = torch.LongTensor([len(t) for t in targets])

    # return images, targets, target_lengths
    return images, targets

# Create datasets and dataloaders
train_dataset = IAMLinesDataset(dataset, 'train', transform)
val_dataset = IAMLinesDataset(dataset, 'validation', transform)
test_dataset = IAMLinesDataset(dataset, 'test', transform)

# Create dataloaders with the custom collate_fn
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [25]:
print(f"Number of samples in train set: {len(train_dataset)}")
print(f"Number of samples in test set: {len(test_dataset)}")
print(f"Number of samples in val set: {len(val_dataset)}")

print(f"Batch size: {BATCH_SIZE}")
print(f"Num batches in train_loader: {len(train_loader)}")
print(f"Num batches in test_loader: {len(test_loader)}")
print(f"Num batches in val_loader: {len(val_loader)}")

i=23
print(f"Shape of {i}-th item in train set:")
print(f"Image: {train_dataset[i][0].shape}, Label: {train_dataset[i][1].shape}")

print(f"Shape of {i}-th item in a batch of train loader:")
x = next(iter(train_loader)) # (batch of 32 images, batch of 32 labels)
print(f"Shape of entire batch of images: {x[0].shape}")
print(f"Shape of entire batch of labels: {x[1].shape}")

Number of samples in train set: 6482
Number of samples in test set: 2915
Number of samples in val set: 976
Batch size: 64
Num batches in train_loader: 102
Num batches in test_loader: 46
Num batches in val_loader: 16
Shape of 23-th item in train set:
Image: torch.Size([1, 128, 1000]), Label: torch.Size([39])
Shape of 23-th item in a batch of train loader:
Shape of entire batch of images: torch.Size([64, 1, 128, 1000])
Shape of entire batch of labels: torch.Size([64, 65])


# Model

In [26]:
class CNN_LSTM(nn.Module):
    def __init__(self, hidden_size, num_layers, num_classes):
        super(CNN_LSTM, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),  # Reduce height and width by 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))  # Further reduce height and width by 2
        )

        # LSTM part
        # LSTM input size = out_channels of last Conv2d * img height after all conv layers i.e input_img_width / (2**num_maxpool_layers)
        self.lstm = nn.LSTM(input_size=128*32, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True)
        # self.lstm = nn.LSTM(input_size=512*8, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True)
        nn.BatchNorm2d(2 * hidden_size)
        self.fc = nn.Linear(2 * hidden_size, num_classes)

    def forward(self, x):
        # Input x has shape (batch_size, channels, height, width)
        # print("input: ", x.shape) # (32, 1, 128, 1000]
        batch_size = x.size(0)

        # Pass through CNN (output will have shape (batch_size, 128, 32, 250))
        out = self.cnn(x)
        # print("cnn output: ", out.shape) # (32, 128, 32, 250)
        # Reshape output from CNN to prepare it for LSTM
        # New shape: (batch_size, 250, 128*32) where 250 is the reduced width (seq_len)
        # print("before reshape output: ", out.shape) # (32, 128, 32, 250)
        batch_size, cnn_output_channels, cnn_output_height, cnn_output_width = out.shape #(32, 128, 32, 250)
        out = out.reshape(batch_size, cnn_output_width, cnn_output_channels * cnn_output_height)
        # print("after reshape output: ", out.shape)

        # Pass through LSTM (input shape (batch_size, seq_len=250, input_size=128*32))
        out, _ = self.lstm(out)
        # print("lstm output: ", out.shape)

        out = self.fc(out)
        # print("fc output: ", out.shape)

        out = F.log_softmax(out, dim=2)

        return out


In [27]:
model = CNN_LSTM(hidden_size=config['lstm_hidden_size'], num_layers=config['lstm_num_layers'], num_classes=vocab_size).to(device)
batch_size=1
input_channels=1
img_height=128
img_width=1000
summary(model, (batch_size, input_channels, img_height, img_width))

Layer (type:depth-idx)                   Output Shape              Param #
CNN_LSTM                                 [1, 250, 80]              --
├─Sequential: 1-1                        [1, 128, 32, 250]         --
│    └─Conv2d: 2-1                       [1, 64, 128, 1000]        640
│    └─BatchNorm2d: 2-2                  [1, 64, 128, 1000]        128
│    └─ReLU: 2-3                         [1, 64, 128, 1000]        --
│    └─MaxPool2d: 2-4                    [1, 64, 64, 500]          --
│    └─Conv2d: 2-5                       [1, 128, 64, 500]         73,856
│    └─BatchNorm2d: 2-6                  [1, 128, 64, 500]         256
│    └─ReLU: 2-7                         [1, 128, 64, 500]         --
│    └─MaxPool2d: 2-8                    [1, 128, 32, 250]         --
├─LSTM: 1-2                              [1, 250, 512]             13,647,872
├─Linear: 1-3                            [1, 250, 80]              41,040
Total params: 13,763,792
Trainable params: 13,763,792
Non-trainabl

# Training

In [28]:
# Initialize model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CTCLoss(blank=0)
optimizer =  torch.optim.AdamW(model.parameters(), lr= config['lr'], weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
scaler = torch.amp.GradScaler('cuda')

In [29]:
def ctc_decode_tensor(predicted_indices):
    if not isinstance(predicted_indices, torch.Tensor):
        predicted_indices = torch.tensor(predicted_indices)

    batch_size, seq_length = predicted_indices.shape
    non_blank_mask = predicted_indices != 0
    diff_mask = torch.cat([torch.ones(batch_size, 1, dtype=torch.bool, device=predicted_indices.device),
                           predicted_indices[:, 1:] != predicted_indices[:, :-1]], dim=1)
    valid_mask = non_blank_mask & diff_mask
    decoded = predicted_indices * valid_mask.long()
    decoded_list = [seq[seq != 0].tolist() for seq in decoded]

    return decoded_list

def save_model(model, optimizer, scheduler, metric, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         metric[0]                  : metric[1],
         'epoch'                    : epoch},
         path
    )

def load_model(path, model, metric= 'valid_acc', optimizer= None, scheduler= None):

    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    epoch   = checkpoint['epoch']
    metric  = checkpoint[metric]

    return [model, optimizer, scheduler, epoch, metric]

def evaluate(model, data_loader, device):
    model.eval()
    total_distance = 0
    total_length = 0

    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device)
            targets = targets.to(device)

            outputs = model(images)

            # Get the predicted indices
            predicted_indices = torch.argmax(outputs, dim=2) #(batchsize, seqlen)
            decoded_predictions = ctc_decode_tensor(predicted_indices)

            # Convert predictions and targets to strings
            predicted_strings = [''.join([index_to_char[idx.item() if isinstance(idx, torch.Tensor) else idx] for idx in pred if idx != 0]) for pred in decoded_predictions]
            target_strings = [''.join([index_to_char[idx.item() if isinstance(idx, torch.Tensor) else idx] for idx in tgt if idx != 0]) for tgt in targets]

            # Compute edit distance
            for pred, tgt in zip(predicted_strings, target_strings):
                # print("Predicted string: ", pred)
                # print("Target string: ", tgt)
                distance = editdistance.eval(pred, tgt)
                # print("Distance: ", distance)
                total_distance += distance
                total_length += len(tgt)

            del images, targets, predicted_indices, decoded_predictions, predicted_strings, target_strings, outputs, distance
            torch.cuda.empty_cache()

    # Compute Character Error Rate (CER)
    cer = total_distance / total_length
    return cer

def train(model, train_loader, val_loader, criterion, optimizer, num_epochs, learning_rate, device):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')
        total_loss = 0

        for batch_idx, (images, targets) in enumerate(train_loader):
            images = images.to(device)
            targets = targets.to(device)
            # print("targets: ", targets.shape) #(32, 60)

            optimizer.zero_grad()

            # with torch.cuda.amp.autocast():
            with torch.amp.autocast('cuda'):
            # Forward pass
              outputs = model(images)
              # print("Forward output: ", outputs.shape) #(32, 250, 80)

              # Prepare CTC loss inputs
              input_lengths = torch.full(size=(outputs.size(0),), fill_value=outputs.size(1), dtype=torch.long)
              # print("input_lengths: ", input_lengths.shape) #(32)
              target_lengths = torch.sum(targets != 0, dim=1)
              # print("target_lengths: ", target_lengths.shape) #(32)

              # Compute loss
              loss = criterion(outputs.transpose(0, 1), targets, input_lengths, target_lengths)

            total_loss += loss.item()

            # Backward pass and optimize
            # loss.backward()
            # optimizer.step()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            del images, targets, outputs, input_lengths, target_lengths, loss
            torch.cuda.empty_cache()

            batch_bar.set_postfix(
            loss="{:.04f}".format(float(total_loss / (batch_idx + 1))),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr'])))
            batch_bar.update()

        avg_loss = total_loss / len(train_loader)
        print(f'\nEpoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

        # Validate after each epoch
        val_cer = evaluate(model, val_loader, device)
        print(f'Validation CER: {val_cer:.4f}')

        batch_bar.close()

    save_model(model, optimizer, scheduler, ['valid_dist', val_cer], epoch, "cnn-lstm-model.pth")


In [None]:
torch.cuda.empty_cache()
gc.collect()

train(model, train_loader, val_loader, criterion, optimizer, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, device=device)

Train: 100%|██████████| 102/102 [01:08<00:00,  1.78it/s, loss=4.2282, lr=0.002000]


Epoch [1/20], Average Loss: 4.2282




Validation CER: 1.0000


Train:   5%|▍         | 5/102 [00:03<01:04,  1.50it/s, loss=3.2479, lr=0.002000]