In [1]:
from abc import ABC, abstractmethod

import numpy as np
from sklearn.ensemble import RandomForestClassifier

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
class DigitClassificationInterface(ABC):
    # This is an abstract base class (ABC) used to define an interface for digit classification models.
    # It ensures that all subclasses implement the specified methods, providing a consistent approach to training and prediction.

    @abstractmethod
    def train(self, x):
        # Abstract method that should be implemented by subclasses to train the model.
        # The method should take 'x' as input, which could be data used for training the model.
        # This method is expected to handle all the training logic for a digit classification model.
        pass
    
    @abstractmethod
    def predict(self, x):
        # Abstract method that should be implemented by subclasses to make predictions on the input data.
        # The method should take 'x' as input, which could be new, unseen data that the model will make predictions on.
        # This method should return the predicted output (the class labels for the input data).
        pass


class CNNClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize the feature extractor part of the network
        # It includes convolutional layers with ReLU activation, batch normalization, and max pooling.
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2),

            nn.Flatten()  # Flatten the output of convolutional layers to feed into fully connected layers
        )
        
        # Initialize the classifier part of the network
        # It consists of fully connected layers with ReLU activation and dropout for regularization.
        self.classifier = nn.Sequential(
            nn.Linear(128 * 3 * 3, 256),  # Calculate the correct input size
            nn.ReLU(),
            nn.Dropout(0.5),  # Dropout layer for preventing overfitting
            nn.Linear(256, 10)  # Output layer with 10 units for 10 classes
        )
    
    def forward(self, x):
        # Define the forward pass process, combining feature extractor and classifier
        x = self.feature_extractor(x)  # Process input through the feature extractor
        x = self.classifier(x)  # Process the result through the classifier
        return x

    def train_model(self, train_loader: DataLoader, epochs=1):
        # Train the model using stochastic gradient descent and cross entropy loss
        optimizer = optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            self.train()  # Set the model to training mode
            total_loss = 0
            for data, target in train_loader:
                optimizer.zero_grad()  # Reset gradients to zero
                output = self(data)  # Compute output
                loss = criterion(output, target)  # Compute loss
                loss.backward()  # Backpropagate the error
                optimizer.step()  # Update weights
                total_loss += loss.item()  # Sum up loss for average calculation

            print(f'Epoch: {epoch+1}, Average Loss: {total_loss / len(train_loader):.6f}')

    def predict(self, x):
        # Predict method to evaluate an individual sample
        self.eval()  # Set the model to evaluation mode
        with torch.no_grad():  # Disable gradient computation
            x = x.view(-1, 1, 28, 28)  # Ensure the input tensor is correctly shaped
            output = self.forward(x)  # Get model output
            _, predicted_label = torch.max(output.data, 1)  # Find the predicted label
            return predicted_label.item()  # Return the predicted class label as an integer


class RFClassifier:
    def __init__(self):
        # Initialize the RandomForestClassifier here to use the same instance for training and predictions
        # This maintains the state of the model across different method calls within the class.
        self.random_forest_classifier = RandomForestClassifier(n_estimators=10, random_state=42)
    
    def train_model(self, trainloader):
        # Train the model using the provided training data.
        # trainloader: DataLoader containing training data
        X_train, y_train = self._prepare_data(trainloader)
        self.random_forest_classifier.fit(X_train, y_train)
    
    def predict(self, X):
        # Use the trained model to predict labels for new data.
        # X: new data features for which predictions are required, expected torch.Size([1, 1, 28, 28])
        # Convert PyTorch tensor to NumPy array, reshape for RandomForest, and predict.
        X = X.numpy().reshape(1, -1)  # Reshape from (1, 1, 28, 28) to (1, 784) for a single image
        predictions = self.random_forest_classifier.predict(X)
        return predictions[0]
    
    def _prepare_data(self, dataloader):
        # Convert data from DataLoader into a format suitable for RandomForest.
        X, y = [], []
        for data in dataloader:
            inputs, labels = data
            # Convert PyTorch tensors to NumPy arrays and reshape for use in RandomForest.
            inputs = inputs.numpy()
            inputs = inputs.reshape(inputs.shape[0], -1)  # Reshape from (B, 1, 28, 28) to (B, 784)
            X.append(inputs)
            y.append(labels.numpy())
        
        # Concatenate all batch data into one array
        X = np.concatenate(X, axis=0)
        y = np.concatenate(y, axis=0)
        return X, y

class RandClassifier(DigitClassificationInterface):
    def train(self, x):
        # This method is implemented to comply with the interface, but it does nothing.
        # RandomModelClassifier does not require training since it generates random predictions.
        pass
    
    def predict(self, x):
        # Takes an input tensor representing an image (typically MNIST) of size [1, 1, 28, 28] and returns a random number between 0 and 9.
        # x: PyTorch tensor representing the image.
        
        # Create a center crop operation to reduce the image size to 10x10 pixels.
        center_crop = transforms.CenterCrop(10)
        
        # Apply the center crop to the input image.
        cropped_image = center_crop(x)
        
        # Return a random number between 0 and 9 using numpy's randint function.
        return np.random.randint(0, 10)

class DigitClassifier:
    def __init__(self, model_name: str):
        # Constructor for the DigitClassifier class.
        # It initializes a classifier based on the provided model name.
        # model_name: A string indicating which type of model to instantiate.
        
        # Internally calls the _get_model method to create an instance of the specified model.
        self.model = self._get_model(model_name)

    def _get_model(self, model_name: str) -> DigitClassificationInterface:
        # Private method to retrieve an instance of a specific classification model.
        # model_name: A string that specifies the model to instantiate.
        # Returns an instance of a class that implements the DigitClassificationInterface.

        # Checks the model_name and returns the appropriate classifier.
        if model_name == "cnn":
            return CNNClassifier()
        elif model_name == "rf":
            return RFClassifier()
        elif model_name == "rand":
            return RandClassifier()
        else:
            # If the provided model name is not recognized, raise an error.
            raise ValueError(f"Unknown model name: {model_name}")

    def predict(self, x):
        # Method to predict the class of the input data using the initialized model.
        # x: Input data for which the predictions need to be made.
        # Returns the prediction results from the model's predict method.
        
        return self.model.predict(x)

    def train(self, x):
        # Method to train the model on the input data.
        # x: Training data on which the model will be trained.
        # Calls the train_model method of the model to perform the training process.
        
        return self.model.train_model(x)

In [3]:
#Define a series of transformations to apply to the MNIST images.
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor.
    transforms.Normalize((0.5,), (0.5,))  # Normalize the tensor with mean and std deviation.
])

# Load the MNIST training dataset with the defined transformations.
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a DataLoader for the training set, with a batch size of 100 and shuffling enabled.
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True)

# Load the MNIST test dataset similar to the training dataset but with shuffling disabled.
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Create a DataLoader for the test set, with a batch size of 1 for individual image processing.
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)

In [4]:
# The task required creating an interface with the same structure for any model listed. 
# For this reason, all models take a tensor from torchvision.datasets.MNIST as input.

In [5]:
model = DigitClassifier('cnn')
model.train(trainloader)

Epoch: 1, Average Loss: 0.131169


In [6]:
print('Predicted label:', model.predict(testset[0][0]))
print('Real label:', testset[0][1])

Predicted label: 7
Real label: 7


In [7]:
model = DigitClassifier('rf')
model.train(trainloader)

In [8]:
print('Predicted label:', model.predict(testset[0][0]))
print('Real label:', testset[0][1])

Predicted label: 7
Real label: 7


In [9]:
model = DigitClassifier('rand')

In [10]:
print('Predicted label:', model.predict(testset[0][0]))
print('Real label:', testset[0][1])

Predicted label: 7
Real label: 7
