In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import numpy as np
import pygame
import torch
import os
import detectron2
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode
import cv2
import random

OSError: libcurand.so.10: cannot open shared object file: No such file or directory

In [None]:
# Define a function to calculate bounding boxes for each digit in the dataset
def calculate_bounding_boxes(images):
    bounding_boxes = []
    for img in images:
        img_np = np.array(img)
        rows = np.any(img_np, axis=1)
        cols = np.any(img_np, axis=0)
        ymin, ymax = np.where(rows)[0][[0, -1]]
        xmin, xmax = np.where(cols)[0][[0, -1]]
        bounding_boxes.append([xmin, ymin, xmax, ymax])
    return bounding_boxes

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = MNIST(root='./data', train=False, download=True, transform=transform)

# Calculate bounding boxes for the training and test datasets
train_bboxes = calculate_bounding_boxes(mnist_train.data)
test_bboxes = calculate_bounding_boxes(mnist_test.data)

# Example to show bounding boxes
print(train_bboxes[:5])
print(test_bboxes[:5])


In [None]:
def get_mnist_dicts(images, bboxes, labels):
    dataset_dicts = []
    for idx, (img, bbox, label) in enumerate(zip(images, bboxes, labels)):
        record = {}
        
        height, width = img.shape

        record["file_name"] = f"{idx}.png"
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
        
        obj = {
            "bbox": bbox,
            "bbox_mode": BoxMode.XYXY_ABS,
            "category_id": label
        }
        
        record["annotations"] = [obj]
        dataset_dicts.append(record)
    return dataset_dicts

# Register the dataset
for d in ["train", "test"]:
    DatasetCatalog.register("mnist_" + d, lambda d=d: get_mnist_dicts(
        mnist_train.data.numpy() if d == "train" else mnist_test.data.numpy(),
        train_bboxes if d == "train" else test_bboxes,
        mnist_train.targets.numpy() if d == "train" else mnist_test.targets.numpy()
    ))
    MetadataCatalog.get("mnist_" + d).set(thing_classes=[str(i) for i in range(10)])

# Configure the Detectron2 model
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("mnist_train",)
cfg.DATASETS.TEST = ("mnist_test",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LearningRate
cfg.SOLVER.MAX_ITER = 300    # 300 iterations
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128  # faster, and good enough for this dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10  # 10 classes (digits 0-9)
cfg.MODEL.DEVICE = "cuda"  # Use GPU for training

# Train the model
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()


In [None]:

# Save the trained model
os.makedirs("model_output", exist_ok=True)
torch.save(trainer.model.state_dict(), "model_output/mnist_digit_detector.pth")

In [None]:
# Initialize PyGame
pygame.init()

# Set up the drawing window
screen = pygame.display.set_mode([280, 280])
pygame.display.set_caption('Draw a digit')

# Set up the brush
brush_size = 8
brush_color = (255, 255, 255)

# Run until the user asks to quit
running = True
drawing = False

# Create a surface for drawing
draw_surface = pygame.Surface((280, 280))
draw_surface.fill((0, 0, 0))

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        elif event.type == pygame.MOUSEBUTTONDOWN:
            drawing = True
        elif event.type == pygame.MOUSEBUTTONUP:
            drawing = False
        elif event.type == pygame.MOUSEMOTION:
            if drawing:
                pygame.draw.circle(draw_surface, brush_color, event.pos, brush_size)
    
    # Draw everything
    screen.blit(draw_surface, (0, 0))
    pygame.display.flip()

# Save the drawing
pygame.image.save(draw_surface, "drawn_digit.png")

# Load and process the drawing
image = cv2.imread("drawn_digit.png", cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
image = 255 - image  # Invert colors
image = image / 255.0  # Normalize

# Predict the digit using the trained model
model = trainer.model
model.eval()
with torch.no_grad():
    input_image = torch.tensor(image).unsqueeze(0).unsqueeze(0).float()  # Add batch and channel dimensions
    predictions = model([{"image": input_image}])

# Process and display predictions
for prediction in predictions:
    boxes = prediction["instances"].pred_boxes.tensor.cpu().numpy()
    scores = prediction["instances"].scores.cpu().numpy()
    classes = prediction["instances"].pred_classes.cpu().numpy()

    for box, score, cls in zip(boxes, scores, classes):
        if score > 0.5:  # Consider only high confidence predictions
            x1, y1, x2, y2 = box.astype(int)
            cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image, str(cls), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

cv2.imshow('Predicted Digits', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

# Quit PyGame
pygame.quit()
