In [20]:
import os
import torch
import cv2
import bounding_box
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class LatexDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(".png")]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        label_path = os.path.join(self.label_dir, img_name.replace(".png", ".txt"))

        # Load image
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape  # Original image size

        # Load label
        with open(label_path, "r") as f:
            label_data = [list(map(float, line.split())) for line in f.readlines()]
        
        class_labels = []
        bboxes = []

        for label in label_data:
            class_id = int(label[0])
            x_center, y_center, width, height = label[1:]

            x_min = (x_center - width / 2) 
            y_min = (y_center - height / 2)
            x_max = (x_center + width / 2) 
            y_max = (y_center + height / 2)

            class_labels.append(class_id)
            bboxes.append([x_min, y_min, x_max, y_max])

        # Convert to tensors
        image = cv2.resize(image, (128, 128)) / 255.0 
        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  
        class_labels = torch.tensor(class_labels, dtype=torch.long)
        bboxes = torch.tensor(bboxes, dtype=torch.float32)

        return image, class_labels, bboxes
    

def collate_fn(batch):
    """
    Custom collate function to handle variable-length bounding boxes.
    """
    images, class_labels, bboxes = zip(*batch)  # Unpack batch
    images = torch.stack(images, dim=0)  # Stack images normally

    return images, list(class_labels), list(bboxes)  # Keep labels & bboxes as lists

In [21]:
# Load dataset
max_objects = 100
image_dir = "dataset"
label_dir = "dataset"
train_dataset = LatexDataset(image_dir, label_dir, transform=None)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [22]:
class ObjectDetectorSimpleCNN(nn.Module):
    def __init__(self, num_classes, max_objects):
        super(ObjectDetectorSimpleCNN, self).__init__()

        self.num_classes = num_classes
        self.max_objects = max_objects

        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 32 * 32, 256)

        # Class prediction head
        self.fc_class = nn.Linear(256, num_classes*max_objects)

        # Bounding Box prediction head (x_min, y_min, x_max, y_max)
        self.fc_bbox = nn.Linear(256, 4*max_objects)

        self.fc_confidence = nn.Linear(256, max_objects)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)  # Flatten

        x = F.relu(self.fc1(x))

        class_output = self.fc_class(x).view(x.shape[0], self.max_objects, self.num_classes)  # [batch_size, max_objects, num_classes]
        bbox_output = self.fc_bbox(x).view(x.shape[0], self.max_objects, 4)  # [batch_size, max_objects, 4]
       
        confidence_output = torch.sigmoid(self.fc_confidence(x)) # 0 to 1 range

        return class_output, bbox_output, confidence_output


model = ObjectDetectorSimpleCNN(len(bounding_box.types), max_objects)
model

ObjectDetectorSimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=65536, out_features=256, bias=True)
  (fc_class): Linear(in_features=256, out_features=1000, bias=True)
  (fc_bbox): Linear(in_features=256, out_features=400, bias=True)
  (fc_confidence): Linear(in_features=256, out_features=100, bias=True)
)

In [23]:
class_loss_fn = nn.CrossEntropyLoss()
bbox_loss_fn = nn.MSELoss()
confidence_loss_fn = nn.BCELoss()  # Confidence loss
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs=10

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_class_loss, running_bbox_loss, running_conf_loss = 0.0, 0.0, 0.0

    for images, class_labels, bboxes in train_loader:
        images = images.to(device)
        optimizer.zero_grad()

        class_preds, bbox_preds, confidence_preds = model(images)

        loss_class, loss_bbox, loss_conf = 0, 0, 0

        for i in range(len(class_labels)):  # Loop over batch
            cls_target = class_labels[i].to(device)
            bbox_target = bboxes[i].to(device)

            num_objects = cls_target.shape[0]
            pred_classes = class_preds[i][:num_objects, :]
            pred_bboxes = bbox_preds[i][:num_objects, :]
            pred_confidences = confidence_preds[i][:num_objects]

            loss_class += class_loss_fn(pred_classes, cls_target)
            loss_bbox += bbox_loss_fn(pred_bboxes, bbox_target)
            loss_conf += confidence_loss_fn(pred_confidences, torch.ones(num_objects, device=device))

        loss = loss_class + loss_bbox + loss_conf
        loss.backward()
        optimizer.step()

        running_class_loss += loss_class.item()
        running_bbox_loss += loss_bbox.item()
        running_conf_loss += loss_conf.item()

    print(f"Epoch {epoch+1}, Class Loss: {running_class_loss:.4f}, BBox Loss: {running_bbox_loss:.4f}, Conf Loss: {running_conf_loss:.4f}")

print("Training complete!")

Epoch 1, Class Loss: 647.2732, BBox Loss: 53839.6027, Conf Loss: 83.0674
Epoch 2, Class Loss: 487.6732, BBox Loss: 29922.0896, Conf Loss: 0.0010
Epoch 3, Class Loss: 375.0172, BBox Loss: 28108.3630, Conf Loss: 0.0000
Epoch 4, Class Loss: 308.6285, BBox Loss: 27451.3144, Conf Loss: 0.0000
Epoch 5, Class Loss: 276.0655, BBox Loss: 26311.0074, Conf Loss: 0.0000
Epoch 6, Class Loss: 250.5802, BBox Loss: 24772.5605, Conf Loss: 0.0000
Epoch 7, Class Loss: 295.2268, BBox Loss: 23746.7567, Conf Loss: 0.0000
Epoch 8, Class Loss: 360.8060, BBox Loss: 21739.9703, Conf Loss: 0.0000
Epoch 9, Class Loss: 354.2438, BBox Loss: 18881.0616, Conf Loss: 0.0000
Epoch 10, Class Loss: 288.0420, BBox Loss: 15933.9975, Conf Loss: 0.0000
Training complete!
