# LAB2

In lab2, we will guide you through the basic Scallop Python API, called scallopy.
In this tutorial, you will learn:
1. How to construct and execute a Scallop program in Python using Scallopy.
2. How to perform a learning task in Python through Scallopy.

## Hello Scallopy

Let's write our first hello world Python program. You can do this purely through the scallopy interface. 

In [None]:
# We start from constructing a scallopy context
import scallopy
ctx = scallopy.ScallopContext()

# We declare a relation type using 'add_relation'. 
# This is equvalent to 'type hello(String)' in a .scl file
ctx.add_relation("hello", str)

# We add the fact hello("Hello World") to the scallopy context 
ctx.add_facts("hello", [("Hello World",)])

# We can execute the context through 'run'
ctx.run()
print(list(ctx.relation("hello")))


## MNIST
<div>
  <img src="img/mnist_example.png" width="300"/>
</div>

### P1: Count 2
Let's first construct the symbolic representation of the MNIST image, where the input facts for `digit(i, d)` where i is the image id, and d is the numerical value of the corresponding image.
Write a rule that counts how many `2` are there in the image. 

In [None]:
import os
import scallopy

ctx = scallopy.ScallopContext()

# TODO: Write the discrete input facts and rule here

ctx.run()
print(list(ctx.relation("num_of_2")))

### P2: Probabilistic Less than 5
Let's try to write a probabilistic symbolic representation of the MNIST images, where the input facts for `digit(i, d)` where i is the image id, and d is the numerical value of the corresponding image. 
The probabilities for the input facts should be randomly generated, ranging from 0 to 1. 
Further, the probabilities of one image being recognized to all different numbers (0-9) shall sum to 1.
Write a rule that counts how many numbers are less than 5 in the image. 
Use "minmaxprob" semiring to compute the query result with probability.

### P3: MNIST Sum 3
In this practice, we will use scallopy to train an MNIST digit recognition network. Given three MNIST numbers and their sum, we want to train a classifier that can identify the digits, and yields a correct sum of the input images.

**Step 1** Dataloader construction. 

First, we want to construct a train data loader, and a test data loader separately. 
Please fill in the `get_item` and `collate_fn` functions for the dataloader.
The `get_item` function shall take in an index and return a tuple. The first tuple element is a triplet of tensorized images, and the second tuple element is the sum of the images.
The `collate_fn` function shall take in a list of tuples returned by `get_item`, and return a tuple. The first tuple element is a triplet of batched tensors representing the images, and the second element is a tensor of batched sum values. 

In [None]:
import os
import random
from typing import *
import torch
import torchvision

mnist_img_transform = torchvision.transforms.Compose([
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(
    (0.1307,), (0.3081,)
  )
])

class MNISTSum3Dataset(torch.utils.data.Dataset):
  def __init__(
    self,
    root: str,
    train: bool = True,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = False,
  ):
    # Contains a MNIST dataset
    self.mnist_dataset = torchvision.datasets.MNIST(
      root,
      train=train,
      transform=transform,
      target_transform=target_transform,
      download=download,
    )
    self.index_map = list(range(len(self.mnist_dataset)))
    random.shuffle(self.index_map)

  def __len__(self):
    return int(len(self.mnist_dataset) / 3)

  # The `get_item` function shall take in an index and return a tuple. 
  # The first tuple element is a triplet of tensorized images, 
  # and the second tuple element is the sum of the images.
  def __getitem__(self, idx):
    # TODO: Complete the __getitem__ method
    raise NotImplementedError

  # The `collate_fn` function shall take in a list of tuples returned by `get_item`, 
  # and return a tuple. The first tuple element is triplet of batched tensors 
  # representing the images, and the second element is a tensor of batched sum values.  
  @staticmethod
  def collate_fn(batch):
    # TODO: complete the collate_fn method
    raise NotImplementedError

def mnist_sum_3_loader(data_dir, batch_size_train, batch_size_test):

  train_loader = torch.utils.data.DataLoader(
    MNISTSum3Dataset(
      data_dir,
      train=True,
      download=True,
      transform=mnist_img_transform,
    ),
    collate_fn=MNISTSum3Dataset.collate_fn,
    batch_size=batch_size_train,
    shuffle=True
  )

  test_loader = torch.utils.data.DataLoader(
    MNISTSum3Dataset(
      data_dir,
      train=False,
      download=True,
      transform=mnist_img_transform,
    ),
    collate_fn=MNISTSum3Dataset.collate_fn,
    batch_size=batch_size_test,
    shuffle=True
  )

  return train_loader, test_loader


You can take a look into the dataset with matplotlib. 

In [None]:
import matplotlib.pyplot as plt
import torch, random

# Feel free to modify the parameters below
seed = 1234
batch_size_train = 64
batch_size_test = 64

torch.manual_seed(seed)
random.seed(seed)
data_dir = os.path.abspath(os.path.join(os.path.abspath("__file__"), "../data"))
train_loader, test_loader = mnist_sum_3_loader(data_dir, batch_size_train, batch_size_test)

# Let's take a look into the dataset
print(f"The dataset size is: {len(train_loader)}.")
for (x, y) in train_loader:
    # The dataloader will give you batches of three MNIST images and their sum 
    (a_imgs, b_imgs, c_imgs), digits = (x, y)
    print(a_imgs.shape)

    # We can peek the CLEVR image in the dataset
    imgplot = plt.imshow(a_imgs[0].reshape(28, 28), cmap='gray')
    plt.show()
    break

**Step 2** Construct a classifier `MNISTNet` that takes in an MNIST image and returns a tensor of the probability it is the number between 0~9. Here is a link to a tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [1]:
from torch import nn
import torch.nn.functional as F

class MNISTNet(nn.Module):
  def __init__(self):
    super(MNISTNet, self).__init__()
    # TODO:  Complete the __init__ function
    raise NotImplementedError

  def forward(self, x):
    # TODO: Complete the forward function
    raise NotImplementedError
  

**Step 3** Construct a classifier `MNISTSum3Net` that takes in three MNIST images and returns a tensor of the distribution of their sum over 0 to 27.

In [None]:
import scallopy
class MNISTSum3Net(nn.Module):
  def __init__(self, provenance, k):
    super(MNISTSum3Net, self).__init__()
    # TODO: Initialize the nueral network here. It should include:
    #       1. MNISTNet
    #       2. Scallop program
    #       3. Forward function
    raise NotImplementedError

  def forward(self, x):
    # TODO: Write the forward function for MNISTSum3Net
    # Then execute the reasoning module; the expected return value is a size 28 tensor
    raise NotImplementedError

**Step 4** Setup trainer. We will use the BCE loss function for training the model.

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

def bce_loss(output, ground_truth):
  (_, dim) = output.shape
  gt = torch.stack([torch.tensor([1.0 if i == t else 0.0 for i in range(dim)]) for t in ground_truth])
  return F.binary_cross_entropy(output, gt)

class Trainer():
  def __init__(self, train_loader, test_loader, learning_rate, k, provenance):
    self.network = MNISTSum3Net(provenance, k)
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.loss = bce_loss

  def train_epoch(self, epoch):
    self.network.train()
    iter = tqdm(self.train_loader, total=len(self.train_loader))
    train_loss = 0
    correct = 0
    total = 0
    for data_ct, (data, target) in enumerate(iter):
      self.optimizer.zero_grad()
      output = self.network(data)

      loss = self.loss(output, target)
      loss.backward()
      self.optimizer.step()
      train_loss += loss.item()

      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
      total += pred.shape[0]
      perc = 100. * correct / total
      avg_loss = train_loss / (data_ct + 1)
      iter.set_description(f"[Train Epoch {epoch}] Total loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({perc:.2f}%)")

  def test(self, epoch):
    self.network.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for data_ct, (data, target) in enumerate(iter):
        output = self.network(data)
        test_loss += self.loss(output, target).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
        total += pred.shape[0]
        perc = 100. * correct / total
        avg_loss = test_loss / (data_ct + 1)
        iter.set_description(f"[Test Epoch {epoch}] Total loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({perc:.2f}%)")

  def train(self, n_epochs):
    self.test(0)
    for epoch in range(1, n_epochs + 1):
      self.train_epoch(epoch)
      self.test(epoch)


**Step 5** Train the model, and see the performance. :)

In [None]:
# Feel free to modify the parameters here
n_epochs=3
learning_rate=0.001
provenance="difftopkproofs"
k=3

trainer = Trainer(train_loader, test_loader, learning_rate, k, provenance)
trainer.train(n_epochs)

**Step 6** 
Let's plot the confusion matrix for the neural network, and check the performance for single-digit recognition.

In [None]:
from sklearn.metrics import confusion_matrix
import numpy
import seaborn as sn
import pandas as pd

diagnose_batch_size = 32
mnist_diagnose_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=mnist_img_transform)
mnist_loader = torch.utils.data.DataLoader(mnist_diagnose_dataset, batch_size=diagnose_batch_size)

# Get prediction result
y_true, y_pred = [], []
with torch.no_grad():
    for (imgs, digits) in mnist_loader:
        pred_digits = numpy.argmax(trainer.network.mnist_net(imgs), axis=1)
        y_true += [d.item() for (i, d) in enumerate(digits)]
        y_pred += [d.item() for (i, d) in enumerate(pred_digits)]

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

df_cm = pd.DataFrame(cm, index=list(range(10)), columns=list(range(10)))
plt.figure(figsize=(10,7))
sn.heatmap(df_cm, annot=True, cmap=plt.cm.Blues)
plt.ylabel("Actual")
plt.xlabel("Predicted")
plt.show()

### P4. MNIST Sort 2
In this practice, we will learn the MNIST digit recognition through the sort 2 task. The task takes in two MNIST digits and returns 0 if the first digit is smaller than the second image, otherwise, returns 1.

**Step 1** Dataloader construction. 

First, we want to construct a train data loader, and a test data loader separately. 
The `get_item` function shall take in an index and return a tuple. The first tuple element is a tuple of tensorized images, and the second tuple element is 0 or 1.
The `collate_fn` function shall take in a list of tuples returned by `get_item`, and return a tuple. The first tuple element is tuples of batched tensors representing the images, and the second element is a tensor of batched 0 or 1s. 

In [None]:

# TODO: Implent the MNISTSort2Dataset and the Dataloaders

**Step 2** Construct a classifier `MNISTSort2Net` that takes in two MNIST images and returns a tensor of the distribution over 0 and 1. You can utilize the previously defined class `MNISTNet`.

In [None]:

# TODO: Implement the MNISTSort2Net

**Step 3** Setup trainer and loss function. We will use the BCE loss function for training the model.

In [2]:
# TODO: Implement the loss function and Trainer

**Step 4** Train the model with different extended provenance semirings and check the results.

In [3]:
# TODO: Perform model training

**Step 5** 
Please plot the confusion matrix for the neural network with different extended provenance semiring setups.
1. diffminmaxprob
2. difftopkproofs with k = 3
3. difftopkproofs with k = 10

In [None]:
# TODO: Perform error analysis using confusion matrix 
# using the three extended provenance semirings mentioned above.