In [27]:
import torch
# Define device first
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name()}")

PyTorch version: 2.5.1+cu118
CUDA available: True
GPU device: NVIDIA GeForce RTX 4080 SUPER


In [28]:
# Enter your kaggle user name and key
username = str('')
key = str('')

In [29]:
# All imports
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import shutil  # Add this import
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from datetime import datetime
from torch.utils.data import random_split

In [30]:
# Download dataset from Kaggle
import opendatasets as od
import pandas
import json

# Create .kaggle directory if it doesn't exist
os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)

try:
    kaggle_token = {
        "username": username,
        "key": key
    }
    
    # Verify the data is string
    print("\nVerifying credentials format:")
    print(f"Username type: {type(username)}")
    print(f"Key type: {type(key)}")
    
    with open(os.path.expanduser('~/.kaggle/kaggle.json'), 'w') as f:
        json.dump(kaggle_token, f)
    
    print("Kaggle credentials saved successfully")
except Exception as e:
    print(f"Error saving Kaggle credentials: {e}")
    raise

# Set permissions
os.chmod(os.path.expanduser('~/.kaggle/kaggle.json'), 0o600)

# Download dataset
dataset_url = "https://www.kaggle.com/datasets/sagyamthapa/handwritten-math-symbols"
od.download(dataset_url)


Verifying credentials format:
Username type: <class 'str'>
Key type: <class 'str'>
Kaggle credentials saved successfully
Skipping, found downloaded files in ".\handwritten-math-symbols" (use force=True to force download)


In [31]:
# Extract dataset from archive with more detailed logging
print("\nStarting extraction process...")
dataset_path = "./handwritten-math-symbols"
archive_path = f"{dataset_path}.zip"
print(f"Looking for archive at: {archive_path}")
print(f"Archive exists: {os.path.exists(archive_path)}")

if os.path.exists(archive_path):
    import zipfile
    with zipfile.ZipFile(archive_path, 'r') as zip_ref:
        print("\nContents of zip file:")
        for file in zip_ref.namelist()[:10]:
            print(f"- {file}")
        print("...")
        zip_ref.extractall("./")
    print("Dataset extracted successfully")

def verify_directory_structure():
    print("\nVerifying directory structure:")
    for root, dirs, files in os.walk('./data'):
        print(f"\nDirectory: {root}")
        print(f"Subdirectories: {dirs}")
        print(f"Number of files: {len(files)}")

# First organize the data
print("\nStarting file organization...")
source_dir = "./handwritten-math-symbols/dataset"
print(f"Looking for source directory at: {source_dir}")
print(f"Source directory exists: {os.path.exists(source_dir)}")

if not os.path.exists(source_dir):
    raise FileNotFoundError(f"Source directory {source_dir} not found!")

# Create destination directories
os.makedirs('./data/combined', exist_ok=True)  # New combined directory
os.makedirs('./data/combined/digits', exist_ok=True)
os.makedirs('./data/combined/operators', exist_ok=True)

# Define which folders belong to digits and operators
digit_folders = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
operator_folders = ['add', 'sub', 'mul', 'div', 'eq', 'dec', 'x', 'y', 'z']

# Move digit folders
print("\nMoving digit folders...")
for folder in digit_folders:
    src = os.path.join(source_dir, folder)
    dst = os.path.join('./data/combined/digits', folder)
    if os.path.exists(src):
        print(f"Copying {src} to {dst}")
        shutil.copytree(src, dst, dirs_exist_ok=True)
    else:
        print(f"Warning: Source folder not found: {src}")

# Move operator folders
print("\nMoving operator folders...")
for folder in operator_folders:
    src = os.path.join(source_dir, folder)
    dst = os.path.join('./data/combined/operators', folder)
    if os.path.exists(src):
        print(f"Copying {src} to {dst}")
        shutil.copytree(src, dst, dirs_exist_ok=True)
    else:
        print(f"Warning: Source folder not found: {src}")

# Verify the directory structure
verify_directory_structure()

# Verify that we have data in the folders
print("\nVerifying data in folders:")
digit_path = './data/combined/digits'
operator_path = './data/combined/operators'

if not os.path.exists(digit_path) or not os.path.exists(operator_path):
    raise FileNotFoundError("Data directories not created properly!")

digit_classes = sorted(os.listdir(digit_path))
operator_classes = sorted(os.listdir(operator_path))

print(f"\nFound digit classes: {digit_classes}")
print(f"Found operator classes: {operator_classes}")

# Create class mappings for the combined model
class_info = {
    'digit_classes': digit_classes,
    'operator_classes': operator_classes,
    'digit_to_idx': {cls: idx for idx, cls in enumerate(digit_classes)},
    'operator_to_idx': {cls: idx for idx, cls in enumerate(operator_classes)},
    'num_digit_classes': len(digit_classes),
    'num_operator_classes': len(operator_classes)
}

print("\nClass mappings created:")
print(f"Digit classes: {class_info['digit_to_idx']}")
print(f"Operator classes: {class_info['operator_to_idx']}")
print(f"Total classes: {class_info['num_digit_classes'] + class_info['num_operator_classes']}")

if not digit_classes:
    raise FileNotFoundError("No digit classes found!")
if not operator_classes:
    raise FileNotFoundError("No operator classes found!")


Starting extraction process...
Looking for archive at: ./handwritten-math-symbols.zip
Archive exists: False

Starting file organization...
Looking for source directory at: ./handwritten-math-symbols/dataset
Source directory exists: True

Moving digit folders...
Copying ./handwritten-math-symbols/dataset\0 to ./data/combined/digits\0
Copying ./handwritten-math-symbols/dataset\1 to ./data/combined/digits\1
Copying ./handwritten-math-symbols/dataset\2 to ./data/combined/digits\2
Copying ./handwritten-math-symbols/dataset\3 to ./data/combined/digits\3
Copying ./handwritten-math-symbols/dataset\4 to ./data/combined/digits\4
Copying ./handwritten-math-symbols/dataset\5 to ./data/combined/digits\5
Copying ./handwritten-math-symbols/dataset\6 to ./data/combined/digits\6
Copying ./handwritten-math-symbols/dataset\7 to ./data/combined/digits\7
Copying ./handwritten-math-symbols/dataset\8 to ./data/combined/digits\8
Copying ./handwritten-math-symbols/dataset\9 to ./data/combined/digits\9

Moving

In [32]:
# Create transforms
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create datasets with error handling
try:
    digit_dataset = torchvision.datasets.ImageFolder(
        root='./data/digits',
        transform=transform
    )
    print(f"\nDigit dataset created successfully with {len(digit_dataset)} images")
except Exception as e:
    print(f"Error creating digit dataset: {e}")
    raise

try:
    operator_dataset = torchvision.datasets.ImageFolder(
        root='./data/operators',
        transform=transform
    )
    print(f"Operator dataset created successfully with {len(operator_dataset)} images")
except Exception as e:
    print(f"Error creating operator dataset: {e}")
    raise

# Create dataset splits
print("\nCreating dataset splits...")

# For Digit Dataset
total_digit_size = len(digit_dataset)
train_digit_size = int(0.8 * total_digit_size)
val_digit_size = int(0.1 * total_digit_size)
test_digit_size = total_digit_size - train_digit_size - val_digit_size

digit_train_dataset, digit_val_dataset, digit_test_dataset = random_split(
    digit_dataset, 
    [train_digit_size, val_digit_size, test_digit_size]
)

# For Operator Dataset
total_operator_size = len(operator_dataset)
train_operator_size = int(0.8 * total_operator_size)
val_operator_size = int(0.1 * total_operator_size)
test_operator_size = total_operator_size - train_operator_size - val_operator_size

operator_train_dataset, operator_val_dataset, operator_test_dataset = random_split(
    operator_dataset, 
    [train_operator_size, val_operator_size, test_operator_size]
)

# Create combined dataloaders
digit_trainloader = torch.utils.data.DataLoader(
    digit_train_dataset,
    batch_size=32,
    shuffle=True, 
    num_workers=2
)

operator_trainloader = torch.utils.data.DataLoader(
    operator_train_dataset, 
    batch_size=32,
    shuffle=True, 
    num_workers=2
)

# Create validation dataloaders
digit_valloader = torch.utils.data.DataLoader(
    digit_val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

operator_valloader = torch.utils.data.DataLoader(
    operator_val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

# Create test dataloaders
digit_testloader = torch.utils.data.DataLoader(
    digit_test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

operator_testloader = torch.utils.data.DataLoader(
    operator_test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

print(f"Digit - Train: {len(digit_train_dataset)}, Val: {len(digit_val_dataset)}, Test: {len(digit_test_dataset)}")
print(f"Operator - Train: {len(operator_train_dataset)}, Val: {len(operator_val_dataset)}, Test: {len(operator_test_dataset)}")

# Print dataset information
print("\nDatasets and dataloaders created successfully!")
print(f"Number of digit classes: {len(digit_dataset.classes)}")
print(f"Number of operator classes: {len(operator_dataset.classes)}")

# Store class information for the combined model
num_digit_classes = len(digit_dataset.classes)
num_operator_classes = len(operator_dataset.classes)
digit_class_to_idx = digit_dataset.class_to_idx
operator_class_to_idx = operator_dataset.class_to_idx

print("\nClass mappings:")
print(f"Digit classes: {digit_class_to_idx}")
print(f"Operator classes: {operator_class_to_idx}")


Digit dataset created successfully with 5304 images
Operator dataset created successfully with 4767 images

Creating dataset splits...
Digit - Train: 4243, Val: 530, Test: 531
Operator - Train: 3813, Val: 476, Test: 478

Datasets and dataloaders created successfully!
Number of digit classes: 10
Number of operator classes: 9

Class mappings:
Digit classes: {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
Operator classes: {'add': 0, 'dec': 1, 'div': 2, 'eq': 3, 'mul': 4, 'sub': 5, 'x': 6, 'y': 7, 'z': 8}


In [33]:
# Define the classes based on the folders
digit_classes = sorted(os.listdir('./data/digits'))
operator_classes = sorted(os.listdir('./data/operators'))

class CombinedNet(nn.Module):
    def __init__(self):
        super(CombinedNet, self).__init__()
        # First conv block
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)  # 32x32 -> 32x32
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)  # 32x32 -> 32x32
        self.bn2 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)  # 32x32 -> 16x16
        
        # Second conv block
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)  # 16x16 -> 16x16
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)  # 16x16 -> 16x16
        self.bn4 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)  # 16x16 -> 8x8
        
        # Third conv block
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1)  # 8x8 -> 8x8
        self.bn5 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)  # 8x8 -> 4x4
        
        # Shared features end here
        
        # Type classification (digit vs operator)
        self.type_fc = nn.Linear(128 * 4 * 4, 2)
        
        # Digit classification (0-9)
        self.digit_fc = nn.Linear(128 * 4 * 4, len(digit_classes))
        
        # Operator classification
        self.operator_fc = nn.Linear(128 * 4 * 4, len(operator_classes))

    def forward(self, x):
        # First block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        
        # Second block
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        
        # Third block
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool3(x)
        
        # Flatten
        features = x.view(-1, 128 * 4 * 4)
        
        # Get all predictions
        type_out = self.type_fc(features)      # Is it digit or operator?
        digit_out = self.digit_fc(features)    # If digit, which one?
        operator_out = self.operator_fc(features)  # If operator, which one?
        
        return type_out, digit_out, operator_out


# Initialize the combined network
combined_net = CombinedNet().to(device)

# Single optimizer for the combined network
optimizer = optim.Adam(combined_net.parameters(), lr=0.0001)

def predict_with_threshold(image, model, device, type_threshold=0.8):
    model.eval()  # Set to evaluation mode
    with torch.no_grad():
        # Ensure image is on the correct device
        if not isinstance(image, torch.Tensor):
            image = transform(image).unsqueeze(0)  # Add batch dimension
        image = image.to(device)
        
        # Get predictions
        type_out, digit_out, operator_out = model(image)
        
        # Get all probabilities
        type_prob = F.softmax(type_out, dim=1)
        digit_prob = F.softmax(digit_out, dim=1)
        operator_prob = F.softmax(operator_out, dim=1)
        
        type_conf, type_pred = torch.max(type_prob, 1)
        digit_conf, digit_pred = torch.max(digit_prob, 1)
        operator_conf, operator_pred = torch.max(operator_prob, 1)
        
        # Only trust type prediction if confident enough
        if type_conf >= type_threshold:
            if type_pred == 0:  # Digit
                return 'digit', digit_pred.item(), digit_conf.item()
            else:  # Operator
                return 'operator', operator_pred.item(), operator_conf.item()
        else:
            # If not confident about type, use highest confidence between digit and operator
            if digit_conf > operator_conf:
                return 'digit', digit_pred.item(), digit_conf.item()
            else:
                return 'operator', operator_pred.item(), operator_conf.item()


# Diagnostic code to check dimensions and classes
print("\nDiagnostic Information:")
print(f"Digit classes: {digit_classes}")
print(f"Number of digit classes: {len(digit_classes)}")
print(f"Operator classes: {operator_classes}")
print(f"Number of operator classes: {len(operator_classes)}")

# Verify network structure
print("\nNetwork Architecture Check:")
print("CombinedNet outputs:")
print(f"Type classification: 2 classes (digit/operator)")
print(f"Digit classification: {len(digit_classes)} classes")
print(f"Operator classification: {len(operator_classes)} classes")

# Try moving to GPU with error handling
try:
    print("\nMoving network to GPU...")
    combined_net = combined_net.to(device)
    print("Successfully moved network to GPU")
except RuntimeError as e:
    print(f"Error moving network to GPU: {e}")
    print("Falling back to CPU")
    device = torch.device("cpu")
    combined_net = combined_net.to(device)

# Now verify data and label shapes
print("\nVerifying data shapes:")
sample_digit_batch = next(iter(digit_trainloader))
sample_operator_batch = next(iter(operator_trainloader))

print(f"Digit batch - Images: {sample_digit_batch[0].shape}, Labels: {sample_digit_batch[1].shape}")
print(f"Operator batch - Images: {sample_operator_batch[0].shape}, Labels: {sample_operator_batch[1].shape}")

# Verify network output dimensions
with torch.no_grad():
    type_out, digit_out, operator_out = combined_net(sample_digit_batch[0].to(device))
    print("\nNetwork output dimensions:")
    print(f"Type output: {type_out.shape}")
    print(f"Digit output: {digit_out.shape}")
    print(f"Operator output: {operator_out.shape}")


Diagnostic Information:
Digit classes: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Number of digit classes: 10
Operator classes: ['add', 'dec', 'div', 'eq', 'mul', 'sub', 'x', 'y', 'z']
Number of operator classes: 9

Network Architecture Check:
CombinedNet outputs:
Type classification: 2 classes (digit/operator)
Digit classification: 10 classes
Operator classification: 9 classes

Moving network to GPU...
Successfully moved network to GPU

Verifying data shapes:
Digit batch - Images: torch.Size([32, 1, 32, 32]), Labels: torch.Size([32])
Operator batch - Images: torch.Size([32, 1, 32, 32]), Labels: torch.Size([32])

Network output dimensions:
Type output: torch.Size([32, 2])
Digit output: torch.Size([32, 10])
Operator output: torch.Size([32, 9])


In [34]:
# Training parameters
num_epochs = 50
eval_interval = 5
early_stopping_patience = 15
best_accuracy = 0
epochs_without_improvement = 0

# Get current date for logging
training_date = datetime.now().strftime("%Y-%m-%d")
start_time = time.time()

# Add learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)

# Add criterion definitions
criterion = nn.CrossEntropyLoss()

# Create weights directory if it doesn't exist
os.makedirs('./weights', exist_ok=True)
best_model_path = './weights/combined_net_best.pth'

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    combined_net.train()
    running_loss = 0.0
    epoch_start = time.time()
    
    # Calculate number of batches
    n_batches = min(len(digit_trainloader), len(operator_trainloader))
    
    # Add progress bar with known length
    pbar = tqdm(zip(digit_trainloader, operator_trainloader), 
                total=n_batches,  # Set total number of batches
                desc=f"Training")
    
    for batch_idx, (digit_data, operator_data) in enumerate(pbar):
        # Prepare digit data
        digit_inputs, digit_labels = digit_data
        digit_inputs, digit_labels = digit_inputs.to(device), digit_labels.to(device)
        digit_type = torch.zeros(digit_inputs.size(0), dtype=torch.long).to(device)
        
        # Prepare operator data
        operator_inputs, operator_labels = operator_data
        operator_inputs, operator_labels = operator_inputs.to(device), operator_labels.to(device)
        operator_type = torch.ones(operator_inputs.size(0), dtype=torch.long).to(device)
        
        # Combined batch
        inputs = torch.cat([digit_inputs, operator_inputs])
        type_labels = torch.cat([digit_type, operator_type])
        
        optimizer.zero_grad()
        
        # Forward pass
        type_out, digit_out, operator_out = combined_net(inputs)
        
        # Calculate losses
        type_loss = criterion(type_out, type_labels)
        digit_loss = criterion(digit_out[:len(digit_type)], digit_labels)
        operator_loss = criterion(operator_out[len(digit_type):], operator_labels)
        
        # Combined loss
        loss = type_loss + digit_loss + operator_loss
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss/(batch_idx + 1),
            'type_loss': type_loss.item(),
            'digit_loss': digit_loss.item(),
            'operator_loss': operator_loss.item()
        })
    
    # Evaluation every eval_interval epochs
    if epoch % eval_interval == 0:
        combined_net.eval()
        
        correct_type = 0
        correct_digit = 0
        correct_operator = 0
        total = 0
        
        with torch.no_grad():
            # Evaluate on validation set
            for digit_data, operator_data in zip(digit_valloader, operator_valloader):
                # Process digit data
                digit_inputs, digit_labels = digit_data
                digit_inputs = digit_inputs.to(device)
                digit_type = torch.zeros(digit_inputs.size(0), dtype=torch.long).to(device)
                
                # Process operator data
                operator_inputs, operator_labels = operator_data
                operator_inputs = operator_inputs.to(device)
                operator_type = torch.ones(operator_inputs.size(0), dtype=torch.long).to(device)
                
                # Combined validation
                inputs = torch.cat([digit_inputs, operator_inputs])
                type_labels = torch.cat([digit_type, operator_type])
                
                type_out, digit_out, operator_out = combined_net(inputs)
                
                # Calculate accuracies
                _, predicted_type = torch.max(type_out.data, 1)
                correct_type += (predicted_type == type_labels).sum().item()
                
                _, predicted_digit = torch.max(digit_out[:len(digit_type)].data, 1)
                correct_digit += (predicted_digit == digit_labels.to(device)).sum().item()
                
                _, predicted_operator = torch.max(operator_out[len(digit_type):].data, 1)
                correct_operator += (predicted_operator == operator_labels.to(device)).sum().item()
                
                total += type_labels.size(0)
        
        type_accuracy = 100 * correct_type / total
        digit_accuracy = 100 * correct_digit / (total/2)
        operator_accuracy = 100 * correct_operator / (total/2)
        
        print(f"\nValidation Accuracies:")
        print(f"Type Classification: {type_accuracy:.2f}%")
        print(f"Digit Recognition: {digit_accuracy:.2f}%")
        print(f"Operator Recognition: {operator_accuracy:.2f}%")
        
        # Update scheduler based on type accuracy
        scheduler.step(type_accuracy)
        
        # Save best model based on type accuracy
        if type_accuracy > best_accuracy:
            best_accuracy = type_accuracy
            torch.save(combined_net.state_dict(), best_model_path)
            epochs_without_improvement = 0
            print("New best accuracy! Saved model.")
        else:
            epochs_without_improvement += eval_interval
        
        # Early stopping check
        if epochs_without_improvement >= early_stopping_patience:
            print("\nEarly stopping triggered!")
            break
    
    # Print epoch summary
    epoch_time = time.time() - epoch_start
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Time taken: {epoch_time:.2f} seconds")
    print(f"Best Accuracy: {best_accuracy:.2f}%")

print('Finished Training')

# Save final weights
print("\nSaving final model weights...")
try:
    final_model_path = './weights/combined_net_final.pth'
    torch.save(combined_net.state_dict(), final_model_path)
    print(f"Final weights saved successfully: {final_model_path}")
except Exception as e:
    print(f"Error saving final weights: {e}")

if os.path.exists(best_model_path):
    print(f"\nBest weights were saved during training: {best_model_path}")
else:
    print("\nWarning: Best weight file not found!")


Epoch 1/50


Training: 100%|██████████| 120/120 [00:02<00:00, 50.62it/s, loss=3.16, type_loss=0.384, digit_loss=1.17, operator_loss=0.537]



Validation Accuracies:
Type Classification: 90.17%
Digit Recognition: 65.27%
Operator Recognition: 84.73%
New best accuracy! Saved model.

Epoch 1/50
Time taken: 17.74 seconds
Best Accuracy: 90.17%

Epoch 2/50


Training: 100%|██████████| 120/120 [00:01<00:00, 60.14it/s, loss=1.53, type_loss=0.36, digit_loss=0.912, operator_loss=0.542] 



Epoch 2/50
Time taken: 9.04 seconds
Best Accuracy: 90.17%

Epoch 3/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.02it/s, loss=0.963, type_loss=0.223, digit_loss=0.561, operator_loss=0.37]  



Epoch 3/50
Time taken: 9.16 seconds
Best Accuracy: 90.17%

Epoch 4/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.78it/s, loss=0.674, type_loss=0.13, digit_loss=0.288, operator_loss=0.194]  



Epoch 4/50
Time taken: 9.19 seconds
Best Accuracy: 90.17%

Epoch 5/50


Training: 100%|██████████| 120/120 [00:01<00:00, 60.48it/s, loss=0.495, type_loss=0.241, digit_loss=0.19, operator_loss=0.0826]  



Epoch 5/50
Time taken: 9.06 seconds
Best Accuracy: 90.17%

Epoch 6/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.97it/s, loss=0.369, type_loss=0.0777, digit_loss=0.127, operator_loss=0.0111]



Validation Accuracies:
Type Classification: 95.92%
Digit Recognition: 92.26%
Operator Recognition: 96.86%
New best accuracy! Saved model.

Epoch 6/50
Time taken: 17.15 seconds
Best Accuracy: 95.92%

Epoch 7/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.80it/s, loss=0.287, type_loss=0.0723, digit_loss=0.161, operator_loss=0.0288] 



Epoch 7/50
Time taken: 9.19 seconds
Best Accuracy: 95.92%

Epoch 8/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.53it/s, loss=0.226, type_loss=0.1, digit_loss=0.1, operator_loss=0.0352]      



Epoch 8/50
Time taken: 9.13 seconds
Best Accuracy: 95.92%

Epoch 9/50


Training: 100%|██████████| 120/120 [00:02<00:00, 56.23it/s, loss=0.181, type_loss=0.0996, digit_loss=0.0792, operator_loss=0.0274]



Epoch 9/50
Time taken: 9.13 seconds
Best Accuracy: 95.92%

Epoch 10/50


Training: 100%|██████████| 120/120 [00:02<00:00, 57.06it/s, loss=0.137, type_loss=0.0685, digit_loss=0.124, operator_loss=0.0316]  



Epoch 10/50
Time taken: 9.16 seconds
Best Accuracy: 95.92%

Epoch 11/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.32it/s, loss=0.115, type_loss=0.103, digit_loss=0.0654, operator_loss=0.0177]  



Validation Accuracies:
Type Classification: 98.01%
Digit Recognition: 97.07%
Operator Recognition: 98.33%
New best accuracy! Saved model.

Epoch 11/50
Time taken: 17.18 seconds
Best Accuracy: 98.01%

Epoch 12/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.91it/s, loss=0.0911, type_loss=0.0251, digit_loss=0.0139, operator_loss=0.0123] 



Epoch 12/50
Time taken: 9.20 seconds
Best Accuracy: 98.01%

Epoch 13/50


Training: 100%|██████████| 120/120 [00:02<00:00, 56.21it/s, loss=0.0767, type_loss=0.0908, digit_loss=0.0324, operator_loss=0.0394] 



Epoch 13/50
Time taken: 9.21 seconds
Best Accuracy: 98.01%

Epoch 14/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.44it/s, loss=0.063, type_loss=0.0547, digit_loss=0.0631, operator_loss=0.0119]  



Epoch 14/50
Time taken: 9.21 seconds
Best Accuracy: 98.01%

Epoch 15/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.11it/s, loss=0.0521, type_loss=0.0564, digit_loss=0.028, operator_loss=0.00955]  



Epoch 15/50
Time taken: 9.21 seconds
Best Accuracy: 98.01%

Epoch 16/50


Training: 100%|██████████| 120/120 [00:02<00:00, 56.33it/s, loss=0.0449, type_loss=0.0586, digit_loss=0.0331, operator_loss=0.0126]  



Validation Accuracies:
Type Classification: 98.54%
Digit Recognition: 97.70%
Operator Recognition: 98.54%
New best accuracy! Saved model.

Epoch 16/50
Time taken: 17.09 seconds
Best Accuracy: 98.54%

Epoch 17/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.82it/s, loss=0.0385, type_loss=0.0733, digit_loss=0.0187, operator_loss=0.0516]   



Epoch 17/50
Time taken: 9.17 seconds
Best Accuracy: 98.54%

Epoch 18/50


Training: 100%|██████████| 120/120 [00:02<00:00, 58.12it/s, loss=0.0336, type_loss=0.0325, digit_loss=0.0222, operator_loss=0.0202]  



Epoch 18/50
Time taken: 9.10 seconds
Best Accuracy: 98.54%

Epoch 19/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.67it/s, loss=0.0283, type_loss=0.0335, digit_loss=0.0137, operator_loss=0.00473]  



Epoch 19/50
Time taken: 9.21 seconds
Best Accuracy: 98.54%

Epoch 20/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.98it/s, loss=0.0245, type_loss=0.0167, digit_loss=0.0201, operator_loss=0.00184] 



Epoch 20/50
Time taken: 9.42 seconds
Best Accuracy: 98.54%

Epoch 21/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.88it/s, loss=0.0212, type_loss=0.0378, digit_loss=0.0142, operator_loss=0.0249]   



Validation Accuracies:
Type Classification: 98.85%
Digit Recognition: 98.54%
Operator Recognition: 98.74%
New best accuracy! Saved model.

Epoch 21/50
Time taken: 18.17 seconds
Best Accuracy: 98.85%

Epoch 22/50


Training: 100%|██████████| 120/120 [00:02<00:00, 56.33it/s, loss=0.0195, type_loss=0.0138, digit_loss=0.00958, operator_loss=0.0012]  



Epoch 22/50
Time taken: 9.31 seconds
Best Accuracy: 98.85%

Epoch 23/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.51it/s, loss=0.0166, type_loss=0.0446, digit_loss=0.00633, operator_loss=0.0148]  



Epoch 23/50
Time taken: 9.39 seconds
Best Accuracy: 98.85%

Epoch 24/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.50it/s, loss=0.0161, type_loss=0.0516, digit_loss=0.018, operator_loss=0.00654]   



Epoch 24/50
Time taken: 9.31 seconds
Best Accuracy: 98.85%

Epoch 25/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.85it/s, loss=0.0161, type_loss=0.0151, digit_loss=0.0052, operator_loss=0.00711]  



Epoch 25/50
Time taken: 9.24 seconds
Best Accuracy: 98.85%

Epoch 26/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.60it/s, loss=0.0124, type_loss=0.023, digit_loss=0.00751, operator_loss=0.0131]    



Validation Accuracies:
Type Classification: 98.74%
Digit Recognition: 98.54%
Operator Recognition: 98.95%

Epoch 26/50
Time taken: 17.16 seconds
Best Accuracy: 98.85%

Epoch 27/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.62it/s, loss=0.0117, type_loss=0.00433, digit_loss=0.00614, operator_loss=0.00214] 



Epoch 27/50
Time taken: 9.24 seconds
Best Accuracy: 98.85%

Epoch 28/50


Training: 100%|██████████| 120/120 [00:02<00:00, 59.14it/s, loss=0.0102, type_loss=0.0122, digit_loss=0.00633, operator_loss=0.00469]  



Epoch 28/50
Time taken: 9.28 seconds
Best Accuracy: 98.85%

Epoch 29/50


Training: 100%|██████████| 120/120 [00:02<00:00, 51.94it/s, loss=0.00945, type_loss=0.0353, digit_loss=0.00976, operator_loss=0.0116]   



Epoch 29/50
Time taken: 9.42 seconds
Best Accuracy: 98.85%

Epoch 30/50


Training: 100%|██████████| 120/120 [00:02<00:00, 52.97it/s, loss=0.0117, type_loss=0.023, digit_loss=0.0157, operator_loss=0.000968]   



Epoch 30/50
Time taken: 9.79 seconds
Best Accuracy: 98.85%

Epoch 31/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.98it/s, loss=0.0107, type_loss=0.023, digit_loss=0.0111, operator_loss=0.000847]   



Validation Accuracies:
Type Classification: 98.95%
Digit Recognition: 98.12%
Operator Recognition: 98.95%
New best accuracy! Saved model.

Epoch 31/50
Time taken: 17.36 seconds
Best Accuracy: 98.95%

Epoch 32/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.87it/s, loss=0.00821, type_loss=0.00969, digit_loss=0.0023, operator_loss=0.00267]  



Epoch 32/50
Time taken: 9.31 seconds
Best Accuracy: 98.95%

Epoch 33/50


Training: 100%|██████████| 120/120 [00:02<00:00, 53.43it/s, loss=0.00739, type_loss=0.00812, digit_loss=0.00428, operator_loss=0.0012]   



Epoch 33/50
Time taken: 9.43 seconds
Best Accuracy: 98.95%

Epoch 34/50


Training: 100%|██████████| 120/120 [00:02<00:00, 53.65it/s, loss=0.00698, type_loss=0.00659, digit_loss=0.00435, operator_loss=0.000843] 



Epoch 34/50
Time taken: 9.36 seconds
Best Accuracy: 98.95%

Epoch 35/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.41it/s, loss=0.00543, type_loss=0.00978, digit_loss=0.00387, operator_loss=0.00105]  



Epoch 35/50
Time taken: 9.40 seconds
Best Accuracy: 98.95%

Epoch 36/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.49it/s, loss=0.0053, type_loss=0.0117, digit_loss=0.00541, operator_loss=0.00512]    



Validation Accuracies:
Type Classification: 98.95%
Digit Recognition: 97.49%
Operator Recognition: 98.95%

Epoch 36/50
Time taken: 17.27 seconds
Best Accuracy: 98.95%

Epoch 37/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.88it/s, loss=0.00783, type_loss=0.0416, digit_loss=0.00321, operator_loss=0.00413]    



Epoch 37/50
Time taken: 9.12 seconds
Best Accuracy: 98.95%

Epoch 38/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.55it/s, loss=0.0136, type_loss=0.00708, digit_loss=0.00269, operator_loss=0.00878]  



Epoch 38/50
Time taken: 9.16 seconds
Best Accuracy: 98.95%

Epoch 39/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.01it/s, loss=0.00588, type_loss=0.0327, digit_loss=0.00255, operator_loss=0.000633]  



Epoch 39/50
Time taken: 9.17 seconds
Best Accuracy: 98.95%

Epoch 40/50


Training: 100%|██████████| 120/120 [00:02<00:00, 56.72it/s, loss=0.00717, type_loss=0.0122, digit_loss=0.00155, operator_loss=0.00707]    



Epoch 40/50
Time taken: 9.02 seconds
Best Accuracy: 98.95%

Epoch 41/50


Training: 100%|██████████| 120/120 [00:02<00:00, 56.59it/s, loss=0.00546, type_loss=0.00653, digit_loss=0.00402, operator_loss=0.000314] 



Validation Accuracies:
Type Classification: 98.64%
Digit Recognition: 97.28%
Operator Recognition: 98.54%

Epoch 41/50
Time taken: 16.94 seconds
Best Accuracy: 98.95%

Epoch 42/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.76it/s, loss=0.0043, type_loss=0.0173, digit_loss=0.00144, operator_loss=0.00147]     



Epoch 42/50
Time taken: 9.08 seconds
Best Accuracy: 98.95%

Epoch 43/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.50it/s, loss=0.00528, type_loss=0.00899, digit_loss=0.00311, operator_loss=0.0015]    



Epoch 43/50
Time taken: 9.18 seconds
Best Accuracy: 98.95%

Epoch 44/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.07it/s, loss=0.00492, type_loss=0.0117, digit_loss=0.00194, operator_loss=0.00128]    



Epoch 44/50
Time taken: 9.17 seconds
Best Accuracy: 98.95%

Epoch 45/50


Training: 100%|██████████| 120/120 [00:02<00:00, 55.27it/s, loss=0.00415, type_loss=0.00878, digit_loss=0.00147, operator_loss=0.000655]  



Epoch 45/50
Time taken: 9.13 seconds
Best Accuracy: 98.95%

Epoch 46/50


Training: 100%|██████████| 120/120 [00:02<00:00, 54.32it/s, loss=0.00374, type_loss=0.00408, digit_loss=0.00195, operator_loss=0.00123]   



Validation Accuracies:
Type Classification: 98.85%
Digit Recognition: 98.12%
Operator Recognition: 98.95%

Early stopping triggered!
Finished Training

Saving final model weights...
Final weights saved successfully: ./weights/combined_net_final.pth

Best weights were saved during training: ./weights/combined_net_best.pth


In [35]:
print("\nPerforming final test evaluation...")
combined_net.eval()

# Create test dataloaders
digit_testloader = DataLoader(digit_test_dataset, batch_size=32, shuffle=False)
operator_testloader = DataLoader(operator_test_dataset, batch_size=32, shuffle=False)

with torch.no_grad():
    correct_type = 0
    correct_digit = 0
    correct_operator = 0
    total = 0
    
    # Test on both digit and operator data
    for digit_data, operator_data in zip(digit_testloader, operator_testloader):
        # Process digit data
        digit_inputs, digit_labels = digit_data
        digit_inputs = digit_inputs.to(device)
        digit_type = torch.zeros(digit_inputs.size(0), dtype=torch.long).to(device)  # 0 for digits
        
        # Process operator data
        operator_inputs, operator_labels = operator_data
        operator_inputs = operator_inputs.to(device)
        operator_type = torch.ones(operator_inputs.size(0), dtype=torch.long).to(device)  # 1 for operators
        
        # Combined test batch
        inputs = torch.cat([digit_inputs, operator_inputs])
        type_labels = torch.cat([digit_type, operator_type])
        
        # Get predictions
        type_out, digit_out, operator_out = combined_net(inputs)
        
        # Calculate accuracies
        _, predicted_type = torch.max(type_out.data, 1)
        correct_type += (predicted_type == type_labels).sum().item()
        
        _, predicted_digit = torch.max(digit_out[:len(digit_type)].data, 1)
        correct_digit += (predicted_digit == digit_labels.to(device)).sum().item()
        
        _, predicted_operator = torch.max(operator_out[len(digit_type):].data, 1)
        correct_operator += (predicted_operator == operator_labels.to(device)).sum().item()
        
        total += type_labels.size(0)

    # Calculate final accuracies
    test_type_accuracy = 100 * correct_type / total
    test_digit_accuracy = 100 * correct_digit / (total/2)
    test_operator_accuracy = 100 * correct_operator / (total/2)
    
    print("\nFinal Test Results:")
    print(f"Type Classification Accuracy: {test_type_accuracy:.2f}%")
    print(f"Digit Recognition Accuracy: {test_digit_accuracy:.2f}%")
    print(f"Operator Recognition Accuracy: {test_operator_accuracy:.2f}%")

print("\nTesting complete!")



Performing final test evaluation...

Final Test Results:
Type Classification Accuracy: 99.16%
Digit Recognition Accuracy: 98.75%
Operator Recognition Accuracy: 98.75%

Testing complete!


In [36]:
# Second: Add threshold-based evaluation
def predict_with_threshold(image, type_threshold=0.8):
    """Predict with confidence threshold"""
    with torch.no_grad():
        # Get predictions
        type_out, digit_out, operator_out = combined_net(image)
        
        # Get all probabilities
        type_prob = F.softmax(type_out, dim=1)
        digit_prob = F.softmax(digit_out, dim=1)
        operator_prob = F.softmax(operator_out, dim=1)
        
        type_conf, type_pred = torch.max(type_prob, 1)
        digit_conf, digit_pred = torch.max(digit_prob, 1)
        operator_conf, operator_pred = torch.max(operator_prob, 1)
        
        # Only trust type prediction if confident enough
        if type_conf >= type_threshold:
            if type_pred == 0:  # Digit
                return 'digit', digit_pred.item(), digit_conf.item()
            else:  # Operator
                return 'operator', operator_pred.item(), operator_conf.item()
        else:
            # If not confident about type, use highest confidence
            if digit_conf > operator_conf:
                return 'digit', digit_pred.item(), digit_conf.item()
            else:
                return 'operator', operator_pred.item(), operator_conf.item()

print("\nPerforming threshold-based evaluation...")
# Test with different thresholds
thresholds = [0.7, 0.8, 0.9]
for threshold in thresholds:
    print(f"\nTesting with threshold {threshold}:")
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        # Test digits
        for images, labels in digit_testloader:
            images = images.to(device)
            labels = labels.to(device)
            for i in range(len(images)):
                pred_type, pred_class, conf = predict_with_threshold(
                    images[i].unsqueeze(0),  # Add batch dimension
                    threshold
                )
                if pred_type == 'digit' and pred_class == labels[i].item():
                    correct_predictions += 1
                total_predictions += 1
        
        # Test operators
        for images, labels in operator_testloader:
            images = images.to(device)
            labels = labels.to(device)
            for i in range(len(images)):
                pred_type, pred_class, conf = predict_with_threshold(
                    images[i].unsqueeze(0),  # Add batch dimension
                    threshold
                )
                if pred_type == 'operator' and pred_class == labels[i].item():
                    correct_predictions += 1
                total_predictions += 1
    
    accuracy = 100 * correct_predictions / total_predictions
    print(f"Overall Accuracy with threshold {threshold}: {accuracy:.2f}%")

print("\nThreshold-based testing complete!")


Performing threshold-based evaluation...

Testing with threshold 0.7:
Overall Accuracy with threshold 0.7: 98.22%

Testing with threshold 0.8:
Overall Accuracy with threshold 0.8: 98.32%

Testing with threshold 0.9:
Overall Accuracy with threshold 0.9: 98.32%

Threshold-based testing complete!
