In [None]:
import Nawah
import Nawah.nn as nn
import Nawah.nn.functional as F
import Nawah.optim as optim
import numpy as np
from sklearn.datasets import fetch_openml

In [None]:
class MNISTDataset(Nawah.data.Dataset):
    def __init__(self, train=True, transform=None):
        self.transform = transform

        print("Fetching MNIST dataset...")

        mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
        print("Dataset fetched.")

        images = mnist.data
        labels = mnist.target

        images = images / 255.0

        labels = labels.astype(int)
        
        images = images.astype(np.float32)

        if train:
            self.images = images[:60000]
            self.labels = labels[:60000]
        else:
            self.images = images[60000:]
            self.labels = labels[60000:]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_flat = self.images[index]
        
        image_reshaped = image_flat.reshape(1, 28, 28)
    
        label = self.labels[index]
    
        if self.transform:
            image_reshaped = self.transform(image_reshaped)
    
        return image_reshaped, label
    

trainset = MNISTDataset()
testset = MNISTDataset(train=False)

In [None]:
trainloader = Nawah.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = Nawah.data.DataLoader(testset, batch_size=64, shuffle=True)

In [None]:
class LeNetModernized(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(num_features=6)
        
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(num_features=16)

        flattened_size = 16 * 4 * 4
        
        self.fc1 = nn.Linear(in_dim=flattened_size, out_dim=120)
        self.fc2 = nn.Linear(in_dim=120, out_dim=84)
        self.output_layer = nn.Linear(in_dim=84, out_dim=10)
    
    
    @F.relu
    def conv_block1(self, x):
        return x >> self.conv1 >> self.bn1

    @F.relu
    def conv_block2(self, x):
        return x >> self.conv2 >> self.bn2

    @F.relu
    def dense_block1(self, x):
        return x >> self.fc1

    @F.relu
    def dense_block2(self, x):
        return x >> self.fc2

    def forward(self, x: Nawah.Tensor) -> Nawah.Tensor:
        x = x.view(x.shape[0], 1, 28, 28)
        
        
        logits = (x >> self.conv_block1
                    >> self.conv_block2
                    >> F.flatten
                    >> self.dense_block1
                    >> self.dense_block2
                    >> self.output_layer)
        
        return logits


model = LeNetModernized()
print("Modernized LeNet model created successfully.")
print("Input to fc1 will have size:", 16 * 4 * 4)

In [None]:
EPOCHS = 5
lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitLoss("sum")

In [None]:
def to_one_hot(labels: Nawah.Tensor, num_classes=10):
    labels_flat = labels.data.flatten()
    
    labels_int = labels_flat.astype(int)
    
    one_hot = np.zeros((labels_int.shape[0], num_classes))
    one_hot[np.arange(labels_int.shape[0]), labels_int] = 1
    
    return Nawah.Tensor(one_hot) 

In [None]:
from Nawah import Tensor

criterion = nn.BCEWithLogitLoss(reduction="sum") 
optimizer = optim.Adam(model.parameters(), lr=lr)

print("--- Starting Training ---")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        optimizer.zero_grad()

        outputs = model(inputs) 

        one_hot_labels = to_one_hot(labels, num_classes=10)

        loss = criterion(outputs, one_hot_labels)

        loss.backward()

        optimizer.step()
        
        running_loss += loss.data.item()
        
        predicted_probs = 1 / (1 + np.exp(-outputs.data))
        predicted_labels = np.argmax(predicted_probs, axis=1)
        
        true_labels = labels.data.flatten()
        
        correct_predictions += (predicted_labels == true_labels).sum()
        total_samples += len(true_labels)
        
        avg_loss = running_loss / total_samples
        avg_acc = correct_predictions / total_samples
        
        progress_string = (
            f"Epoch {epoch + 1}/{EPOCHS} | "
            f"Batch [{batch_idx + 1}/{len(trainloader)}] | "
            f"Loss: {avg_loss:.4f} | "
            f"Acc: {avg_acc:.2%}"
        )
        print(progress_string + "  ", end='\r')

    print()

    final_epoch_loss = running_loss / total_samples
    final_epoch_acc = correct_predictions / total_samples
    print(f"End of Epoch {epoch + 1} Summary | Average Loss: {final_epoch_loss:.4f} | Accuracy: {final_epoch_acc:.2%}")

print("\n--- Training Finished ---")

In [None]:
!pip install matplotlib
import matplotlib.pyplot as plt

CLASS_NAMES = [str(i) for i in range(10)]

model.eval()

print("--- Making predictions on random test images ---")

num_images_to_show = 5
plt.figure(figsize=(15, 4))

for i in range(num_images_to_show):
    random_index = np.random.randint(0, len(testset))
    
    image_data, true_label_index = testset[random_index]
    true_label_name = CLASS_NAMES[true_label_index]
    
    image_tensor = Nawah.Tensor(image_data)

    input_tensor = image_tensor.unsqueeze(0)
    
    output_logits = model(input_tensor)
    
    probabilities = 1 / (1 + np.exp(-output_logits.data))
    predicted_index = np.argmax(probabilities)
    predicted_name = CLASS_NAMES[predicted_index]
    
    image_to_display = image_data.squeeze()
    
    ax = plt.subplot(1, num_images_to_show, i + 1)
    plt.imshow(image_to_display, cmap='gray')
    
    title_color = 'green' if predicted_name == true_label_name else 'red'
    ax.set_title(f"Predicted: {predicted_name}\nTrue: {true_label_name}", color=title_color)
    
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()
