# Train A Shape Classifier Model



In [None]:
import json
import os

train_data_root = "../datasets/train"
test_data_root = "../datasets/test"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations (including resizing and normalization)
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale (black and white images)
    transforms.Resize((64, 64)),  # Resize images to 64x64 pixels
    transforms.ToTensor(),  # Convert the image to a tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images (mean=0.5, std=0.5 for grayscale)
])

# Load the dataset
train_dataset = datasets.ImageFolder(root=train_data_root, transform=transform)
test_dataset = datasets.ImageFolder(root=test_data_root, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Check class names (optional)
print(f'Classes: {train_dataset.classes}')

# 2. Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 3)  # 3 classes: circle, triangle, rectangle
        
    def forward(self, x):
        x = F.relu(self.conv1(x))   # First Conv Layer
        x = F.max_pool2d(x, 2)      # Max Pooling
        x = F.relu(self.conv2(x))   # Second Conv Layer
        x = F.max_pool2d(x, 2)      # Max Pooling
        x = x.view(x.size(0), -1)   # Flatten
        x = F.relu(self.fc1(x))     # Fully Connected Layer 1
        x = self.fc2(x)             # Fully Connected Layer 2 (output)
        return x

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')


# Train

In [None]:
train_model(model, train_loader, criterion, optimizer, epochs=15)

# Test 

In [None]:
def test(model, test_loader):
    """Print the Precision, Recall and F1-score for the trained model
    """

    y_predicted = []
    y_true = []
    
    model.eval() # Don't change the model weights from this point

    # For all images in the evaluation data:
    for batch, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)

        # Get predicted labels for the images (y pred)
        outputs = model(images) # Get raw predictions
        _, predictions = torch.max(outputs,1)

        # Add the real labels and predicted labels to two lists
        y_predicted.extend( predictions.cpu().numpy() )
        y_true.extend( labels.cpu().numpy() )

    # Convert the lists of predicted and true labels to tensors
    y_pred_tensor = torch.tensor(y_predicted)
    y_true_tensor = torch.tensor(y_true)
    
    # Calculate precision, recall, and F1 score
    # SOURCE: GeeksForGeeks https://www.geeksforgeeks.org/how-to-calculate-the-f1-score-and-other-custom-metrics-in-pytorch/
    TP = ((y_pred_tensor == 1) & (y_true_tensor == 1)).sum().item()
    FP = ((y_pred_tensor == 1) & (y_true_tensor == 0)).sum().item()
    FN = ((y_pred_tensor == 0) & (y_true_tensor == 1)).sum().item()

    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    print(f'Precision: {precision}')
    print(f'Recall: {recall}')
    print(f'F1 Score: {f1}')

    return

In [None]:
test(model, test_loader)

# Show Predictions


### Load image

In [None]:
from PIL import Image

# Source: https://www.geeksforgeeks.org/converting-an-image-to-a-torch-tensor-in-python/
import torchvision.transforms as transforms

def load_image(path):
    img = Image.open(path)
    img = transform(img).to(device)
    img = torch.unsqueeze(img, 0) # We must wrap an additional Tensor around this image
    return img

### Classify shape

In [None]:
def predict(model, image_tensor):
    """Predict the class of a given image.

    Args:
        model (nn.Module): The CNN Network.
        image_tensor (Tensor): The image to predict.
    
    Returns:
        pred_class_name (String): The name of the predicted class for the image
        confidence (float): Model prediction confidence as a scalar (0-1)
    """
    
    # Predict the image's class
    output = model(image_tensor)
    _, pred_class = torch.max(output,1)
    pred_class_name = train_dataset.classes[pred_class]

    # Get prediction confidence
    output = nn.Softmax(dim=1)(output) # Make all class predictions sum to 1 (softmax)
    confidence = output.squeeze()[ pred_class.squeeze() ].cpu().detach().numpy()

    return pred_class_name, confidence

### Get bounding box for shape

In [None]:
image_tensor = load_image("../datasets/eval/diamond_21.png")

def get_shape_bounding_box(image_tensor):
    """
    Given a greyscale image of a shape, get the bounding box for that shape.

    Args:
        image_tensor (Tensor): The image.
    
    Returns:
        rect (Tuple): The bounding box rectangle (X, Y, width, height).
    """
    image_tensor = image_tensor.squeeze()

    # Start of the bounding box (X, Y) - initialise as bottom-right corner of the image
    box_start = [image_tensor.squeeze().shape[0], image_tensor.squeeze().shape[1]]

    print(box_start)

    # End of the bounding box (X, Y) - initialise as top-left corner of the image
    box_end = [0,0]

    # For each horizontal row in the image:
    for Y, row in enumerate(image_tensor):

        # Get the X positions of all black pixels.
        # SOURCE: https://stackoverflow.com/questions/47863001/how-pytorch-tensor-get-the-index-of-specific-value
        black_pixel_X_positions = (row<0).nonzero().cpu().numpy().squeeze(1)

        # If the row has any black pixels:
        if black_pixel_X_positions.size > 0:
            # Update the co-ordinates of the bounding box to match.

            # Top-left corner of the box:
            box_start[1] = min(box_start[1], Y)
            box_start[0] = min(box_start[0], int(black_pixel_X_positions[0]))

            # Bottom-right corner of the box:
            box_end[1] = max(box_end[1], Y)
            box_end[0] = max(box_end[0], int(black_pixel_X_positions[-1]))
    
    height, width = box_end[0] - box_start[0], box_end[1] - box_start[1]
    rect = (*box_start, height, width)

    return rect

get_shape_bounding_box(image_tensor)

In [None]:
print(*(np.array(start) + [10,10]))

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle, Ellipse, Polygon

def display_prediction(image_tensor, bounding_box, predicted_shape, confidence):
    img_preview = image_tensor.cpu().numpy().squeeze()

    # Display the image
    plt.imshow(img_preview, cmap='grey')

    # Get the bounding box for the predicted shape
    rect = bounding_box
    start = [rect[0], rect[1]]
    width, height = rect[2], rect[3]

    # Draw the predicted shape

    ax = plt.gca()

    match predicted_shape:
        case "rectangle":
            # https://stackoverflow.com/questions/37435369/how-to-draw-a-rectangle-on-image
            shape = Rectangle( start, width, height, linewidth = 2, edgecolor = 'r', facecolor = 'none')
        
        case "circle":
            # https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.patches.Ellipse.html
            shape = Ellipse( [start[0] + width * 0.5, start[1] + height * 0.5], width, height, linewidth = 2, edgecolor = 'r', facecolor = 'none')
            
        case "triangle":
            points = [
                [int(start[0] + width * 0.5), start[1]],
                [start[0], start[1] + height],
                [start[0] + width, start[1] + height]
            ]

            shape = Polygon(points, linewidth = 2, edgecolor = 'r', facecolor = 'none')
            
        case "diamond":
            points = [
                [int(start[0] + width * 0.5), start[1]],
                [start[0], int(start[1] + height * 0.5)],
                [int(start[0] + width * 0.5), start[1] + height],
                [start[0] + width, int(start[1] + height * 0.5)]
            ]

            shape = Polygon(points, linewidth = 2, edgecolor = 'r', facecolor = 'none')

    ax.add_patch(shape)

    plt.text(*(np.array(start) + [-5,-6]), predicted_shape.title(), color='r', fontsize=12)
    plt.text(*(np.array(start) + [-5,-2]), f"Confidence: {(confidence * 100):.0f}%".title(), color='r', fontsize=12)


In [None]:
def show_prediction(model, image):
    """Pass the image to the model and overlay the predicted shape and confidence on the input
    image and display it
    """
    
    image_tensor = load_image(image)

    # Predict the shape first
    pred_class_name, confidence = predict(model,image_tensor)

    bounding_box = get_shape_bounding_box(image_tensor)

    display_prediction(image_tensor, bounding_box, pred_class_name, confidence)
    
    
show_prediction(model, "../datasets/test/diamond_21.png")