Steps to take for MNIST Test:



1.   Train MNIST encoder/decoder to accurately predict images
2.   Test MNIST encoder accuracy
3.   Implement triplet function into pytorch
4.   Create loss function based on triplet function
5.   Simulate researcher feedback with triplet labels with certain metacriteria
6.   Repeat steps 1-5 until triplet labels increase the accuracy of the encoder



Questions:


*   Why do we need to reproduce images? Would testing based on prediction accuracy be enough to show results?
  
  * Hud: was thinking about this some more. For MNIST, you probably can skip the pre-training step because we have plenty of data. Can go straight to training based on triplet feedback. 
  * Hud: we will need pre-training in the real-world scenario because we expect to be able to get quite a limited set of labels from our researchers. 
*   Will the retraining after feedback be solely based on feedback labels? Or also based on the triplet probability function? Or also based on 

  * Hud: training will involve both the model's confidence prediction for an input triplet (which will involve the triplet probability distribution) as well as the ground truth labels (which will involve the feedback labels. 

      1. pass each sample in triplet through encoder (could even be a simple linear encoder for the first attempt). 
      2. input encodings into triplet probability function to compute model confidence that triplet evaluates to "True"
      3. using output of 2 and known True/false value for triplet compute the binary cross entropy loss. 
    * This process is repeated for all of the training samples. 
    



# Note on training:

Rather than data being the image and target being the correct value, the data will be three images and the target will be which image is more similar to the first image.

Create a custom dataset with PyTorch using the 3 images as the data.

In [None]:
# @title Installs and imports

# installs
!pip install torchviz

# Library imports
from __future__ import print_function
import numpy as np
import pandas as pd
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchviz import make_dot
import matplotlib.pyplot as plt
from PIL import Image
import math
%matplotlib inline

from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
from torchvision import utils



In [None]:
#@title Model and Training Settings
batch_size=64 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
epochs=5 #number of epochs to train (default: 14)
lr=1.0 #learning rate (default: 1.0)
gamma=0.7 #Learning rate step gamma (default: 0.7)
no_cuda=False #disables CUDA training (default: True)
seed=42 #random seed (default: 42)
log_interval=70 #how many batches to wait before logging training status (default: 10)
save_model=False #save the trained model (default: False)

# additional derived settings
use_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

print("Device:", device)

Device: cpu


#Creating a Custom Dataset
Where the inputs are:
*   List of triple indices to use for training
*   Original training set
*   Function for evaluating meta-criteria

And the output is:
*   ((A: image,B: image,C: image), target: bool)



In [None]:
# @title Loading the data (instantiating DataLoaders)
# define pytorch dataloaders for training and testing
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=test_batch_size, shuffle=True, **kwargs)

In [None]:
dataset = datasets.MNIST('../data')
indices = []
for i in range(100):
  indices.append(np.random.randint(0, high=len(dataset.data), size = 3))

In [None]:
def metacriteria(a, b, c):
  if(a == b and a != c):
    return True
  else:
    return False

In [None]:
def output(indices, dataset):
  #metacriteria go brrr
  for triplet in indices:
    a,b,c = triplet
    if metacriteria(dataset.targets[a].item(), dataset.targets[b].item(), dataset.targets[c].item()):
      print(f"Triplet a: {dataset.targets[a].item()}, b: {dataset.targets[b].item()}, c: {dataset.targets[c].item()} fits metacriteria")
    else:
      print(f"Triplet a: {dataset.targets[a].item()}, b: {dataset.targets[b].item()}, c: {dataset.targets[c].item()} does not fit metacriteria")

In [None]:
output(indices, dataset)

Triplet a: 5, b: 6, c: 8 does not fit metacriteria
Triplet a: 4, b: 2, c: 5 does not fit metacriteria
Triplet a: 6, b: 2, c: 6 does not fit metacriteria
Triplet a: 2, b: 1, c: 1 does not fit metacriteria
Triplet a: 3, b: 5, c: 0 does not fit metacriteria
Triplet a: 5, b: 2, c: 4 does not fit metacriteria
Triplet a: 6, b: 2, c: 4 does not fit metacriteria
Triplet a: 2, b: 7, c: 7 does not fit metacriteria
Triplet a: 7, b: 9, c: 7 does not fit metacriteria
Triplet a: 2, b: 3, c: 2 does not fit metacriteria
Triplet a: 9, b: 1, c: 6 does not fit metacriteria
Triplet a: 3, b: 5, c: 9 does not fit metacriteria
Triplet a: 3, b: 9, c: 2 does not fit metacriteria
Triplet a: 2, b: 7, c: 4 does not fit metacriteria
Triplet a: 3, b: 8, c: 3 does not fit metacriteria
Triplet a: 4, b: 4, c: 1 fits metacriteria
Triplet a: 7, b: 8, c: 3 does not fit metacriteria
Triplet a: 2, b: 0, c: 8 does not fit metacriteria
Triplet a: 9, b: 5, c: 3 does not fit metacriteria
Triplet a: 4, b: 8, c: 8 does not fit m

#Training and Testing

In [None]:
#@title Training and testing functions
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    losses = []
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # print(output.shape)
        # print(target.shape)
        # raise Exception()
        loss = F.mse_loss(output.flatten(), data.flatten())
        loss.backward()
        optimizer.step()
        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader), loss.item()))
            
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.mse_loss(output.flatten(), data.flatten(), reduction='sum').item()  # sum up batch loss
            #pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            #correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
# Define the Flatten architecture
class NetFlat(nn.Module):
    def __init__(self):
        super(NetFlat, self).__init__()

        # feature encoder
        self.encoder = nn.Sequential(
            nn.Flatten(), # Convert into tabular data format.
            nn.Linear(784, 10)
        )

        self.nonlinear = nn.ReLU()
        
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(10, 784)
        )

    def forward(self, x):
        x = self.encoder(x) # Flatten -> x_i
        #x = self.nonlinear(x) 
        x = self.decoder(x) # Matrix multiply -> c_m^0 + sum(W_mi*x_i)


        return x.view((-1, 28, 28))

In [None]:
# Create the flat model
modelLin = NetFlat().to(device)
display(modelLin)

NetFlat(
  (encoder): Sequential(
    (0): Flatten()
    (1): Linear(in_features=784, out_features=10, bias=True)
  )
  (nonlinear): ReLU()
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=784, bias=True)
  )
)

In [None]:
#@title Get number of free parameters
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print("Number of parameters in linear model:", get_n_params(modelLin))

Number of parameters in linear model: 16474


In [None]:
# @title Train the linear model
optimizer = optim.Adadelta(modelLin.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

for epoch in range(1, epochs + 1):
    train(modelLin, device, train_loader, optimizer, epoch)
    test(modelLin, device, test_loader)
    scheduler.step()

if save_model:
    torch.save(modelLin.state_dict(), "mnist_flat.pt")

Test set: Average loss: 487.5377, Accuracy: 0/10000 (0.00%)
Test set: Average loss: 435.9507, Accuracy: 0/10000 (0.00%)
Test set: Average loss: 389.4391, Accuracy: 0/10000 (0.00%)
Test set: Average loss: 372.0756, Accuracy: 0/10000 (0.00%)
Test set: Average loss: 362.4109, Accuracy: 0/10000 (0.00%)
