In [None]:
import os
import cv2
import math
import torch
import codecs
import random
import numpy as np
import editdistance
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [None]:
fn_char_list = 'charList.txt'
hindi_vocab = 'hindi_vocab.txt'
with codecs.open('full.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()
lines = [x.strip() for x in lines]
chars = set()
print(*lines[:5])
batchSize = 32
img_size = (128, 32)
max_text_len = 32

In [None]:
def preprocess(img, imgSize, dataAugmentation=False):
	"put img into target img of size imgSize, transpose for TF and normalize gray-values"

	# there are damaged files in IAM dataset - just use black image instead
	if img is None:
		img = np.zeros([imgSize[1], imgSize[0]])

	# create target image and copy sample image into it
	(wt, ht) = imgSize
	(h, w) = img.shape
	fx = w / wt
	fy = h / ht
	f = max(fx, fy)
	newSize = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht)
	img = cv2.resize(img, newSize)
	target = np.ones([ht, wt]) * 255
	target[0:newSize[1], 0:newSize[0]] = img

	# transpose for TF
	img = cv2.transpose(target)

	# normalize
	(m, s) = cv2.meanStdDev(img)
	m = m[0][0]
	s = s[0][0]
	img = img - m
	img = img / s if s>0 else img
	return img

In [None]:
class Sample:
    def __init__(self, gt_text, file_path):
        self.gt_text = gt_text
        self.file_path = file_path


class Batch:
    def __init__(self, gt_texts, imgs):
        self.imgs = np.stack(imgs, axis = 0)
        self.gt_texts = gt_texts

class DataLoader:
    def __init__(self, batch_size, img_size, max_text_len):
        """
        Loader for the dataset
        :param file_path: File path of the image
        :param batch_size: Batch size
        :param img_size: Size of the image
        :param max_text_len: Maximum text length
        """
        self.data_augmentation = False
        self.cur_idx = 0
        self.batch_size = batch_size
        self.img_size = img_size
        self.samples = []

        with codecs.open("full.txt", 'r', encoding='utf-8') as f:
            lines = f.readlines()
        lines = [x.strip() for x in lines]
        chars = set()
        print(lines[5])
        for line in lines:
            if not line or line[0] == '#':
                continue
            line_split = line.strip().split(' ')
            if line_split[0] == '\ufeff':
                continue
            file_name = line_split[0]

            # Ground Truth text starts at column 1
            gt_text = self.truncate_label(' '.join(line_split[1]), max_text_len)
            chars = chars.union(set(list(gt_text)))

            # Check if image not empty
            if not os.path.getsize(file_name):
                continue
            self.samples.append(Sample(gt_text, file_name)) # This can be a dictionary

        # Split into training, validation and testing sets
        n1, n2 = int(0.8*len(self.samples)), int(0.9*len(self.samples))
        self.train_samples = self.samples[:n1]
        self.validation_samples = self.samples[n1:n2]
        self.test_samples = self.samples[n2:]

        # Put words into lists
        self.train_words = [x.gt_text for x in self.train_samples]
        self.test_words = [x.gt_text for x in self.test_samples]
        self.valid_words = [x.gt_text for x in self.validation_samples]

        # Number of randomly chosen samples per epoch
        self.num_train_samples_per_epoch = 10000

        self.train_set()

        # List of chars in the dataset
        self.char_list = sorted(list(chars))

    @staticmethod
    def truncate_label(text, max_text_len):
        cost = 0
        for i in range(len(text)):
            if i != 0 and text[i] == text[i - 1]:
                cost += 2
            else:
                cost += 1
            if cost > max_text_len:
                return text[:i]
        return text

    def train_set(self):
        """
        Switch to randomly chosen subset of training set
        :return: None
        """
        self.data_augmentation = True
        self.cur_idx = 0
        random.shuffle(self.train_samples)
        self.samples = self.train_samples[:self.num_train_samples_per_epoch]

    def validation_set(self):
        """
        Switch to validation set
        :return:
        """
        self.data_augmentation = False
        self.cur_idx = 0
        random.shuffle(self.validation_samples)
        self.samples = self.validation_samples

    def test_set(self):
        """
        Switch to testing set
        :return:
        """
        self.data_augmentation = False
        self.cur_idx = 0
        random.shuffle(self.test_samples)
        self.samples = self.test_samples

    def get_iterator_info(self):
        """
        Current batch index and total number of batches
        :return:
        """
        return self.cur_idx // self.batch_size + 1, len(self.samples) // self.batch_size

    def has_next(self):
        return self.cur_idx + self.batch_size <= len(self.samples)

    def get_next(self):
        batch_range = range(self.cur_idx, self.cur_idx + self.batch_size)
        gt_texts = [self.samples[i].gt_text for i in batch_range]
        imgs = [preprocess(cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE), self.img_size, self.data_augmentation) for i in batch_range]
        self.cur_idx += self.batch_size
        return Batch(gt_texts, imgs)
dataloader = DataLoader(batchSize, img_size, max_text_len)
print(len(dataloader.char_list))
open(fn_char_list, 'w', encoding = 'utf-8').write(str().join(dataloader.char_list))
open(hindi_vocab, 'w', encoding='UTF-8').write(str(' ').join(dataloader.train_words + dataloader.valid_words))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class ViTHTRModel(nn.Module):
    def __init__(self, char_list, img_size=(128, 32), patch_size=8, embed_dim=128,max_text_len = 32 ,num_heads=8, num_layers=6, dropout=0.1):
        super().__init__()
        self.char_list = char_list
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
        self.embed_dim = embed_dim
        self.max_text_len = max_text_len
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        # Position embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)
        
        # Transformer encoder layers
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, embed_dim * 4, dropout)
            for _ in range(num_layers)
        ])
        
        # Output layer
        self.fc = nn.Linear(embed_dim, len(char_list))  # +1 for CTC blank
        
        # Positional encoding for transformer input
        self.pos_encoder = PositionalEncoding(embed_dim)
        
        self.init_weights()
        
        self.criterion = nn.CTCLoss()
        self.optimizer = torch.optim.RMSprop(self.parameters(), lr=0.001)

    def init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        if len(x.shape) != 4:
            x = x.unsqueeze(1)
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        
        # Patch embedding
        x = self.patch_embed(x)
        print(f"After applying the patch embeddings, the shape is {x.shape}")
        x = x.flatten(2).transpose(1, 2)
        print(f"After flattening and transposing, the shape is {x.shape}")
        
        # Add position embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        print(f"After applying the position embeddings the shape is {x.shape}")
        
        # Positional encoding
        x = x.transpose(0, 1)
        print(f"Prior to positional encoding the shape is {x.shape}")
        x = self.pos_encoder(x)
        print(f"After applying the positional encoding the shape is {x.shape}")
        
        # Transformer layers
        for layer in self.transformer_layers:
            x = layer(x)
        print(f"After the transformer layer, the shape is {x.shape}")
        
        # Output layer
        #x = x.transpose(0, 1) # Has to be changed depending on the model's output and errors.
        x = self.fc(x)
        print(f"After the final connected layer the shape is {x.shape}")
        
        # Apply log softmax for CTC loss
        log_probs = F.log_softmax(x, dim=-1)
        print(f"After the log softmax of the final connected layer the shape is {log_probs.shape}")
        print(f"The min value in the log softmax probs is{torch.min(log_probs)}")
        print(f"The max value in the log softmax probs is{torch.max(log_probs)}")
        
        return log_probs
    
    def to_sparse(self, texts):
        """put ground truth texts into sparse tensor for ctc_loss"""
        indices, values = [], []
        shape = [len(texts), 0]  # last entry must be max(labelList[i])

        for (batchElement, text) in enumerate(texts):
            label_str = [self.char_list.index(c) for c in text]
            if len(label_str) > shape[1]:
                shape[1] = len(label_str)
            for (i, label) in enumerate(label_str):
                indices.append([batchElement, i])
                values.append(label)

        return indices, values, shape

    def decoder_output_to_text(self, ctc_output, batch_size):
        """extract texts from output of CTC decoder"""
        print(f"Maximum value of the CTC output tensor is {torch.max(ctc_output)} and minimum value {torch.min(ctc_output)}")
        encoded_label_strs = [[] for _ in range(batch_size)]
        # Convert to a column vector
        ctc_output = ctc_output.unsqueeze(1)
        print(f"The output logits are of shape {ctc_output.shape}")
        blank = len(self.char_list)
        for b in range(batch_size):
            for label in ctc_output[b]:
                if label == blank:
                    break
                encoded_label_strs[b].append(label)
        return ["".join([self.char_list[c] for c in labelStr]) for labelStr in encoded_label_strs]
    
    def train_batch(self, batch):
        """feed a batch into the NN to train it"""
        self.train()
        num_batch_elements = len(batch.imgs)
        # indices = sparse[0], values = sparse[1], shape = sparse[2]
        sparse = self.to_sparse(batch.gt_texts) 
        imgs = torch.tensor(batch.imgs, dtype=torch.float32)
        if imgs.dim() == 3:
            imgs = imgs.unsqueeze(1)
        
        gt_texts = torch.sparse_coo_tensor(
        torch.LongTensor(sparse[0]).t(),
        torch.LongTensor(sparse[1]),
        torch.Size(sparse[2])
        ).to_dense()

        self.optimizer.zero_grad()
        log_probs = self.forward(imgs)
        
        # prepare input lengths (need to understand more about this!!)
        input_lengths = torch.full((num_batch_elements,), log_probs.size(0), dtype=torch.long)
        #print(f"Log probs first dimension is {log_probs.size(0)}")
        #print(f"batch size is {num_batch_elements}")
        #print(f"Normalized logits are of shape {log_probs.shape}")
        #print(f"Ground truth texts are of shape {gt_texts.shape}")
        #print(f"Input lengths is of shape {input_lengths.shape}")
        
        # prepare target lengths
        target_lengths = torch.LongTensor([len(t) for t in batch.gt_texts])
        #print(f"Target lengths is of shape {target_lengths.shape}")
        loss = self.criterion(log_probs, gt_texts, input_lengths, target_lengths)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def infer_batch(self, batch, calc_probability=False, probability_of_gt=False):
        """feed a batch into the NN to recognize the texts"""
        self.eval()
        num_batch_elements = len(batch.imgs)
        imgs = torch.tensor(batch.imgs, dtype=torch.float32)
        seq_len = torch.tensor([self.max_text_len] * num_batch_elements, dtype=torch.int32)

        with torch.no_grad():
            preds = self.forward(imgs)
            print(f"The predictions in the infer_batch() are of shape {preds.shape}")
            preds = preds.transpose(0,1)
            _, preds = preds.max(-1)
            print(f"The predictions in the infer_batch() are of shape {preds.shape}")
            preds = preds.transpose(1,0).contiguous().view(-1)

        texts = self.decoder_output_to_text(preds, num_batch_elements)

        probs = None
        if calc_probability:
            sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts)
            gt_texts = (torch.LongTensor(sparse[0]), torch.LongTensor(sparse[1]), torch.Size(sparse[2]))
            loss = self.criterion(preds, gt_texts, seq_len, torch.IntTensor([self.max_text_len] * num_batch_elements))
            probs = torch.exp(-loss).cpu().numpy()

        return texts, probs

    def save(self):
        """save model to file"""
        self.snap_id += 1
        torch.save(self.state_dict(), f'../model/snapshot_{self.snap_id}.pth')

    def load(self, path):
        """load model from file"""
        self.load_state_dict(torch.load(path))

# Usage
char_list = dataloader.char_list  # Example character list
model = ViTHTRModel(char_list)

# Test the model
dummy_input = torch.randn(32, 1, 128, 32)
output = model(dummy_input)
print("Model output shape:", output.shape)
summary(model, (1, 128, 32))

#### Training and validating a model

In [None]:
def validate(model, loader):
    model.eval()
    loader.validation_set()
    num_char_err = 0
    num_char_total = 0
    num_word_0K = 0
    num_word_total = 0

    while loader.has_next():
        iter_info = loader.get_iterator_info()
        print(f"Batch: {iter_info[0], '/', iter_info[1]}")
        batch = loader.get_next()
        (recognized, _) = model.infer_batch(batch)

        print(f"Ground Truth -> Recognized")
        for i in range(len(recognized)):
            num_word_0K += 1 if batch.gt_texts[i] == recognized[i] else 0
            num_word_total += 1
            dist = editdistance.eval(recognized[i], batch.gt_texts[i])
            num_char_err += dist
            num_char_total += len(batch.gt_texts[i])
            print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->',
                  '"' + recognized[i] + '"')
    return num_char_err

def train(model, loader):
    val_error_rates, losses = [], []
    best_error_rate = float("inf")
    no_improvement_since = 0
    early_stopping = 5
    epoch = 0
    while True:
        epoch += 1
        print(f"Currently on epoch {epoch}")
        
        # Training
        print("Training")
        loader.train_set()
        while loader.has_next():
            iter_info = loader.get_iterator_info()
            batch = loader.get_next()
            loss = model.train_batch(batch)
        losses.append(loss)
        
        # Validation
        char_error_rate = validate(model, loader)
        
        # If this is the least error rate, then save the model parameters
        if char_error_rate < best_error_rate:
            print("Character error rate improved")
            no_improvement_since = 0
            val_error_rates.append(char_error_rate)
        else:
            print("Error rate not improved")
            no_improvement_since += 1
        
        # Stop if there's no improvement in error
        if no_improvement_since > early_stopping:
            print("Training stopped")
            break
    print(losses)

train(model, dataloader)

In [None]:
dataloader.train_set()
batch = dataloader.get_next()
print(batch.gt_texts)
texts = batch.gt_texts
print(*texts)
print("\n")
print(*dataloader.char_list)
images = batch.imgs
texts = batch.gt_texts
log_probs = model(images)

In [None]:
num_batch_elements = len(images)
print(f"There are {num_batch_elements} images in the batch")
texts = batch.gt_texts
print(f"The Ground truth texts for this batch are of shape {gt_texts.shape}")
print(f"The images are of shape {images.shape}")
model.optimizer.zero_grad()
print(f"The logarithmic probabilities are of shape {log_probs.shape}")
print(f"The input lengths are of shape {input_lengths.shape}")
print(f"The target lengths are of shape {target_lengths.shape}")