In [26]:
import torch
# Define device first
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name()}")

PyTorch version: 2.5.1+cu118
CUDA available: True
GPU device: NVIDIA GeForce RTX 4080 SUPER


In [27]:
# Enter your kaggle user name and key
username = str('')
key = str('')

In [28]:
# All imports
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import shutil  # Add this import
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from datetime import datetime
from torch.utils.data import random_split

In [29]:
# Download dataset from Kaggle
import opendatasets as od
import pandas
import json

# Create .kaggle directory if it doesn't exist
os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)

try:
    kaggle_token = {
        "username": username,
        "key": key
    }
    
    # Verify the data is string
    print("\nVerifying credentials format:")
    print(f"Username type: {type(username)}")
    print(f"Key type: {type(key)}")
    
    with open(os.path.expanduser('~/.kaggle/kaggle.json'), 'w') as f:
        json.dump(kaggle_token, f)
    
    print("Kaggle credentials saved successfully")
except Exception as e:
    print(f"Error saving Kaggle credentials: {e}")
    raise

# Set permissions
os.chmod(os.path.expanduser('~/.kaggle/kaggle.json'), 0o600)

# Download dataset
dataset_url = "https://www.kaggle.com/datasets/sagyamthapa/handwritten-math-symbols"
od.download(dataset_url)


Verifying credentials format:
Username type: <class 'str'>
Key type: <class 'str'>
Kaggle credentials saved successfully
Skipping, found downloaded files in ".\handwritten-math-symbols" (use force=True to force download)


In [30]:
# Extract dataset from archive with more detailed logging
print("\nStarting extraction process...")
dataset_path = "./handwritten-math-symbols"
archive_path = f"{dataset_path}.zip"
print(f"Looking for archive at: {archive_path}")
print(f"Archive exists: {os.path.exists(archive_path)}")

if os.path.exists(archive_path):
    import zipfile
    with zipfile.ZipFile(archive_path, 'r') as zip_ref:
        # List contents of zip file
        print("\nContents of zip file:")
        for file in zip_ref.namelist()[:10]:  # Show first 10 files
            print(f"- {file}")
        print("...")
        
        zip_ref.extractall("./")
    print("Dataset extracted successfully")

def verify_directory_structure():
    """Verify and print the directory structure"""
    print("\nVerifying directory structure:")
    for root, dirs, files in os.walk('./data'):
        print(f"\nDirectory: {root}")
        print(f"Subdirectories: {dirs}")
        print(f"Number of files: {len(files)}")

# First organize the data
print("\nStarting file organization...")
source_dir = "./handwritten-math-symbols/dataset"
print(f"Looking for source directory at: {source_dir}")
print(f"Source directory exists: {os.path.exists(source_dir)}")

if not os.path.exists(source_dir):
    raise FileNotFoundError(f"Source directory {source_dir} not found!")

# Create destination directories
os.makedirs('./data/digits', exist_ok=True)
os.makedirs('./data/operators', exist_ok=True)

# Define which folders belong to digits and operators
digit_folders = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
operator_folders = ['add', 'sub', 'mul', 'div', 'eq', 'dec', 'x', 'y', 'z']

# Move digit folders
print("\nMoving digit folders...")
for folder in digit_folders:
    src = os.path.join(source_dir, folder)
    dst = os.path.join('./data/digits', folder)
    if os.path.exists(src):
        print(f"Copying {src} to {dst}")
        shutil.copytree(src, dst, dirs_exist_ok=True)
    else:
        print(f"Warning: Source folder not found: {src}")

# Move operator folders
print("\nMoving operator folders...")
for folder in operator_folders:
    src = os.path.join(source_dir, folder)
    dst = os.path.join('./data/operators', folder)
    if os.path.exists(src):
        print(f"Copying {src} to {dst}")
        shutil.copytree(src, dst, dirs_exist_ok=True)
    else:
        print(f"Warning: Source folder not found: {src}")

# Verify the directory structure
verify_directory_structure()

# Verify that we have data in the folders
print("\nVerifying data in folders:")
digit_path = './data/digits'
operator_path = './data/operators'

if not os.path.exists(digit_path) or not os.path.exists(operator_path):
    raise FileNotFoundError("Data directories not created properly!")

digit_classes = sorted(os.listdir(digit_path))
operator_classes = sorted(os.listdir(operator_path))

print(f"\nFound digit classes: {digit_classes}")
print(f"Found operator classes: {operator_classes}")

if not digit_classes:
    raise FileNotFoundError("No digit classes found!")
if not operator_classes:
    raise FileNotFoundError("No operator classes found!")



Starting extraction process...
Looking for archive at: ./handwritten-math-symbols.zip
Archive exists: False

Starting file organization...
Looking for source directory at: ./handwritten-math-symbols/dataset
Source directory exists: True

Moving digit folders...
Copying ./handwritten-math-symbols/dataset\0 to ./data/digits\0
Copying ./handwritten-math-symbols/dataset\1 to ./data/digits\1
Copying ./handwritten-math-symbols/dataset\2 to ./data/digits\2
Copying ./handwritten-math-symbols/dataset\3 to ./data/digits\3
Copying ./handwritten-math-symbols/dataset\4 to ./data/digits\4
Copying ./handwritten-math-symbols/dataset\5 to ./data/digits\5
Copying ./handwritten-math-symbols/dataset\6 to ./data/digits\6
Copying ./handwritten-math-symbols/dataset\7 to ./data/digits\7
Copying ./handwritten-math-symbols/dataset\8 to ./data/digits\8
Copying ./handwritten-math-symbols/dataset\9 to ./data/digits\9

Moving operator folders...
Copying ./handwritten-math-symbols/dataset\add to ./data/operators\ad

In [31]:
# Create transforms
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create datasets with error handling
try:
    digit_dataset = torchvision.datasets.ImageFolder(
        root='./data/digits',
        transform=transform
    )
    print(f"\nDigit dataset created successfully with {len(digit_dataset)} images")
except Exception as e:
    print(f"Error creating digit dataset: {e}")
    raise

try:
    operator_dataset = torchvision.datasets.ImageFolder(
        root='./data/operators',
        transform=transform
    )
    print(f"Operator dataset created successfully with {len(operator_dataset)} images")
except Exception as e:
    print(f"Error creating operator dataset: {e}")
    raise

# Create dataset splits
print("\nCreating dataset splits...")

# For Digit Dataset
total_digit_size = len(digit_dataset)
train_digit_size = int(0.8 * total_digit_size)
val_digit_size = int(0.1 * total_digit_size)
test_digit_size = total_digit_size - train_digit_size - val_digit_size

digit_train_dataset, digit_val_dataset, digit_test_dataset = random_split(
    digit_dataset, 
    [train_digit_size, val_digit_size, test_digit_size]
)

# For Operator Dataset
total_operator_size = len(operator_dataset)
train_operator_size = int(0.8 * total_operator_size)
val_operator_size = int(0.1 * total_operator_size)
test_operator_size = total_operator_size - train_operator_size - val_operator_size

operator_train_dataset, operator_val_dataset, operator_test_dataset = random_split(
    operator_dataset, 
    [train_operator_size, val_operator_size, test_operator_size]
)

# Update the existing dataloaders to use training split
digit_trainloader = torch.utils.data.DataLoader(
    digit_train_dataset,
    batch_size=32,
    shuffle=True, 
    num_workers=2
    )

operator_trainloader = torch.utils.data.DataLoader(
    operator_train_dataset, 
    batch_size=32,
    shuffle=True, 
    num_workers=2
    )
print(f"Digit - Train: {len(digit_train_dataset)}, Val: {len(digit_val_dataset)}, Test: {len(digit_test_dataset)}")
print(f"Operator - Train: {len(operator_train_dataset)}, Val: {len(operator_val_dataset)}, Test: {len(operator_test_dataset)}")

# Create dataloaders
print("\nDatasets and dataloaders created successfully!")
print(f"Number of digit classes: {len(digit_dataset.classes)}")
print(f"Number of operator classes: {len(operator_dataset.classes)}")



Digit dataset created successfully with 5304 images
Operator dataset created successfully with 4767 images

Creating dataset splits...
Digit - Train: 4243, Val: 530, Test: 531
Operator - Train: 3813, Val: 476, Test: 478

Datasets and dataloaders created successfully!
Number of digit classes: 10
Number of operator classes: 9


In [32]:
# Define the classes based on the folders
digit_classes = sorted(os.listdir('./data/digits'))
operator_classes = sorted(os.listdir('./data/operators'))

# Define network architectures first
class DigitNet(nn.Module):
    def __init__(self):
        super(DigitNet, self).__init__()
        # Input: 1x32x32
        self.conv1 = nn.Conv2d(1, 6, 5)  # Output: 6x28x28
        self.pool = nn.MaxPool2d(2, 2)    # Output: 6x14x14
        self.conv2 = nn.Conv2d(6, 16, 5)  # Output: 16x10x10
        # After second pooling: 16x5x5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, len(digit_classes))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class OperatorNet(nn.Module):
    def __init__(self):
        super(OperatorNet, self).__init__()
        # Input: 1x32x32
        self.conv1 = nn.Conv2d(1, 6, 5)  # Output: 6x28x28
        self.pool = nn.MaxPool2d(2, 2)    # Output: 6x14x14
        self.conv2 = nn.Conv2d(6, 16, 5)  # Output: 16x10x10
        # After second pooling: 16x5x5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, len(operator_classes))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
# Initialize networks on CPU first
digit_net = DigitNet()
operator_net = OperatorNet()

# Diagnostic code to check dimensions and classes
print("\nDiagnostic Information:")
print(f"Digit classes: {digit_classes}")
print(f"Number of digit classes: {len(digit_classes)}")
print(f"DigitNet output dimension: {digit_net.fc3.out_features}")

print(f"Operator classes: {operator_classes}")
print(f"Number of operator classes: {len(operator_classes)}")
print(f"OperatorNet output dimension: {operator_net.fc3.out_features}")

# Verify network structures before moving to GPU
print("\nNetwork Architecture Check:")
print("DigitNet:")
print(f"Input -> Conv1 (1->6) -> Pool -> Conv2 (6->16) -> Pool -> FC1 (400->120) -> FC2 (120->84) -> FC3 (84->{len(digit_classes)})")
print("\nOperatorNet:")
print(f"Input -> Conv1 (1->6) -> Pool -> Conv2 (6->16) -> Pool -> FC1 (400->120) -> FC2 (120->84) -> FC3 (84->{len(operator_classes)})")

# Try moving to GPU with error handling
try:
    print("\nMoving networks to GPU...")
    digit_net = digit_net.to(device)
    operator_net = operator_net.to(device)
    print("Successfully moved networks to GPU")
except RuntimeError as e:
    print(f"Error moving networks to GPU: {e}")
    print("Falling back to CPU")
    device = torch.device("cpu")
    digit_net = digit_net.to(device)
    operator_net = operator_net.to(device)

# Create optimizers after moving networks to device
digit_optimizer = optim.SGD(digit_net.parameters(), lr=0.001, momentum=0.9)
operator_optimizer = optim.SGD(operator_net.parameters(), lr=0.001, momentum=0.9)

# Now verify data and label shapes
print("\nVerifying data shapes:")
sample_digit_batch = next(iter(digit_trainloader))
sample_operator_batch = next(iter(operator_trainloader))

print(f"Digit batch - Images: {sample_digit_batch[0].shape}, Labels: {sample_digit_batch[1].shape}")
print(f"Operator batch - Images: {sample_operator_batch[0].shape}, Labels: {sample_operator_batch[1].shape}")

# Check label distributions
print("\nLabel distributions:")
print(f"Digit labels unique values: {torch.unique(sample_digit_batch[1])}")
print(f"Operator labels unique values: {torch.unique(sample_operator_batch[1])}")

# Verify network output dimensions match number of classes
with torch.no_grad():
    digit_out = digit_net(sample_digit_batch[0].to(device))
    operator_out = operator_net(sample_operator_batch[0].to(device))
    print("\nNetwork output dimensions:")
    print(f"DigitNet output: {digit_out.shape}")
    print(f"OperatorNet output: {operator_out.shape}")



Diagnostic Information:
Digit classes: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Number of digit classes: 10
DigitNet output dimension: 10
Operator classes: ['add', 'dec', 'div', 'eq', 'mul', 'sub', 'x', 'y', 'z']
Number of operator classes: 9
OperatorNet output dimension: 9

Network Architecture Check:
DigitNet:
Input -> Conv1 (1->6) -> Pool -> Conv2 (6->16) -> Pool -> FC1 (400->120) -> FC2 (120->84) -> FC3 (84->10)

OperatorNet:
Input -> Conv1 (1->6) -> Pool -> Conv2 (6->16) -> Pool -> FC1 (400->120) -> FC2 (120->84) -> FC3 (84->9)

Moving networks to GPU...
Successfully moved networks to GPU

Verifying data shapes:
Digit batch - Images: torch.Size([32, 1, 32, 32]), Labels: torch.Size([32])
Operator batch - Images: torch.Size([32, 1, 32, 32]), Labels: torch.Size([32])

Label distributions:
Digit labels unique values: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Operator labels unique values: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

Network output dimensions:
DigitNet output: torch.Si

In [33]:
# Training parameters
num_epochs = 50
eval_interval = 5
early_stopping_patience = 15
best_digit_accuracy = 0
best_operator_accuracy = 0
epochs_without_improvement = 0

# Get current date for logging
training_date = datetime.now().strftime("%Y-%m-%d")
start_time = time.time()

# Add learning rate scheduler
digit_scheduler = ReduceLROnPlateau(digit_optimizer, mode='max', factor=0.1, patience=5, verbose=True)
operator_scheduler = ReduceLROnPlateau(operator_optimizer, mode='max', factor=0.1, patience=5, verbose=True)

# Diagnostic code to check dimensions and classes
print("\nDiagnostic Information:")

# Check digit dataset
digit_classes = sorted(os.listdir('./data/digits'))
print(f"Digit classes: {digit_classes}")
print(f"Number of digit classes: {len(digit_classes)}")
print(f"DigitNet output dimension: {digit_net.fc3.out_features}")

# Check operator dataset
operator_classes = sorted(os.listdir('./data/operators'))
print(f"Operator classes: {operator_classes}")
print(f"Number of operator classes: {len(operator_classes)}")
print(f"OperatorNet output dimension: {operator_net.fc3.out_features}")

# Check a batch of data
digit_batch = next(iter(digit_trainloader))
operator_batch = next(iter(operator_trainloader))

print("\nBatch shapes:")
print(f"Digit batch - Images: {digit_batch[0].shape}, Labels: {digit_batch[1].shape}")
print(f"Operator batch - Images: {operator_batch[0].shape}, Labels: {operator_batch[1].shape}")

print("\nLabel ranges:")
print(f"Digit labels: {digit_batch[1].min().item()} to {digit_batch[1].max().item()}")
print(f"Operator labels: {operator_batch[1].min().item()} to {operator_batch[1].max().item()}")


# Add criterion definitions before training
digit_criterion = nn.CrossEntropyLoss()
operator_criterion = nn.CrossEntropyLoss()

# Add dataset class mapping
digit_dataset.class_to_idx  # Check the mapping of classes to indices
operator_dataset.class_to_idx  # Check the mapping of classes to indices

# Before training loop, add diagnostic prints
print("\nClass mappings:")
print(f"Digit classes: {digit_dataset.class_to_idx}")
print(f"Operator classes: {operator_dataset.class_to_idx}")

# Create weights directory if it doesn't exist
os.makedirs('./weights', exist_ok=True)

# Update paths for saving weights during training
best_digit_path = './weights/digit_net_best.pth'
best_operator_path = './weights/operator_net_best.pth'

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Training digit network
    digit_net.train()
    digit_running_loss = 0.0
    epoch_start = time.time()
    
    # Add progress bar
    pbar = tqdm(digit_trainloader, desc=f"Training Digit Net")
    for i, data in enumerate(pbar):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        digit_optimizer.zero_grad()
        outputs = digit_net(inputs)
        loss = digit_criterion(outputs, labels)
        loss.backward()
        digit_optimizer.step()
        
        digit_running_loss += loss.item()
        if i % 100 == 99:
            pbar.set_postfix({'loss': digit_running_loss/100})
            digit_running_loss = 0.0
    
    # Training operator network
    operator_net.train()
    operator_running_loss = 0.0
    
    pbar = tqdm(operator_trainloader, desc=f"Training Operator Net")
    for i, data in enumerate(pbar):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        operator_optimizer.zero_grad()
        outputs = operator_net(inputs)
        loss = operator_criterion(outputs, labels)
        loss.backward()
        operator_optimizer.step()
        
        operator_running_loss += loss.item()
        if i % 100 == 99:
            pbar.set_postfix({'loss': operator_running_loss/100})
            operator_running_loss = 0.0
    
    # Evaluation every eval_interval epochs
    if epoch % eval_interval == 0:
        digit_net.eval()
        operator_net.eval()
        
         # Create validation dataloaders (only when needed)
        digit_valloader = DataLoader(digit_val_dataset, batch_size=32, shuffle=False)
        operator_valloader = DataLoader(operator_val_dataset, batch_size=32, shuffle=False)
        
        # Evaluate digit network on validation set
        correct = 0
        total = 0
        with torch.no_grad():
            for data in digit_valloader:  # Changed from digit_trainloader
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = digit_net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        current_digit_accuracy = 100 * correct / total
        print(f"\nDigit Network Validation Accuracy: {current_digit_accuracy:.2f}%")
        
        # Evaluate operator network on validation set
        correct = 0
        total = 0
        with torch.no_grad():
            for data in operator_valloader:  # Changed from operator_trainloader
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = operator_net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        current_operator_accuracy = 100 * correct / total
        print(f"Operator Network Validation Accuracy: {current_operator_accuracy:.2f}%")
        
        # Update schedulers
        digit_scheduler.step(current_digit_accuracy)
        operator_scheduler.step(current_operator_accuracy)
        
        # Check for improvement
        if current_digit_accuracy > best_digit_accuracy or current_operator_accuracy > best_operator_accuracy:
            best_digit_accuracy = max(best_digit_accuracy, current_digit_accuracy)
            best_operator_accuracy = max(best_operator_accuracy, current_operator_accuracy)
            torch.save(digit_net.state_dict(), best_digit_path)
            torch.save(operator_net.state_dict(), best_operator_path)
            epochs_without_improvement = 0
            print("New best accuracy! Saved models.")
        else:
            epochs_without_improvement += eval_interval
        
        # Early stopping check
        if epochs_without_improvement >= early_stopping_patience:
            print("\nEarly stopping triggered!")
            break
    
    # Print epoch summary
    epoch_time = time.time() - epoch_start
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Time taken: {epoch_time:.2f} seconds")
    print(f"Best Digit Accuracy: {best_digit_accuracy:.2f}%")
    print(f"Best Operator Accuracy: {best_operator_accuracy:.2f}%")

print('Finished Training')
# Save final weights
print("\nSaving final model weights...")
try:
    final_digit_path = './weights/digit_net_final.pth'
    final_operator_path = './weights/operator_net_final.pth'
    
    torch.save(digit_net.state_dict(), final_digit_path)
    torch.save(operator_net.state_dict(), final_operator_path)
    
    # Verify files were saved
    if os.path.exists(final_digit_path) and os.path.exists(final_operator_path):
        print(f"Final weights saved successfully:")
        print(f"- Digit network: {final_digit_path}")
        print(f"- Operator network: {final_operator_path}")
    else:
        print("Warning: Weight files not found after saving!")
except Exception as e:
    print(f"Error saving final weights: {e}")



if os.path.exists(best_digit_path) and os.path.exists(best_operator_path):
    print("\nBest weights were saved during training:")
    print(f"- Best digit network: {best_digit_path}")
    print(f"- Best operator network: {best_operator_path}")
else:
    print("\nWarning: Best weight files not found!")


Diagnostic Information:
Digit classes: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Number of digit classes: 10
DigitNet output dimension: 10
Operator classes: ['add', 'dec', 'div', 'eq', 'mul', 'sub', 'x', 'y', 'z']
Number of operator classes: 9
OperatorNet output dimension: 9

Batch shapes:
Digit batch - Images: torch.Size([32, 1, 32, 32]), Labels: torch.Size([32])
Operator batch - Images: torch.Size([32, 1, 32, 32]), Labels: torch.Size([32])

Label ranges:
Digit labels: 0 to 9
Operator labels: 0 to 7

Class mappings:
Digit classes: {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
Operator classes: {'add': 0, 'dec': 1, 'div': 2, 'eq': 3, 'mul': 4, 'sub': 5, 'x': 6, 'y': 7, 'z': 8}

Epoch 1/50


Training Digit Net: 100%|██████████| 133/133 [00:11<00:00, 11.25it/s, loss=2.3]
Training Operator Net: 100%|██████████| 120/120 [00:11<00:00, 10.71it/s, loss=2.19]



Digit Network Validation Accuracy: 10.94%
Operator Network Validation Accuracy: 10.92%
New best accuracy! Saved models.

Epoch 1/50
Time taken: 27.19 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 10.92%

Epoch 2/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.40it/s, loss=2.3]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.87it/s, loss=2.18]



Epoch 2/50
Time taken: 11.98 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 10.92%

Epoch 3/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.29it/s, loss=2.3]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.52it/s, loss=2.17]



Epoch 3/50
Time taken: 11.82 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 10.92%

Epoch 4/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.45it/s, loss=2.3]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.62it/s, loss=2.15]



Epoch 4/50
Time taken: 11.75 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 10.92%

Epoch 5/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.59it/s, loss=2.29]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.40it/s, loss=2.14]



Epoch 5/50
Time taken: 11.77 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 10.92%

Epoch 6/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.16it/s, loss=2.29]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.85it/s, loss=2.12]



Digit Network Validation Accuracy: 9.06%
Operator Network Validation Accuracy: 25.00%
New best accuracy! Saved models.

Epoch 6/50
Time taken: 12.52 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 25.00%

Epoch 7/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.45it/s, loss=2.28]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.84it/s, loss=2.09]



Epoch 7/50
Time taken: 11.69 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 25.00%

Epoch 8/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.52it/s, loss=2.27]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.64it/s, loss=2]



Epoch 8/50
Time taken: 11.72 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 25.00%

Epoch 9/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.55it/s, loss=2.26]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.61it/s, loss=1.77]



Epoch 9/50
Time taken: 11.72 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 25.00%

Epoch 10/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.32it/s, loss=2.25]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.79it/s, loss=1.37]



Epoch 10/50
Time taken: 11.73 seconds
Best Digit Accuracy: 10.94%
Best Operator Accuracy: 25.00%

Epoch 11/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.29it/s, loss=2.22]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.68it/s, loss=1.13]



Digit Network Validation Accuracy: 20.57%
Operator Network Validation Accuracy: 67.44%
New best accuracy! Saved models.

Epoch 11/50
Time taken: 12.51 seconds
Best Digit Accuracy: 20.57%
Best Operator Accuracy: 67.44%

Epoch 12/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.05it/s, loss=2.18]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.55it/s, loss=0.982]



Epoch 12/50
Time taken: 11.87 seconds
Best Digit Accuracy: 20.57%
Best Operator Accuracy: 67.44%

Epoch 13/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.36it/s, loss=2.1]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.80it/s, loss=0.846]



Epoch 13/50
Time taken: 11.72 seconds
Best Digit Accuracy: 20.57%
Best Operator Accuracy: 67.44%

Epoch 14/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.63it/s, loss=1.97]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.23it/s, loss=0.732]



Epoch 14/50
Time taken: 11.81 seconds
Best Digit Accuracy: 20.57%
Best Operator Accuracy: 67.44%

Epoch 15/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.32it/s, loss=1.82]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.72it/s, loss=0.63]



Epoch 15/50
Time taken: 11.75 seconds
Best Digit Accuracy: 20.57%
Best Operator Accuracy: 67.44%

Epoch 16/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.49it/s, loss=1.65]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.93it/s, loss=0.536]



Digit Network Validation Accuracy: 46.60%
Operator Network Validation Accuracy: 82.14%
New best accuracy! Saved models.

Epoch 16/50
Time taken: 12.36 seconds
Best Digit Accuracy: 46.60%
Best Operator Accuracy: 82.14%

Epoch 17/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.03it/s, loss=1.51]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.43it/s, loss=0.477]



Epoch 17/50
Time taken: 11.91 seconds
Best Digit Accuracy: 46.60%
Best Operator Accuracy: 82.14%

Epoch 18/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.05it/s, loss=1.33]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.72it/s, loss=0.429]



Epoch 18/50
Time taken: 11.83 seconds
Best Digit Accuracy: 46.60%
Best Operator Accuracy: 82.14%

Epoch 19/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.43it/s, loss=1.17]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.67it/s, loss=0.403]



Epoch 19/50
Time taken: 11.74 seconds
Best Digit Accuracy: 46.60%
Best Operator Accuracy: 82.14%

Epoch 20/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.74it/s, loss=1.07]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.20it/s, loss=0.342]



Epoch 20/50
Time taken: 12.37 seconds
Best Digit Accuracy: 46.60%
Best Operator Accuracy: 82.14%

Epoch 21/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.38it/s, loss=0.943]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.75it/s, loss=0.309]



Digit Network Validation Accuracy: 68.11%
Operator Network Validation Accuracy: 90.55%
New best accuracy! Saved models.

Epoch 21/50
Time taken: 12.46 seconds
Best Digit Accuracy: 68.11%
Best Operator Accuracy: 90.55%

Epoch 22/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.64it/s, loss=0.86]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.51it/s, loss=0.282]



Epoch 22/50
Time taken: 11.73 seconds
Best Digit Accuracy: 68.11%
Best Operator Accuracy: 90.55%

Epoch 23/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.51it/s, loss=0.751]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.72it/s, loss=0.256]



Epoch 23/50
Time taken: 11.70 seconds
Best Digit Accuracy: 68.11%
Best Operator Accuracy: 90.55%

Epoch 24/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.14it/s, loss=0.674]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.67it/s, loss=0.254]



Epoch 24/50
Time taken: 11.81 seconds
Best Digit Accuracy: 68.11%
Best Operator Accuracy: 90.55%

Epoch 25/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.59it/s, loss=0.621]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 21.14it/s, loss=0.222]



Epoch 25/50
Time taken: 11.57 seconds
Best Digit Accuracy: 68.11%
Best Operator Accuracy: 90.55%

Epoch 26/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.47it/s, loss=0.578]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.96it/s, loss=0.199]



Digit Network Validation Accuracy: 80.00%
Operator Network Validation Accuracy: 91.60%
New best accuracy! Saved models.

Epoch 26/50
Time taken: 12.38 seconds
Best Digit Accuracy: 80.00%
Best Operator Accuracy: 91.60%

Epoch 27/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 20.97it/s, loss=0.505]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.74it/s, loss=0.179]



Epoch 27/50
Time taken: 12.42 seconds
Best Digit Accuracy: 80.00%
Best Operator Accuracy: 91.60%

Epoch 28/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.92it/s, loss=0.45]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.60it/s, loss=0.167]



Epoch 28/50
Time taken: 11.90 seconds
Best Digit Accuracy: 80.00%
Best Operator Accuracy: 91.60%

Epoch 29/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.45it/s, loss=0.424]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.27it/s, loss=0.165]



Epoch 29/50
Time taken: 11.85 seconds
Best Digit Accuracy: 80.00%
Best Operator Accuracy: 91.60%

Epoch 30/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 20.39it/s, loss=0.385]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.50it/s, loss=0.14]



Epoch 30/50
Time taken: 12.38 seconds
Best Digit Accuracy: 80.00%
Best Operator Accuracy: 91.60%

Epoch 31/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.54it/s, loss=0.356]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.38it/s, loss=0.125]



Digit Network Validation Accuracy: 85.85%
Operator Network Validation Accuracy: 93.28%
New best accuracy! Saved models.

Epoch 31/50
Time taken: 12.55 seconds
Best Digit Accuracy: 85.85%
Best Operator Accuracy: 93.28%

Epoch 32/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.09it/s, loss=0.334]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.57it/s, loss=0.122]



Epoch 32/50
Time taken: 11.86 seconds
Best Digit Accuracy: 85.85%
Best Operator Accuracy: 93.28%

Epoch 33/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.22it/s, loss=0.29]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.32it/s, loss=0.119]



Epoch 33/50
Time taken: 12.20 seconds
Best Digit Accuracy: 85.85%
Best Operator Accuracy: 93.28%

Epoch 34/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 20.67it/s, loss=0.282]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.22it/s, loss=0.103]



Epoch 34/50
Time taken: 12.37 seconds
Best Digit Accuracy: 85.85%
Best Operator Accuracy: 93.28%

Epoch 35/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.73it/s, loss=0.259]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.90it/s, loss=0.115]



Epoch 35/50
Time taken: 12.15 seconds
Best Digit Accuracy: 85.85%
Best Operator Accuracy: 93.28%

Epoch 36/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.01it/s, loss=0.231]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.91it/s, loss=0.096]



Digit Network Validation Accuracy: 87.74%
Operator Network Validation Accuracy: 94.54%
New best accuracy! Saved models.

Epoch 36/50
Time taken: 12.83 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.54%

Epoch 37/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 20.85it/s, loss=0.235]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 18.71it/s, loss=0.091]



Epoch 37/50
Time taken: 12.79 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.54%

Epoch 38/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.66it/s, loss=0.227]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.01it/s, loss=0.0785]



Epoch 38/50
Time taken: 12.14 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.54%

Epoch 39/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.94it/s, loss=0.193]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.10it/s, loss=0.0805]



Epoch 39/50
Time taken: 12.03 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.54%

Epoch 40/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.93it/s, loss=0.168]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.56it/s, loss=0.0785]



Epoch 40/50
Time taken: 11.91 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.54%

Epoch 41/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.37it/s, loss=0.164]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.48it/s, loss=0.0843]



Digit Network Validation Accuracy: 87.74%
Operator Network Validation Accuracy: 94.75%
New best accuracy! Saved models.

Epoch 41/50
Time taken: 12.51 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.75%

Epoch 42/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.50it/s, loss=0.154]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.78it/s, loss=0.0674]



Epoch 42/50
Time taken: 11.69 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.75%

Epoch 43/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.19it/s, loss=0.147]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.51it/s, loss=0.057]



Epoch 43/50
Time taken: 11.85 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.75%

Epoch 44/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.23it/s, loss=0.13]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.61it/s, loss=0.0684]



Epoch 44/50
Time taken: 12.10 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.75%

Epoch 45/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.25it/s, loss=0.118]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.57it/s, loss=0.0641]



Epoch 45/50
Time taken: 12.39 seconds
Best Digit Accuracy: 87.74%
Best Operator Accuracy: 94.75%

Epoch 46/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.13it/s, loss=0.12]
Training Operator Net: 100%|██████████| 120/120 [00:06<00:00, 19.40it/s, loss=0.0514]



Digit Network Validation Accuracy: 88.49%
Operator Network Validation Accuracy: 96.01%
New best accuracy! Saved models.

Epoch 46/50
Time taken: 13.30 seconds
Best Digit Accuracy: 88.49%
Best Operator Accuracy: 96.01%

Epoch 47/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 22.09it/s, loss=0.0955]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.63it/s, loss=0.0481]



Epoch 47/50
Time taken: 11.84 seconds
Best Digit Accuracy: 88.49%
Best Operator Accuracy: 96.01%

Epoch 48/50


Training Digit Net: 100%|██████████| 133/133 [00:06<00:00, 21.42it/s, loss=0.0881]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.22it/s, loss=0.0512]



Epoch 48/50
Time taken: 12.15 seconds
Best Digit Accuracy: 88.49%
Best Operator Accuracy: 96.01%

Epoch 49/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.46it/s, loss=0.101]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.30it/s, loss=0.0423]



Epoch 49/50
Time taken: 11.83 seconds
Best Digit Accuracy: 88.49%
Best Operator Accuracy: 96.01%

Epoch 50/50


Training Digit Net: 100%|██████████| 133/133 [00:05<00:00, 22.27it/s, loss=0.0804]
Training Operator Net: 100%|██████████| 120/120 [00:05<00:00, 20.05it/s, loss=0.0434]


Epoch 50/50
Time taken: 11.96 seconds
Best Digit Accuracy: 88.49%
Best Operator Accuracy: 96.01%
Finished Training

Saving final model weights...
Final weights saved successfully:
- Digit network: ./weights/digit_net_final.pth
- Operator network: ./weights/operator_net_final.pth

Best weights were saved during training:
- Best digit network: ./weights/digit_net_best.pth
- Best operator network: ./weights/operator_net_best.pth





In [34]:

print("\nPerforming final test evaluation...")
digit_net.eval()
operator_net.eval()

# Create test dataloaders
digit_testloader = DataLoader(digit_test_dataset, batch_size=32, shuffle=False)
operator_testloader = DataLoader(operator_test_dataset, batch_size=32, shuffle=False)

with torch.no_grad():
    # Test digit network
    correct = 0
    total = 0
    for images, labels in digit_testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = digit_net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    test_digit_accuracy = 100 * correct / total
    print(f"\nFinal Digit Test Accuracy: {test_digit_accuracy:.2f}%")
    
    # Test operator network
    correct = 0
    total = 0
    for images, labels in operator_testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = operator_net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    test_operator_accuracy = 100 * correct / total
    print(f"Final Operator Test Accuracy: {test_operator_accuracy:.2f}%")

print("\nTesting complete!")


Performing final test evaluation...

Final Digit Test Accuracy: 90.02%
Final Operator Test Accuracy: 95.82%

Testing complete!
