**Multi Instant Learning (MIL)**  on MNIST


In [1]:
import torch
import inspect
import tensorflow as tf
import copy
import random
from random import shuffle
import numpy as np
import pandas as pd
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch import nn, optim
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.data.dataset import Dataset

Get MNIST dataset

In [2]:
def get_data_loaders(train_batch_size, val_batch_size):
    mnist = MNIST(download=True, train=True, root=".").data.float()
    
    data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

    train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader


In [3]:
train_batch_size = 512
val_batch_size = 512

train_loader, valid_loader = get_data_loaders(train_batch_size, val_batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Model. Using ResNet

In [4]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), 
                                     padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)

Define eval and scores functions

In [5]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("Accuracy: ","Precision: ", "Recall: ", "F1: "), (a, p, r, f1)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

First Train the model on all samples without any bags.

In [8]:
epochs = 3

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MnistResNet().to(device)

loss_function = nn.CrossEntropyLoss() 
optimizer = optim.Adadelta(model.parameters())

losses = []
batches = len(train_loader)
val_batches = len(valid_loader)

for epoch in range(epochs):
    total_loss = 0

    # progress bar 
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

    # train
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        model.zero_grad()  
        outputs = model(X)                     
        loss = loss_function(outputs, y)       
        loss.backward()                        
        optimizer.step()                       

        current_loss = loss.item()
        total_loss += current_loss

        # update progress bar
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # validation
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)                                    
            val_losses += loss_function(outputs, y)
            predicted_classes = torch.max(outputs, 1)[1]          
            
            # calculate metrics 
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
          
    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {total_loss/batches}, Validation Loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches) 


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=118.0, style=ProgressStyle(description_width…




  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/3, Training Loss: 1.746042250576666, Validation Loss: 1.74849534034729
	    Accuracy: : 0.7179
	   Precision: : 0.7167
	      Recall: : 0.7071
	          F1: : 0.6528


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=118.0, style=ProgressStyle(description_width…


Epoch 2/3, Training Loss: 1.495963024891029, Validation Loss: 1.4814677238464355
	    Accuracy: : 0.9840
	   Precision: : 0.9844
	      Recall: : 0.9838
	          F1: : 0.9838


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=118.0, style=ProgressStyle(description_width…


Epoch 3/3, Training Loss: 1.4742911154940976, Validation Loss: 1.8237218856811523
	    Accuracy: : 0.6367
	   Precision: : 0.9139
	      Recall: : 0.6322
	          F1: : 0.6799


Save Model

In [10]:
torch.save(model.state_dict(), 'mnist_resnet.pt')

**Creating Bags for MIL**

Assign '1' as the bag label if one sample has a label '1'. Otherwise assign '0' as the bag label.

In [11]:
# get fetures from the portion of the dataset only
train_count = 10000
test_count = 5000

In [12]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [13]:
x_train = x_train[:train_count]
y_train = y_train[:train_count]
x_test = x_test[:test_count]
y_test = y_test[:test_count]

Convert to float and normalize

In [14]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

Build tuples(index, label) from train and test sets

In [15]:
instance_index_label = [(i, y_train[i]) for i in range(x_train.shape[0])]
instance_index_label_test = [(i, y_test[i]) for i in range(x_test.shape[0])]

Find the index if label is 1 in train set

In [16]:
find_index = [instance_index_label[i][0] for i in range(len(instance_index_label)) 
             if instance_index_label[i][1]==1]

Find the index if label is 1 in test set

In [17]:
find_index_test = [instance_index_label_test[i][0] for i in range(len(instance_index_label_test))
                   if instance_index_label_test[i][1]==1]

Check index and lable

In [18]:
print('index:', instance_index_label[0][0])         
print('label:', instance_index_label[0][1])         

index: 0
label: 5


Load pretrained model

In [19]:
model = MnistResNet()
model.load_state_dict(torch.load('mnist_resnet.pt'))
body = nn.Sequential(*list(model.children()))

Remove last layer from the model since we will have binary classification now and not digit number classification.

In [20]:
model = body[:9]
model.eval()

Sequential(
  (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

Get features from last layer

In [21]:
train_loader, val_loader = get_data_loaders(1,1)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())
losses = []
batches = len(train_loader)
val_batches = len(val_loader)

Get Train features

In [22]:
meta_table = dict()
feature_result = []

# progress bar
progress = tqdm(enumerate(train_loader), desc="Percent Completed: ", total=batches)

model.eval()

for i, data in progress:
    if i==train_count:
        break
    X, y = data[0], data[1]
    model.zero_grad()
    outputs = model(X)
    feature_result.append(outputs.reshape(-1).tolist())
    meta_table[i] = outputs.reshape(-1).tolist()
    
feature_array = np.array(feature_result)

HBox(children=(FloatProgress(value=0.0, description='Percent Completed: ', max=60000.0, style=ProgressStyle(de…

Get Test features

In [23]:
meta_t_table = dict()
feature_t_result = []

# progress bar
progress = tqdm(enumerate(val_loader), desc="Percent Completed: ", total=batches)

model.eval()

for i, data in progress:
    if i==test_count:
        break
    X, y = data[0], data[1]
    model.zero_grad()
    outputs_t = model(X)
    feature_t_result.append(outputs_t.reshape(-1).tolist())
    meta_t_table[i] = outputs_t.reshape(-1).tolist()

feature_test_array = np.array(feature_t_result)

HBox(children=(FloatProgress(value=0.0, description='Percent Completed: ', max=60000.0, style=ProgressStyle(de…

Create bags by putting several samples in one bag.

Bag labeled 1 if one sample is labeled 1. If all samples not equal to 1 bag is labeled as 0.

In [24]:
from typing import List, Dict, Tuple

In [25]:
def create_bags(instance_index_label: List[Tuple]) -> List[Dict]:
    bag_size = np.random.randint(3,7,size=len(instance_index_label)//5)
    data_cp = copy.copy(instance_index_label)
    np.random.shuffle(data_cp)
    bags = {}
    bags_per_instance_labels = {}
    bags_labels = {}
    for bag_ind, size in enumerate(bag_size):
        bags[bag_ind] = []
        bags_per_instance_labels[bag_ind] = []
        try:
            for _ in range(size):
                inst_ind, lbl = data_cp.pop()
                bags[bag_ind].append(inst_ind)
                bags_per_instance_labels[bag_ind].append(lbl)
            bags_labels[bag_ind] = bag_label_from_instance_labels(bags_per_instance_labels[bag_ind])
        except:
            break
    return bags, bags_labels

def bag_label_from_instance_labels(instance_labels):
    return int(any(((x==1) for x in instance_labels)))

In [26]:
bag_indices, bag_labels = create_bags(instance_index_label)
bag_features = {kk: torch.Tensor(feature_array[inds]) for kk, inds in bag_indices.items()}

Create bags for test data

In [27]:
bag_t_indices, bag_t_labels = create_bags(instance_index_label_test)
bag_t_features = {kk: torch.Tensor(feature_test_array[inds]) for kk, inds in bag_t_indices.items()}

Multi instance learning 

In [28]:
train_data = [(bag_features[i],bag_labels[i]) for i in range(len(bag_features))]

In [29]:
# check bag features
bag_features[0]

tensor([[1.3435, 0.9800, 0.3039,  ..., 0.9513, 0.6300, 3.0026],
        [0.3664, 0.0000, 0.0895,  ..., 0.4118, 1.1030, 7.1619],
        [0.7696, 0.1679, 1.2092,  ..., 0.3501, 0.3473, 2.1230],
        [1.6177, 0.0658, 0.7975,  ..., 1.5696, 0.0861, 3.2724],
        [1.0205, 0.3924, 1.3202,  ..., 0.7862, 0.4345, 2.4315]])

In [30]:
# check train data
train_data[0]

(tensor([[1.3435, 0.9800, 0.3039,  ..., 0.9513, 0.6300, 3.0026],
         [0.3664, 0.0000, 0.0895,  ..., 0.4118, 1.1030, 7.1619],
         [0.7696, 0.1679, 1.2092,  ..., 0.3501, 0.3473, 2.1230],
         [1.6177, 0.0658, 0.7975,  ..., 1.5696, 0.0861, 3.2724],
         [1.0205, 0.3924, 1.3202,  ..., 0.7862, 0.4345, 2.4315]]), 1)

Helper function to pad bag. Each bag has adifferent sizes. We adjust each bag to have the same size 7 (max of bag size).

In [31]:
def pad_tensor(data:list, max_number_instance) -> list:
    new_data = []
    for bag_index in range(len(data)):
        tensor_size = len(data[bag_index][0])
        pad_size = max_number_instance - tensor_size
        p2d = (0,0, 0, pad_size)
        padded = nn.functional.pad(data[bag_index][0], p2d, 'constant', 0)
        new_data.append((padded, data[bag_index][1]))
    return new_data

Padded Train set

In [32]:
max_number_instance = 7
padded_train = pad_tensor(train_data, max_number_instance)

Padded test set

In [33]:
test_data = [(bag_t_features[i],bag_t_labels[i]) for i in range(len(bag_t_features))]
padded_test = pad_tensor(test_data, max_number_instance)

In [34]:
def get_padded_data_loaders(train_data, test_data, train_batch_size, val_batch_size):
    train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

Convert train and test sets to Loaders

Create Multi Instance Model

In [35]:
# aggregation function
class Aggregate(torch.nn.Module):
    def __init__(self, a=10, dims=[0]):
        super(Aggregate, self).__init__()
        self.a = a
        self.b = torch.nn.Parameter(torch.tensor(0.01))
        self.dims =dims
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        mean = torch.mean(x, self.dims, False)
        res = (self.sigmoid(self.a * (mean - self.b)) - self.sigmoid(-self.a * self.b)) / (
              self.sigmoid(self.a * (1 - self.b)) - self.sigmoid(-self.a * self.b))
        return res

In [36]:
# linear layer with LeakyRelu activation
class LinearNN(torch.nn.Module):

    def __init__(self, n=7*512, n_mid = 7168, n_out=1, dropout=0.2):
        super(LinearNN, self).__init__()
        self.linear1 = torch.nn.Linear(n, n_mid)
        self.linear2 = torch.nn.Linear(n_mid, n_out)
        self.dropout = torch.nn.Dropout(dropout)
        self.non_linearity = torch.nn.LeakyReLU()
        
    def forward(self, x):
        z = self.linear1(x)
        z = self.non_linearity(z)
        z = self.dropout(z)
        z = self.linear2(z)
        y_pred = torch.sigmoid(z)
        return y_pred

# Milti Instance Learning Model 
class MIL(torch.nn.Module):

    def __init__(self, n=7*512,  n_mid=7168, n_out=1, 
                 n_inst=None, dropout=0.1,
                 noisy_a=4,
                 agg = Aggregate(a=4, dims=[0]),
                ):
        super(MIL, self).__init__()
        if agg is None:
            agg = Aggregate(a=noisy_a, dims=[0])
        if n_inst is None:
            self.mdl_instance = agg
            n_inst = n
        else:
            self.mdl_instance = nn.Sequential(
                            nn.Linear(n, n_inst),
                            nn.LeakyReLU(),
                            agg,
                            )
        if n_mid == 0:
            self.mdl_bag = LogisticRegression(n_inst, n_out)
        else:
            self.mdl_bag = LinearNN(n_inst, n_mid, n_out, dropout=dropout)
        
    def forward(self, bag_feature):
        y_pred = self.mdl_bag(bag_feature)
        return y_pred

Train and test

In [37]:
lr0 = 1e-4
epochs = 10

In [38]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MIL().to(device)

# padded data
train_loader, val_loader = get_padded_data_loaders(padded_train, padded_test, 1, 1)

loss_function = torch.nn.BCELoss(reduction='mean')     
optimizer = optim.SGD(model.parameters(), lr=lr0, momentum=0.9)

losses = []
batches = len(train_loader)
val_batches = len(val_loader)

for epoch in range(epochs):
    total_loss = 0

    # progress bar
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

    # train
    model.train()
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        X = X.reshape([1,7*512])
        y = y.type(torch.cuda.FloatTensor)
        model.zero_grad() 
       
        outputs = model(X)                             
        loss = loss_function(outputs, y)               
        loss.backward()                                
        optimizer.step()                               
        
        current_loss = loss.item()
        total_loss += current_loss

        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # validation
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    # validation
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            X = X.reshape([1,7*512])
            y = y.type(torch.cuda.FloatTensor)
            outputs = model(X)                         
            prediced_classes =outputs.detach().round()
            val_losses += loss_function(outputs, y)
            
            # calculate metrics
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), prediced_classes.cpu())
                )
          
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)                  

HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)





  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/10, training loss: 0.7579105719747021, validation loss: 0.6843321919441223
	    Accuracy: : 0.5600
	   Precision: : 0.5600
	      Recall: : 0.5600
	          F1: : 0.5600


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 2/10, training loss: 0.6836628786604851, validation loss: 0.674313485622406
	    Accuracy: : 0.5660
	   Precision: : 0.5660
	      Recall: : 0.5660
	          F1: : 0.5660


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 3/10, training loss: 0.6613991985283793, validation loss: 0.6612150073051453
	    Accuracy: : 0.5890
	   Precision: : 0.5890
	      Recall: : 0.5890
	          F1: : 0.5890


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 4/10, training loss: 0.6479153206292539, validation loss: 0.6688902974128723
	    Accuracy: : 0.6050
	   Precision: : 0.6050
	      Recall: : 0.6050
	          F1: : 0.6050


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 5/10, training loss: 0.6293139054160565, validation loss: 0.6437333822250366
	    Accuracy: : 0.6370
	   Precision: : 0.6370
	      Recall: : 0.6370
	          F1: : 0.6370


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 6/10, training loss: 0.6166103840358556, validation loss: 0.6559007167816162
	    Accuracy: : 0.6330
	   Precision: : 0.6330
	      Recall: : 0.6330
	          F1: : 0.6330


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 7/10, training loss: 0.6097220645584166, validation loss: 0.6752607226371765
	    Accuracy: : 0.5670
	   Precision: : 0.5670
	      Recall: : 0.5670
	          F1: : 0.5670


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 8/10, training loss: 0.5891696884036064, validation loss: 0.6924854516983032
	    Accuracy: : 0.6090
	   Precision: : 0.6090
	      Recall: : 0.6090
	          F1: : 0.6090


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 9/10, training loss: 0.5845732743283734, validation loss: 0.7248034477233887
	    Accuracy: : 0.5360
	   Precision: : 0.5360
	      Recall: : 0.5360
	          F1: : 0.5360


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=2000.0, style=ProgressStyle(description_widt…


Epoch 10/10, training loss: 0.5775469572227449, validation loss: 0.6637910604476929
	    Accuracy: : 0.6200
	   Precision: : 0.6200
	      Recall: : 0.6200
	          F1: : 0.6200
