In [None]:
%pip install torch torchvision torchaudio

In [None]:
# 1. SETUP AND IMPORTS
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
from tqdm import tqdm

# Suppress verbose dataset loading logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 2. DEVICE CONFIGURATION
print("--- Device Check ---")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("--------------------")


# 3. DATASET LOADING AND PREPARATION
print("\n--- Loading IAM Dataset ---")
try:
    iam_dataset = load_dataset("Teklia/IAM-line")
    print("Dataset loaded successfully.")
    print(iam_dataset)
except Exception as e:
    print(f"Failed to load dataset. Please check your internet connection. Error: {e}")
    exit()

# Split the dataset
train_hf_dataset = iam_dataset["train"]
val_hf_dataset = iam_dataset["validation"]
test_hf_dataset = iam_dataset["test"]
print("--------------------------")


# 4. PREPROCESSING
print("\n--- Preprocessing Data ---")

# --- Text Preprocessing ---
characters = set()
for item in train_hf_dataset:
    characters.update(list(item['text']))
characters = sorted(list(characters))
VOCAB = "".join(characters)

char_to_int = {char: i + 1 for i, char in enumerate(VOCAB)} # 0 is reserved for blank
int_to_char = {i + 1: char for i, char in enumerate(VOCAB)}
CTC_BLANK = 0

print(f"Vocabulary Size: {len(VOCAB)}")
print(f"Characters: {VOCAB}")

# --- Image Preprocessing ---
IMG_HEIGHT = 64
IMG_WIDTH = 512

transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class IAMDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        try:
            item = self.hf_dataset[idx]
            image = item['image'].convert("RGB")
            text = item['text']
            if self.transform:
                image = self.transform(image)
            label = [char_to_int[char] for char in text]
            return image, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            print(f"Warning: Error processing item at index {idx}. Error: {e}. Skipping.")
            return None

# --- Collate Function for DataLoader ---
def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        # Return empty tensors if the whole batch failed
        return torch.tensor([]), torch.tensor([]), torch.tensor([])
    
    images, labels = zip(*batch)
    images = torch.stack(images, 0)
    label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
    padded_labels = nn.utils.rnn.pad_sequence(list(labels), batch_first=True, padding_value=0)
    return images, padded_labels, label_lengths


# 5. MODEL BUILDING (CRNN)
class CRNN(nn.Module):
    def __init__(self, num_chars):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.map_to_seq = nn.Linear(64 * (IMG_HEIGHT // 4), 64)
        self.rnn = nn.LSTM(64, 128, num_layers=2, bidirectional=True, dropout=0.25)
        self.fc = nn.Linear(256, num_chars)

    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0, 3, 1, 2)
        b, w, c, h = x.size()
        x = x.view(b, w, c * h)
        x = self.map_to_seq(x)
        x = x.permute(1, 0, 2)
        x, _ = self.rnn(x)
        x = self.fc(x)
        x = nn.functional.log_softmax(x, dim=2)
        return x

# 6. TRAINING AND VALIDATION FUNCTIONS
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for images, labels, label_lengths in tqdm(dataloader, desc="Training"):
        # FIX: Handle case where a whole batch might be empty
        if images.size(0) == 0:
            continue
        images, labels, label_lengths = images.to(device), labels.to(device), label_lengths.to(device)
        optimizer.zero_grad()
        log_probs = model(images)
        input_lengths = torch.full(size=(images.size(0),), fill_value=log_probs.size(0), dtype=torch.long)
        loss = criterion(log_probs, labels, input_lengths, label_lengths)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

def validate_one_epoch(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for images, labels, label_lengths in tqdm(dataloader, desc="Validating"):
            if images.size(0) == 0:
                continue
            images, labels, label_lengths = images.to(device), labels.to(device), label_lengths.to(device)
            log_probs = model(images)
            input_lengths = torch.full(size=(images.size(0),), fill_value=log_probs.size(0), dtype=torch.long)
            loss = criterion(log_probs, labels, input_lengths, label_lengths)
            epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

# 7. INFERENCE FUNCTION
def ctc_decode(log_probs):
    preds = log_probs.argmax(dim=2).permute(1, 0)
    decoded_texts = []
    for pred in preds:
        s = ''.join([int_to_char.get(c.item(), '') for c in pred if c != CTC_BLANK])
        dedup_s = ""
        if s:
            dedup_s = s[0]
            for char in s[1:]:
                if char != dedup_s[-1]:
                    dedup_s += char
        decoded_texts.append(dedup_s)
    return decoded_texts

# 8. main block

if __name__ == '__main__':
    print("\n--- Initializing DataLoaders ---")
    BATCH_SIZE = 32
    
    train_dataset = IAMDataset(train_hf_dataset, transform=transform)
    val_dataset = IAMDataset(val_hf_dataset, transform=transform)

    # **FIX**: Set num_workers=0 for Windows compatibility
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0)
    
    print("Preprocessing complete. PyTorch DataLoaders created.")

    print("\n--- Building CRNN Model ---")
    model = CRNN(num_chars=len(VOCAB) + 1).to(device)
    print(model)

    print("\n--- Training Model ---")
    criterion = nn.CTCLoss(blank=CTC_BLANK, zero_infinity=True)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    epochs = 50
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    epoch_models_dir = 'epoch_models'
    if not os.path.exists(epoch_models_dir):
        os.makedirs(epoch_models_dir)

    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss = validate_one_epoch(model, val_loader, criterion, device)
        
        print(f"\nEpoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        epoch_save_path = os.path.join(epoch_models_dir, f'handwriting_recognizer_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), epoch_save_path)
        print(f"Model saved after epoch {epoch+1} to {epoch_save_path}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'handwriting_recognizer_best.pth')
            print(f"Model improved. Saved best model to handwriting_recognizer_best.pth")
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

    print("Training finished.")

    print("\n--- Evaluating Model and Running Inference ---")
    prediction_model = CRNN(num_chars=len(VOCAB) + 1).to(device)
    prediction_model.load_state_dict(torch.load('handwriting_recognizer_best.pth', map_location=device))
    prediction_model.eval()

    data_iter = iter(val_loader)
    images, labels, _ = next(data_iter)
    images = images.to(device)

    with torch.no_grad():
        log_probs = prediction_model(images)

    pred_texts = ctc_decode(log_probs)

    orig_texts = []
    for label_tensor in labels:
        text = "".join([int_to_char.get(c.item(), '') for c in label_tensor if c != 0])
        orig_texts.append(text)

    _, axes = plt.subplots(4, 4, figsize=(15, 12))

    for i in range(min(16, BATCH_SIZE)):
        if i >= images.size(0):
            break
        img = images[i].cpu().numpy().squeeze()
        img = (img * 0.5 + 0.5) * 255
        img = np.clip(img, 0, 255).astype(np.uint8)
        
        ax = axes[i // 4, i % 4]
        ax.imshow(img, cmap="gray")
        ax.set_title(f"True: {orig_texts[i]}\nPred: {pred_texts[i]}", fontsize=9)
        ax.axis("off")

    plt.suptitle("Model Predictions on Validation Set (PyTorch)", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


    print("\n--- Saving Final Model Locally ---")
    torch.save(model.state_dict(), "handwriting_recognizer_final.pth")
    print("Final model state dict saved as 'handwriting_recognizer_final.pth'")
    print("Best performing model state dict saved as 'handwriting_recognizer_best.pth'")
    print(f"All epoch-wise models are saved in the '{epoch_models_dir}/' directory.")

In [None]:
print("\n--- Evaluating Model and Running Inference ---")
prediction_model = CRNN(num_chars=len(VOCAB) + 1).to(device)
prediction_model.load_state_dict(torch.load('handwriting_recognizer_best.pth', map_location=device))
prediction_model.eval()

data_iter = iter(val_loader)
images, labels, _ = next(data_iter)
images = images.to(device)

with torch.no_grad():
    log_probs = prediction_model(images)

pred_texts = ctc_decode(log_probs)

orig_texts = []
for label_tensor in labels:
    text = "".join([int_to_char.get(c.item(), '') for c in label_tensor if c != 0])
    orig_texts.append(text)

_, axes = plt.subplots(4, 4, figsize=(15, 12))

for i in range(min(16, BATCH_SIZE)):
    if i >= images.size(0):
        break
    img = images[i].cpu().numpy().squeeze()
    img = (img * 0.5 + 0.5) * 255
    img = np.clip(img, 0, 255).astype(np.uint8)
    
    ax = axes[i // 4, i % 4]
    ax.imshow(img, cmap="gray")
    ax.set_title(f"True: {orig_texts[i]}\nPred: {pred_texts[i]}", fontsize=9)
    ax.axis("off")

plt.suptitle("Model Predictions on Validation Set (PyTorch)", fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


print("\n--- Saving Final Model Locally ---")
torch.save(model.state_dict(), "handwriting_recognizer_final.pth")
print("Final model state dict saved as 'handwriting_recognizer_final.pth'")
print("Best performing model state dict saved as 'handwriting_recognizer_best.pth'")
print(f"All epoch-wise models are saved in the '{epoch_models_dir}/' directory.")