In [None]:
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np




########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].

# Modified transforms - math symbols are typically black on white, so we'll adjust normalization
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Single channel normalization
])

# Create two separate datasets for digits and operators
digit_dataset = torchvision.datasets.ImageFolder(
    root='./data/digits',
    transform=transform
)

operator_dataset = torchvision.datasets.ImageFolder(
    root='./data/operators',
    transform=transform
)

# Create separate dataloaders
digit_trainloader = torch.utils.data.DataLoader(digit_dataset, batch_size=32,
                                              shuffle=True, num_workers=2)
operator_trainloader = torch.utils.data.DataLoader(operator_dataset, batch_size=32,
                                                 shuffle=True, num_workers=2)

# Define classes for each network
digit_classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
operator_classes = ('+', '-', '*', '/', '(', ')')

########################################################################
# Let us show some of the training images, for fun.



# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(digit_trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % digit_classes[labels[j]] for j in range(4)))


########################################################################
# 2. Define a Convolution Neural Network
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Copy the neural network from the Neural Networks section before and modify it to
# take 3-channel images (instead of 1-channel images as it was defined).

import torch.nn as nn
import torch.nn.functional as F


class DigitNet(nn.Module):
    def __init__(self):
        super(DigitNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # Changed input channels to 1
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # 10 digits

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class OperatorNet(nn.Module):
    def __init__(self):
        super(OperatorNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # Changed input channels to 1
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 6)  # 6 operators

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Create two networks
digit_net = DigitNet().to(device)
operator_net = OperatorNet().to(device)

# Create separate optimizers and loss functions
digit_criterion = nn.CrossEntropyLoss()
operator_criterion = nn.CrossEntropyLoss()

digit_optimizer = optim.SGD(digit_net.parameters(), lr=0.001, momentum=0.9)
operator_optimizer = optim.SGD(operator_net.parameters(), lr=0.001, momentum=0.9)


########################################################################
# 4. Train the network
# ^^^^^^^^^^^^^^^^^^^^
#
# This is when things start to get interesting.
# We simply have to loop over our data iterator, and feed the inputs to the
# network and optimize.

for epoch in range(2):  # loop over the dataset multiple times
    # Training digit network
    digit_running_loss = 0.0
    for i, data in enumerate(digit_trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device) 
        
        digit_optimizer.zero_grad()
        outputs = digit_net(inputs)
        loss = digit_criterion(outputs, labels)
        loss.backward()
        digit_optimizer.step()

        digit_running_loss += loss.item()
        if i % 500 == 499:    # print every 500 mini-batches
            print('[Digit Net - Epoch %d, Batch %5d] loss: %.3f' %
                  (epoch + 1, i + 1, digit_running_loss / 500))
            digit_running_loss = 0.0

    # Training operator network
    operator_running_loss = 0.0
    for i, data in enumerate(operator_trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device) 
        operator_optimizer.zero_grad()
        outputs = operator_net(inputs)
        loss = operator_criterion(outputs, labels)
        loss.backward()
        operator_optimizer.step()

        operator_running_loss += loss.item()
        if i % 500 == 499:    # print every 500 mini-batches
            print('[Operator Net - Epoch %d, Batch %5d] loss: %.3f' %
                  (epoch + 1, i + 1, operator_running_loss / 500))
            operator_running_loss = 0.0

    

digit_testloader = torch.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(root='./data/test/digits', transform=transform),
    batch_size=32, shuffle=False, num_workers=2
)

operator_testloader = torch.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(root='./data/test/operators', transform=transform),
    batch_size=32, shuffle=False, num_workers=2
)

# To load later
digit_net.load_state_dict(torch.load('digit_net.pth'))
operator_net.load_state_dict(torch.load('operator_net.pth'))

print('Finished Training')

# After training
torch.save(digit_net.state_dict(), 'digit_net.pth')
torch.save(operator_net.state_dict(), 'operator_net.pth')


########################################################################
# 5. Test the network on the test data
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We have trained the network for 2 passes over the training dataset.
# But we need to check if the network has learnt anything at all.
#
# We will check this by predicting the class label that the neural network
# outputs, and checking it against the ground-truth. If the prediction is
# correct, we add the sample to the list of correct predictions.
#
# Okay, first step. Let us display an image from the test set to get familiar.

dataiter = iter(digit_testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % digit_classes[labels[j]] for j in range(4)))

########################################################################
# Okay, now let us see what the neural network thinks these examples above are:

outputs = digit_net(images)

########################################################################
# The outputs are energies for the 10 classes.
# Higher the energy for a class, the more the network
# thinks that the image is of the particular class.
# So, let's get the index of the highest energy:
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % digit_classes[predicted[j]]
                              for j in range(4)))

########################################################################
# The results seem pretty good.
#
# Let us look at how the network performs on the whole dataset.

correct = 0
total = 0
with torch.no_grad():
    for data in digit_testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device) 
        outputs = digit_net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

def predict_symbol(image, model, classes):
    model.eval()
    with torch.no_grad():
        outputs = model(image.unsqueeze(0).to(device))
        _, predicted = torch.max(outputs, 1)
        return classes[predicted[0]]

def process_math_expression(image_path):
    """
    TODO: Implement this function to:
    1. Load and preprocess the image
    2. Segment it into individual symbols
    3. Classify each symbol using appropriate network
    4. Combine into mathematical expression
    5. Evaluate the expression
    """
    pass


########################################################################
# Okay, so what next?
#
# How do we run these neural networks on the GPU?
#
# Training on GPU
# ----------------
# Just like how you transfer a Tensor on to the GPU, you transfer the neural
# net onto the GPU.
#
# Let's first define our device as the first visible cuda device if we have
# CUDA available:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assume that we are on a CUDA machine, then this should print a CUDA device:

print(device)