# Imports and utils

In [None]:
import math
import torch
import torch.nn as nn
import random
import numpy as np
import torch.nn.functional as F
import argparse
import os
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import copy

In [None]:
from neural_network import *
from utils import *
from metrics import *
from training_helpers import *


In [None]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
k = 5 # number of classes expert can predict
n_dataset = 10
Expert = synth_expert(k, n_dataset)


In [None]:
use_data_aug = False
n_dataset = 10  # cifar-10
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                    std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

if use_data_aug:
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                            (4, 4, 4, 4), mode='reflect').squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
else:
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

if n_dataset == 10:
    dataset = 'cifar10'
elif n_dataset == 100:
    dataset = 'cifar100'

kwargs = {'num_workers': 0, 'pin_memory': True}


train_dataset_all = datasets.__dict__[dataset.upper()]('../data', train=True, download=True,
                                                        transform=transform_train)
train_size = int(0.90 * len(train_dataset_all))
test_size = len(train_dataset_all) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset_all, [train_size, test_size])
#train_loader = torch.utils.data.DataLoader(train_dataset,
#                                           batch_size=128, shuffle=True, **kwargs)
#val_loader = torch.utils.data.DataLoader(val_dataset,
#                                            batch_size=128, shuffle=True, **kwargs)


normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                 std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
kwargs = {'num_workers': 1, 'pin_memory': True}

transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize
])
test_dataset = datasets.__dict__["cifar10".upper()]('../data', train=False, transform=transform_test, download=True)
#test_loader = torch.utils.data.DataLoader(
#    datasets.__dict__["cifar100".upper()]('../data', train=False, transform=transform_test, download=True),
#    batch_size=128, shuffle=True, **kwargs)



In [None]:
class CifarExpertDataset(Dataset):
    def __init__(self, images, targets, expert_fn, labeled, indices = None):
        """
        """
        self.images = images
        self.targets = np.array(targets)
        self.expert_fn = expert_fn
        self.labeled = np.array(labeled)
        self.expert_preds = np.array(expert_fn(None, torch.FloatTensor(targets)))
        for i in range(len(self.expert_preds)):
            if self.labeled[i] == 0:
                self.expert_preds[i] = -1 # not labeled by expert
        if indices != None:
            self.indices = indices
        else:
            self.indices = np.array(list(range(len(self.targets))))
    def __getitem__(self, index):
        """Take the index of item and returns the image, label, expert prediction and index in original dataset"""
        label = self.targets[index]
        image = transform_test(self.images[index])
        expert_pred = self.expert_preds[index]
        indice = self.indices[index]
        labeled = self.labeled[index]
        return torch.FloatTensor(image), label, expert_pred, indice, labeled

    def __len__(self):
        return len(self.targets)

In [None]:
dataset_train = CifarExpertDataset(np.array(train_dataset.dataset.data)[train_dataset.indices], np.array(train_dataset.dataset.targets)[train_dataset.indices], Expert.predict , [1]*len(train_dataset.indices))
dataset_val = CifarExpertDataset(np.array(val_dataset.dataset.data)[val_dataset.indices], np.array(val_dataset.dataset.targets)[val_dataset.indices], Expert.predict , [1]*len(val_dataset.indices))
dataset_test = CifarExpertDataset(test_dataset.data , test_dataset.targets, Expert.predict , [1]*len(test_dataset.targets))

dataLoaderTrain = DataLoader(dataset=dataset_train, batch_size=128, shuffle=True,  num_workers=0, pin_memory=True)
dataLoaderVal = DataLoader(dataset=dataset_val, batch_size=128, shuffle=False,  num_workers=0, pin_memory=True)
dataLoaderTest = DataLoader(dataset=dataset_test, batch_size=128, shuffle=False,  num_workers=0, pin_memory=True)

# Figure 3

In [None]:
all_indices = list(range(len(train_dataset.indices)))
all_data_x = np.array(train_dataset.dataset.data)[train_dataset.indices]
all_data_y = np.array(train_dataset.dataset.targets)[train_dataset.indices]

intial_random_set = random.sample(all_indices, 20)
indices_labeled  = intial_random_set
indices_unlabeled= list(set(all_indices) - set(indices_labeled))

dataset_train_labeled = CifarExpertDataset(all_data_x[indices_labeled], all_data_y[indices_labeled], Expert.predict , [1]*len(indices_labeled), indices_labeled)
dataset_train_unlabeled = CifarExpertDataset(all_data_x[indices_unlabeled], all_data_y[indices_unlabeled], Expert.predict , [0]*len(indices_unlabeled), indices_unlabeled)

dataLoaderTrainLabeled = DataLoader(dataset=dataset_train_labeled, batch_size=128, shuffle=True,  num_workers=0, pin_memory=True)
dataLoaderTrainUnlabeled = DataLoader(dataset=dataset_train_unlabeled, batch_size=128, shuffle=True,  num_workers=0, pin_memory=True)

In [None]:
'''
we have a dataset that has all labels. 
A percentage of it also has expert labels.
Joint: only train on labeled part
Seperate: train classifier on all data, train rejector only on labeled part
'''

MAX_TRIALS = 10
EPOCHS = 60
EPOCHS_ALPHA = 15
data_sizes = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]
alpha_grid = [0, 0.1,  0.5, 1]
joint_results = []
seperate_results = []
joint_semisupervised_results = []


for trial in range(MAX_TRIALS):
    joint = []
    seperate = []
    joint_semisupervised = []
    for data_size in data_sizes:
        print(f'\n \n datas size {data_size} \n \n')
        print(f' \n Joint \n')

        all_indices = list(range(len(train_dataset.indices)))
        all_data_x = np.array(train_dataset.dataset.data)[train_dataset.indices]
        all_data_y = np.array(train_dataset.dataset.targets)[train_dataset.indices]

        intial_random_set = random.sample(all_indices, math.floor(data_size*len(all_indices)))
        indices_labeled  = intial_random_set
        indices_unlabeled= list(set(all_indices) - set(indices_labeled))

        dataset_train_labeled = CifarExpertDataset(all_data_x[indices_labeled], all_data_y[indices_labeled], Expert.predict , [1]*len(indices_labeled), indices_labeled)
        dataset_train_unlabeled = CifarExpertDataset(all_data_x[indices_unlabeled], all_data_y[indices_unlabeled], Expert.predict , [0]*len(indices_unlabeled), indices_unlabeled)

        dataLoaderTrainLabeled = DataLoader(dataset=dataset_train_labeled, batch_size=128, shuffle=True,  num_workers=0, pin_memory=True)
        dataLoaderTrainUnlabeled = DataLoader(dataset=dataset_train_unlabeled, batch_size=128, shuffle=True,  num_workers=0, pin_memory=True)


        net_h_params = [10] + [100,100,1000,500]
        net_r_params = [1] + [100,100,1000,500]
        model_2_r = NetSimpleRejector(net_h_params, net_r_params).to(device)
        model_dict = run_reject(model_2_r, 10, Expert.predict, EPOCHS, 1, dataLoaderTrainLabeled, dataLoaderVal, True)
        best_score = 0
        best_model = None
        best_alpha = 1
        for alpha in alpha_grid:
            print(f'alpha {alpha}')
            model_2_r.load_state_dict(model_dict)
            model_dict_alpha = run_reject(model_2_r, 10, Expert.predict, EPOCHS_ALPHA, alpha, dataLoaderTrainLabeled, dataLoaderTrainLabeled, True, 1)
            model_2_r.load_state_dict(model_dict_alpha)
            score = metrics_print(model_2_r, Expert.predict, n_dataset, dataLoaderTrainLabeled)['system accuracy']
            if score >= best_score:
                best_score =  score
                best_model = model_dict_alpha
                best_alpha = alpha



        model_2_r.load_state_dict(best_model)
        joint.append(metrics_print(model_2_r, Expert.predict, n_dataset, dataLoaderTest)['system accuracy'])
        
        print(f'\n Joint semi-supervised')
        net_h_params = [10] + [100,100,1000,500]
        net_r_params = [1] + [100,100,1000,500]
        model_2_r = NetSimpleRejector(net_h_params, net_r_params).to(device)
        run_reject_class(model_2_r, EPOCHS, dataLoaderTrain, dataLoaderTrainLabeled)
        model_dict = copy.deepcopy(model_2_r.state_dict())
        best_score = 0
        best_model = None
        best_alpha = 1
        for alpha in alpha_grid:
            print(f'alpha {alpha}')
            model_2_r.load_state_dict(model_dict)
            model_dict_alpha = run_reject(model_2_r, 10, Expert.predict, EPOCHS_ALPHA, alpha, dataLoaderTrainLabeled, dataLoaderTrainLabeled, True, 1)
            model_2_r.load_state_dict(model_dict_alpha)
            score = metrics_print(model_2_r, Expert.predict, n_dataset, dataLoaderTrainLabeled)['system accuracy']
            if score >= best_score:
                best_score =  score
                best_model = model_dict_alpha
                best_alpha = alpha



        model_2_r.load_state_dict(best_model)
        joint_semisupervised.append(metrics_print(model_2_r, Expert.predict, n_dataset, dataLoaderTest)['system accuracy'])

        print(f' \n Seperate \n')
        # seperate
        model_expert = NetSimple(2,  100,100,1000,500).to(device)
        run_expert(model_expert,EPOCHS, dataLoaderTrainLabeled, dataLoaderVal)

        model_class = NetSimple(n_dataset, 100,100,1000,500).to(device)

        run_reject_class(model_class, EPOCHS, dataLoaderTrain, dataLoaderVal)
        seperate.append(metrics_print_2step(model_class, model_expert, Expert.predict, 10, dataLoaderTest)['system accuracy'])
    
    joint_results.append(joint)
    seperate_results.append(seperate)
    joint_semisupervised_results.append(joint_semisupervised)

In [None]:



avgs_rand = [np.average([joint_results[triall][i]  for triall in range(MAX_TRIALS)]) for i in range(len(data_sizes))]
stds_rand = [np.std([joint_results[triall][i] for triall in range(MAX_TRIALS)]) for i in range(len(data_sizes))]
plt.errorbar(data_sizes ,  avgs_rand, yerr=stds_rand, marker = "x",  label=f'Joint')



avgs_rand = [np.average([ seperate_results[triall][i]  for triall in range(MAX_TRIALS)]) for i in range(len(data_sizes))]
stds_rand = [np.std([seperate_results[triall][i] for triall in range(MAX_TRIALS)]) for i in range(len(data_sizes))]
plt.errorbar(data_sizes ,  avgs_rand, yerr=stds_rand, marker = "o",  label=f'Staged')


avgs_rand = [np.average([ joint_semisupervised_results[triall][i]  for triall in range(MAX_TRIALS)]) for i in range(len(data_sizes))]
stds_rand = [np.std([joint_semisupervised_results[triall][i] for triall in range(MAX_TRIALS)]) for i in range(len(data_sizes))]
plt.errorbar(data_sizes ,  avgs_rand, yerr=stds_rand, marker = "*",  label=f'Joint-SemiSupervised')



ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.get_xaxis().tick_bottom()    
ax.get_yaxis().tick_left()   
plt.grid()
plt.legend(fontsize='large')
plt.legend(loc ="lower right")
plt.ylabel('System Accuracy',  fontsize='x-large')
plt.xlabel('Fraction of data Labeled', fontsize='x-large')
fig_size = plt.rcParams["figure.figsize"]
fig_size[0] = 6
fig_size[1] = 4
plt.savefig("sample_complexity_cifar.pdf", dpi = 1000)
plt.show()