In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from typing import List, Tuple
import random
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from PIL import Image, ImageDraw, ImageFont
from nltk.corpus import words
import random
import string
from torchvision import transforms
from typing import List, Tuple
from PIL import Image

In [8]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using GPU: MPS")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: MPS


In [None]:
class BitSequenceDataset(Dataset):
    def __init__(self, sequences: List[str], labels: List[int]):
        self.sequences = sequences
        self.labels = labels
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = torch.tensor([int(b) for b in self.sequences[idx]], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return seq, label

def collate_fn(batch):
    sequences = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    labels = torch.stack(labels)    
    return sequences_padded, labels, lengths


def generate_dataset(num_samples, min_len, max_len) :
    sequences = []
    labels = []    
    for _ in range(num_samples):
        length = random.randint(min_len, max_len)
        sequence = ''.join(random.choice('01') for _ in range(length))
        count = sum(int(b) for b in sequence)        
        sequences.append(sequence)
        labels.append(count)    
    return sequences, labels


class BitCounterRNN(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int, dropout: float = 0.1):
        super(BitCounterRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.rnn = nn.LSTM(
            input_size=1,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            
            nn.Linear(hidden_size, 1)
        )
    
    def forward(self, x, lengths):
   
        x = x.unsqueeze(-1)  
        packed_x = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )        

        packed_output, _ = self.rnn(packed_x)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
   
        batch_size = output.size(0)
        last_outputs = torch.zeros(batch_size, self.hidden_size, device=output.device)
        for i in range(batch_size):
            last_outputs[i] = output[i, lengths[i]-1]
        count = self.fc(last_outputs)
        return count.squeeze(-1)


def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses = []
    val_losses = []    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        for sequences, labels, lengths in train_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            lengths = lengths.to(device)
            
            optimizer.zero_grad()
            outputs = model(sequences, lengths)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
     
        model.eval()
        val_loss = 0
        val_mae = 0
        
        with torch.no_grad():
            for sequences, labels, lengths in val_loader:
                sequences, labels = sequences.to(device), labels.to(device)
                lengths = lengths.to(device)
                outputs = model(sequences, lengths)
                val_loss += criterion(outputs, labels).item()
                val_mae += torch.mean(torch.abs(outputs - labels)).item()
        
        val_loss /= len(val_loader)
        val_mae /= len(val_loader)
        val_losses.append(val_loss)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Validation MAE: {val_mae:.4f}\n')
    
    return train_losses, val_losses

def evaluate_generalization(model, device, max_length=32):
    model.eval()
    mae_by_length = {}
    
    for length in range(1, max_length + 1):
     
        sequences, labels = generate_dataset(1000, length, length)
        dataset = BitSequenceDataset(sequences, labels)
        loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn)
        
        total_mae = 0
        with torch.no_grad():
            for sequences, labels, lengths in loader:
                sequences, labels = sequences.to(device), labels.to(device)
                lengths = lengths.to(device)
                outputs = model(sequences, lengths)
                total_mae += torch.mean(torch.abs(outputs - labels)).item() * len(sequences)
        
        mae_by_length[length] = total_mae / 1000
    
    return mae_by_length


def random_baseline(test_loader, device):
    total_mae = 0
    total_samples = 0
    
    for sequences, labels, lengths in test_loader:
        sequences = sequences.to(device)
        labels = labels.to(device)
        lengths = lengths.to(device)
        
        batch_size = labels.size(0)
       
        max_counts = lengths.float()  
        random_predictions = torch.rand(batch_size, device=device) * max_counts
        
        mae = torch.mean(torch.abs(random_predictions - labels)).item()
        total_mae += mae * batch_size
        total_samples += batch_size
    
    return total_mae / total_samples



   
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
    
print(f"Using device: {device}")
print("Generating dataset...")
sequences, labels = generate_dataset(100000, 1, 16)
dataset = BitSequenceDataset(sequences, labels)
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
    
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=collate_fn)

model = BitCounterRNN(hidden_size=64, num_layers=2).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("Calculating random baseline...")
random_mae = random_baseline(test_loader, device)
print(f"Random Baseline MAE: {random_mae:.4f}")
    
print("\nTraining model...")
train_losses, val_losses = train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs=20, device=device
)
    
print("\nEvaluating generalization...")
mae_by_length = evaluate_generalization(model, device)
    
# Plot results
plt.figure(figsize=(10, 6))
lengths = list(mae_by_length.keys())
maes = list(mae_by_length.values())
    
plt.plot(lengths, maes, marker='o', label='Model MAE')
plt.axhline(y=random_mae, color='r', linestyle='--', label='Random Baseline')
plt.xlabel('Sequence Length')
plt.ylabel('Mean Absolute Error')
plt.title('Model Generalization across Different Sequence Lengths')
plt.legend()
plt.grid(True)
plt.savefig('generalization_plot.png')
plt.close()
    
    # Print example predictions
print("\nExample predictions:")
model.eval()
with torch.no_grad():
    for sequences, labels, lengths in test_loader:
        sequences = sequences.to(device)
        labels = labels.to(device)
        lengths = lengths.to(device)
        
        predictions = model(sequences, lengths)
        
        # Print first 5 examples
        for i in range(min(5, len(sequences))):
            seq = sequences[i][:lengths[i]].cpu().numpy()
            seq_str = ''.join(map(str, map(int, seq)))
            print(f"Sequence: {seq_str}")
            print(f"True count: {labels[i].item():.0f}")
            print(f"Predicted count: {predictions[i].item():.1f}\n")
        break

Using device: mps
Generating dataset...
Calculating random baseline...
Random Baseline MAE: 2.3707

Training model...
Epoch 1/20:
Training Loss: 0.6624
Validation Loss: 0.0031
Validation MAE: 0.0472



KeyboardInterrupt: 

In [None]:
class WordImageDataset(Dataset):
    def __init__(self, word_list: List[str], image_size: Tuple[int, int] = (256, 64), max_word_length=20):
        self.words = [word for word in word_list if len(word) <= max_word_length]
        self.image_size = image_size
        self.max_word_length = max_word_length
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
       
        self.font = ImageFont.truetype("arial.ttf", size=32)
        
    def __len__(self):
        return len(self.words)
    
    def render_word(self, word: str) -> Image:
       
        img = Image.new('L', self.image_size, color=255)
        draw = ImageDraw.Draw(img)
        
        
        bbox = draw.textbbox((0, 0), word, font=self.font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        
      
        x = (self.image_size[0] - text_width) // 2
        y = (self.image_size[1] - text_height) // 2
        
       
        draw.text((x, y), word, fill=0, font=self.font)
        return img
    
    def __getitem__(self, idx):
        word = self.words[idx]
        image = self.render_word(word)
        image_tensor = self.transform(image)
        
        char_indices = [string.ascii_lowercase.find(c.lower()) + 1 for c in word if c.lower() in string.ascii_lowercase]
        return image_tensor, torch.tensor(char_indices, dtype=torch.long), len(char_indices)

def collate_fn(batch):
    batch.sort(key=lambda x: x[2], reverse=True)
    
    images, labels, lengths = zip(*batch)
    images = torch.stack(images, 0)
    
    max_len = max(lengths)
    padded_labels = torch.zeros(len(labels), max_len).long()
    for i, label in enumerate(labels):
        padded_labels[i, :len(label)] = label
    lengths = torch.tensor(lengths)
    
    return images, padded_labels, lengths

class CNNEncoder(nn.Module):
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, 2),
            
            nn.Dropout(0.5)
        )
        
    def forward(self, x):
        x = self.features(x)
        batch_size = x.size(0)
        return x.view(batch_size, -1, 512)

class RNNDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.5
        )
        
        self.fc = nn.Linear(hidden_size * 2, num_classes)
        
    def forward(self, x):
        output, _ = self.lstm(x)
        output = self.fc(output)
        return output

class OCRModel(nn.Module):
    def __init__(self, num_classes=27):  
        super(OCRModel, self).__init__()
        self.encoder = CNNEncoder()
        self.decoder = RNNDecoder(
            input_size=512,
            hidden_size=256,
            num_layers=2,
            num_classes=num_classes
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


def train_model(model, train_loader, val_loader, num_epochs=10, device='mps'):
    
    criterion = nn.CTCLoss(blank=0, reduction='mean')
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
    
    model = model.to(device)
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
      
        model.train()
        train_loss = 0
        for batch_idx, (images, targets, target_lengths) in enumerate(train_loader):
            images = images.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
     
            input_lengths = torch.full((outputs.size(0),), outputs.size(1), dtype=torch.long)
            loss = criterion(outputs.transpose(0, 1), targets, input_lengths, target_lengths)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()
            
            train_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
   
        model.eval()
        val_loss = 0
        correct_chars = 0
        total_chars = 0
        
        with torch.no_grad():
            for images, targets, target_lengths in val_loader:
                images = images.to(device)
                targets = targets.to(device)
                
                outputs = model(images)
                
               
                input_lengths = torch.full((outputs.size(0),), outputs.size(1), dtype=torch.long)
                loss = criterion(outputs.transpose(0, 1), targets, input_lengths, target_lengths)
                val_loss += loss.item()
                
               
                pred_indices = outputs.argmax(dim=-1)
                for pred, target, length in zip(pred_indices, targets, target_lengths):
                    pred = pred[:length]
                    target = target[:length]
                    correct_chars += (pred == target).sum().item()
                    total_chars += length
        
        val_loss /= len(val_loader)
        char_accuracy = correct_chars / total_chars
        
        print(f'Epoch {epoch}:')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Character Accuracy: {char_accuracy:.4f}')
        
        
        scheduler.step(val_loss)
        
       
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_ocr_model.pth')

def decode_prediction(pred_indices, blank_label=0):
    
    print("Raw indices:", pred_indices)
    
    previous = blank_label
    decoded = []
    for idx in pred_indices:
        if idx != previous and idx != blank_label:
            decoded.append(idx)
        previous = idx
        
    print("Decoded indices:", decoded)
    
    result = ''.join([string.ascii_lowercase[idx-1] if idx > 0 else '' for idx in decoded])
    return result

def visualize_predictions(model, test_loader, device, num_examples=5):
    model.eval()
   
    images, targets, lengths = next(iter(test_loader))
    images = images.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        pred_indices = outputs.argmax(dim=-1)
    
    
    fig, axes = plt.subplots(num_examples, 1, figsize=(15, 3*num_examples))
    if num_examples == 1:
        axes = [axes]
    
    for i in range(num_examples):
       
        img = images[i].cpu().squeeze().numpy()
        img = (img * 0.5 + 0.5)  
      
        true_text = ''.join([string.ascii_lowercase[idx-1] for idx in targets[i][:lengths[i]]])
        pred_text = decode_prediction(pred_indices[i].cpu().numpy())
       
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'True: {true_text}\nPredicted: {pred_text}')
    
    plt.tight_layout()
    plt.show()

def calculate_accuracy(pred_indices, target_indices, target_length):
    
    decoded_pred = []
    previous = 0  # blank
    for idx in pred_indices:
        if idx != previous and idx != 0:
            decoded_pred.append(idx)
        previous = idx
  
    target = target_indices[:target_length].tolist()
  
    pred_text = ''.join([string.ascii_lowercase[idx-1] if idx > 0 else '' for idx in decoded_pred])
    true_text = ''.join([string.ascii_lowercase[idx-1] if idx > 0 else '' for idx in target])
    
    correct_chars = sum(1 for p, t in zip(pred_text, true_text) if p == t)
    total_chars = len(true_text)
    
    return correct_chars, total_chars, pred_text, true_text

def load_and_test_model():
    test_words = random.sample([word for word in words.words() if word.isalpha()], 1000)
    test_dataset = WordImageDataset(test_words)
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    model = OCRModel()
    model.load_state_dict(torch.load('best_ocr_model.pth', map_location=device))
    model = model.to(device)
    
    model.eval()
    correct_chars = 0
    total_chars = 0
    word_correct = 0
    total_words = 0
    
    print("\nSample predictions:")
    with torch.no_grad():
        for images, targets, target_lengths in test_loader:
            images = images.to(device)
            targets = targets.to(device)
            
            outputs = model(images)
            pred_indices = outputs.argmax(dim=-1)
            
            for i in range(len(targets)):
                curr_correct, curr_total, pred_text, true_text = calculate_accuracy(
                    pred_indices[i].cpu().numpy(),
                    targets[i],
                    target_lengths[i]
                )
                
                correct_chars += curr_correct
                total_chars += curr_total
                
                if pred_text == true_text:
                    word_correct += 1
                total_words += 1
                
                if i < 5:
                    print(f"\nTrue: {true_text}")
                    print(f"Pred: {pred_text}")
                    print(f"Character accuracy: {curr_correct}/{curr_total}")
    
    char_accuracy = correct_chars / total_chars
    word_accuracy = word_correct / total_words
    
    print(f'\nOverall Character Accuracy: {char_accuracy:.4f}')
    print(f'Overall Word Accuracy: {word_accuracy:.4f}')
    print(f'Total correct words: {word_correct}/{total_words}')
    
    # Visualize some examples
    visualize_predictions(model, test_loader, device)


def random_baseline_accuracy(test_loader):
    correct_chars = 0
    total_chars = 0
    word_correct = 0
    total_words = 0
    
    print("\nRandom Baseline Predictions:")
    
    for _, targets, target_lengths in test_loader:
        for i in range(len(targets)):
            true_text = ''.join([string.ascii_lowercase[idx-1] for idx in targets[i][:target_lengths[i]]])
            # Generate random prediction of same length
            random_pred = ''.join(random.choice(string.ascii_lowercase) for _ in range(len(true_text)))
            
            # Calculate character accuracy
            correct_chars += sum(1 for p, t in zip(random_pred, true_text) if p == t)
            total_chars += len(true_text)
            
            # Calculate word accuracy
            if random_pred == true_text:
                word_correct += 1
            total_words += 1
            
            # Print some sample predictions (first 5 of each batch)
            if i < 5:
                print(f"\nTrue: {true_text}")
                print(f"Random: {random_pred}")
                print(f"Character matches: {sum(1 for p, t in zip(random_pred, true_text) if p == t)}/{len(true_text)}")
    
    char_accuracy = correct_chars / total_chars
    word_accuracy = word_correct / total_words
    
    print(f'\nRandom Baseline Results:')
    print(f'Character Accuracy: {char_accuracy:.4f}')
    print(f'Word Accuracy: {word_accuracy:.4f}')
    print(f'Total correct words: {word_correct}/{total_words}')
    
    return char_accuracy, word_accuracy

def compare_with_baseline():
    # Create test dataset
    test_words = random.sample([word for word in words.words() if word.isalpha()], 1000)
    test_dataset = WordImageDataset(test_words)
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    # Get model predictions
    model = OCRModel()
    model.load_state_dict(torch.load('best_ocr_model.pth', map_location=device))
    model = model.to(device)
    
    model.eval()
    model_correct_chars = 0
    model_total_chars = 0
    model_word_correct = 0
    model_total_words = 0
    
    print("Model Predictions:")
    with torch.no_grad():
        for images, targets, target_lengths in test_loader:
            images = images.to(device)
            targets = targets.to(device)
            
            outputs = model(images)
            pred_indices = outputs.argmax(dim=-1)
            
            for i in range(len(targets)):
                curr_correct, curr_total, pred_text, true_text = calculate_accuracy(
                    pred_indices[i].cpu().numpy(),
                    targets[i],
                    target_lengths[i]
                )
                
                model_correct_chars += curr_correct
                model_total_chars += curr_total
                
                if pred_text == true_text:
                    model_word_correct += 1
                model_total_words += 1
                
                if i < 5:
                    print(f"\nTrue: {true_text}")
                    print(f"Pred: {pred_text}")
                    print(f"Character matches: {curr_correct}/{curr_total}")
    
    model_char_accuracy = model_correct_chars / model_total_chars
    model_word_accuracy = model_word_correct / model_total_words
    
    print(f'\nModel Results:')
    print(f'Character Accuracy: {model_char_accuracy:.4f}')
    print(f'Word Accuracy: {model_word_accuracy:.4f}')
    print(f'Total correct words: {model_word_correct}/{model_total_words}')
    
    # Get random baseline predictions
    print('\n' + '='*50)
    baseline_char_accuracy, baseline_word_accuracy = random_baseline_accuracy(test_loader)
    
    # Compare results
    print('\n' + '='*50)
    print('Comparison Summary:')
    print(f'Character Accuracy - Model: {model_char_accuracy:.4f}, Random: {baseline_char_accuracy:.4f}')
    print(f'Character Accuracy Improvement: {(model_char_accuracy - baseline_char_accuracy):.4f}')
    print(f'Word Accuracy - Model: {model_word_accuracy:.4f}, Random: {baseline_word_accuracy:.4f}')
    print(f'Word Accuracy Improvement: {(model_word_accuracy - baseline_word_accuracy):.4f}')



compare_with_baseline()