In [64]:
import os
import cv2
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim

# Configuración básica
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz "
num_chars = len(chars) + 1  # +1 para blank token de CTC
img_size = (64, 256)  # Alto, Ancho

In [65]:
class OCRDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        self.samples = self._load_samples()
        
    def _load_samples(self):
        samples = []
        for img_name in os.listdir(self.img_dir):
            if img_name.endswith(('.png', '.jpg', '.jpeg')):
                img_path = os.path.join(self.img_dir, img_name)
                label_path = os.path.join(self.label_dir, os.path.splitext(img_name)[0] + '.txt')
                if os.path.exists(label_path):
                    with open(label_path, 'r') as f:
                        for line in f.readlines():
                            parts = line.strip().split()
                            if len(parts) >= 5:
                                text = ' '.join(parts[4:])
                                bbox = list(map(float, parts[:4]))
                                samples.append((img_path, bbox, text))
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, bbox, text = self.samples[idx]
        
        # Leer y procesar imagen
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        h, w = img.shape
        
        # Convertir bbox YOLO a coordenadas absolutas
        x_center, y_center, bw, bh = bbox
        x1 = int((x_center - bw/2) * w)
        y1 = int((y_center - bh/2) * h)
        x2 = int((x_center + bw/2) * w)
        y2 = int((y_center + bh/2) * h)
        
        # Recortar región de texto
        roi = img[max(0,y1):min(h,y2), max(0,x1):min(w,x2)]
        roi = cv2.resize(roi, (img_size[1], img_size[0]))
        roi = Image.fromarray(roi)
        
        if self.transform:
            roi = self.transform(roi)
        
        # Convertir texto a índices
        target = [chars.index(c) for c in text if c in chars]
        return roi, torch.tensor(target, dtype=torch.long), torch.tensor(len(target), dtype=torch.long)

In [66]:
# Transformaciones
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Función para hacer padding de los batches
def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    lengths = [item[2] for item in batch]
    
    images = torch.stack(images, 0)
    
    # Padding de targets
    max_len = max([len(t) for t in targets])
    padded_targets = torch.zeros(len(targets), max_len, dtype=torch.long)
    for i, target in enumerate(targets):
        padded_targets[i, :len(target)] = target
    
    lengths = torch.stack(lengths, 0)
    return images, padded_targets, lengths

# Crear datasets y dataloaders
train_dataset = OCRDataset("dataset2/images/train", "dataset2/labels/train", transform)
val_dataset = OCRDataset("dataset2/images/val", "dataset2/labels/val", transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [67]:
class CRNN(nn.Module):
    def __init__(self, num_chars):
        super(CRNN, self).__init__()
        
        # CNN
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            nn.Conv2d(256, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            nn.Conv2d(512, 512, 2), nn.BatchNorm2d(512), nn.ReLU()
        )
        
        # RNN
        self.rnn = nn.LSTM(512, 128, bidirectional=True, num_layers=2, dropout=0.3)
        
        # Capa de salida
        self.fc = nn.Linear(256, num_chars + 1)
    
    def forward(self, x):
        # CNN
        conv = self.cnn(x)  # [batch, channels, height, width]
        
        # Preparar para RNN
        conv = conv.squeeze(2)  # Eliminar dimensión de altura [batch, channels, width]
        conv = conv.permute(2, 0, 1)  # [width, batch, channels]
        
        # RNN
        rnn_out, _ = self.rnn(conv)
        
        # Salida
        output = self.fc(rnn_out)
        return output

In [68]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (images, targets, target_lengths) in enumerate(dataloader):
        images = images.to(device)
        targets = targets.to(device)
        target_lengths = target_lengths.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        input_lengths = torch.full((outputs.size(1),), outputs.size(0), dtype=torch.long).to(device)
        
        loss = criterion(outputs, targets, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)} - Loss: {loss.item():.4f}')
    
    return total_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, targets, target_lengths in dataloader:
            images = images.to(device)
            targets = targets.to(device)
            target_lengths = target_lengths.to(device)
            
            outputs = model(images)
            input_lengths = torch.full((outputs.size(1),), outputs.size(0), dtype=torch.long).to(device)
            
            loss = criterion(outputs, targets, input_lengths, target_lengths)
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [69]:
# Inicialización
model = CRNN(len(chars)).to(device)
criterion = nn.CTCLoss(blank=len(chars))
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Bucle de entrenamiento
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss = validate(model, val_loader, criterion, device)
    print(f'Epoch {epoch+1}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}')
    
    # Guardar el mejor modelo
    if epoch == 0 or val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_ocr_model.pth')
        print('Modelo guardado!')

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3

In [None]:
def predict(image_path, model, transform, device):
    # Preprocesamiento
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (img_size[1], img_size[0]))
    img = Image.fromarray(img)
    img = transform(img).unsqueeze(0).to(device)
    
    # Predicción
    model.eval()
    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 2)
        pred_text = decode_ctc(preds.squeeze(), chars)
    
    return pred_text

def decode_ctc(sequence, chars):
    prev_char = None
    text = []
    for idx in sequence:
        if idx < len(chars) and (prev_char != idx or idx == len(chars)):
            if idx == len(chars):  # Blank token
                prev_char = None
            else:
                text.append(chars[idx])
                prev_char = idx
    return ''.join(text)

# Ejemplo de uso
# model.load_state_dict(torch.load('best_ocr_model.pth'))
# text = predict("ejemplo.jpg", model, transform, device)
# print("Texto reconocido:", text)