# Import Libraries

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use("ggplot")

import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.utils.data.sampler import WeightedRandomSampler

from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

Running on device: CUDA


# Accuracy Metric

In [2]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

In [3]:
class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        return example[0], example[1], idx
    
def sampling(net, train_set, compact_train_set, compact_test_set, net_state_dict): 
    epochs = 30
    
    compact_train_labels = torch.from_numpy(np.zeros((len(compact_train_set),))).long()
    compact_test_labels = torch.from_numpy(np.zeros((len(compact_test_set),))).long()
    for i, (inputs, targets) in enumerate(compact_train_set):
        compact_train_labels[i] = targets
    for i, (inputs, targets) in enumerate(compact_test_set):
        compact_test_labels[i] = targets
    
    mixed_data_set = ConcatDataset([compact_test_set,compact_train_set])
    
    sampling_weights = np.zeros((len(mixed_data_set),))
    for i in range(len(mixed_data_set)):
        if i < len(compact_test_set):
            sampling_weights[i] = len(mixed_data_set)/(len(compact_test_set)*2)
        else:
            sampling_weights[i] = len(mixed_data_set)/(len(compact_train_set)*2)
    sampler = WeightedRandomSampler(sampling_weights, 512*512, replacement=True)
    mixed_data_set = HiddenDataset(mixed_data_set)
    mixed_data_loader = DataLoader(
        mixed_data_set, batch_size=128, sampler = sampler)
    
    data_loader = DataLoader(
        train_set, batch_size=128, shuffle=True, drop_last=True)
    
    uncond_cohesive_scores = torch.from_numpy(np.zeros((len(compact_test_set),len(compact_train_set),10))).long()
    cohesive_scores = torch.from_numpy(np.zeros((len(compact_test_set),len(compact_train_set),1))).long()
    
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=4e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    test_net = resnet18(pretrained=False, num_classes=10)
    test_net.to(DEVICE)
    
    net.train()
    test_net.eval()
    for epoch in range(epochs):
        for i, (inputs, targets) in enumerate(data_loader):
            net.load_state_dict(net_state_dict)
            test_net.load_state_dict(net.state_dict())
            
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            targets = torch.nn.functional.one_hot(targets, num_classes=10).float()
            
            optimizer.zero_grad()
            output = net(inputs)
            loss = -(torch.log_softmax(output,dim=-1).view(-1,1,10) @ targets.view(-1,10,1)).mean()

            loss.backward()
            optimizer.step()
            
            net.eval()
            test_net.eval()
            for fn_inputs, fn_targets, fn_indices in mixed_data_loader:
                fn_inputs, fn_targets, fn_indices = fn_inputs.to(DEVICE), fn_targets.to(DEVICE), fn_indices.to(DEVICE)
                f_masks = fn_indices<len(compact_test_set)
                f_indices = fn_indices[f_masks]
                n_indices = fn_indices[~f_masks]-len(compact_test_set)
                
                with torch.no_grad():
                    mask = torch.arange(10).to(DEVICE).reshape(1,-1)
                    mask = (mask == fn_targets.reshape(-1,1))
                    
                    output0 = torch.softmax(net(fn_inputs),dim=1).detach()
                    output1 = torch.softmax(test_net(fn_inputs),dim=1).detach()
                    perc = torch.sign(output0-output1).long().reshape(-1,10)
                    f_sc = perc[f_masks,:]
                    
                    output0 = torch.masked_select(output0,mask)
                    output1 = torch.masked_select(output1,mask)
                    perc = torch.sign(output0-output1).long().reshape(-1,1)
                    n_sc = perc[~f_masks,:]

                    ones = torch.ones(len(f_sc),len(n_sc),10).long().to(DEVICE)
                    f_sq = torch.mul(ones,n_sc.view(1,-1,1))
                    f_sq = torch.mul(f_sq,f_sc.view(-1,1,10))
                    f_indices_sq = f_indices.view(-1,1).expand(len(f_sc),len(n_sc))
                    n_indices_sq = n_indices.view(1,-1).expand(len(f_sc),len(n_sc))
                    f_sq = f_sq.cpu().reshape(-1,10)
                    f_indices_sq = f_indices_sq.cpu().reshape(-1)
                    n_indices_sq = n_indices_sq.cpu().reshape(-1)
                    
                    uncond_cohesive_scores[f_indices_sq,n_indices_sq,:] = uncond_cohesive_scores[f_indices_sq,n_indices_sq,:]+f_sq
                    
                    #==============================================#
                        
                    f_sc = perc[f_masks,:]

                    ones = torch.ones(len(f_sc),len(n_sc),1).long().to(DEVICE)
                    f_sq = torch.mul(ones,n_sc.view(1,-1,1))
                    f_sq = torch.mul(f_sq,f_sc.view(-1,1,1))
                    f_indices_sq = f_indices.view(-1,1).expand(len(f_sc),len(n_sc))
                    n_indices_sq = n_indices.view(1,-1).expand(len(f_sc),len(n_sc))
                    f_sq = f_sq.cpu().reshape(-1,1)
                    f_indices_sq = f_indices_sq.cpu().reshape(-1)
                    n_indices_sq = n_indices_sq.cpu().reshape(-1)
                    
                    cohesive_scores[f_indices_sq,n_indices_sq,:] = cohesive_scores[f_indices_sq,n_indices_sq,:]+f_sq
            net.train()
            break
            
        scheduler.step()
    net.eval()
    np.save("./tmp/uncond_cohesive_scores.npy", uncond_cohesive_scores.numpy())
    np.save("./tmp/cohesive_scores.npy", cohesive_scores.numpy())
    np.save("./tmp/compact_train_labels.npy", compact_train_labels.numpy())
    np.save("./tmp/compact_test_labels.npy", compact_test_labels.numpy())
    
    return 0


# Sampling

In [4]:
from torch.utils.data import random_split

import torchvision
from torchvision import transforms
from sklearn.model_selection import train_test_split
def load_data():
    """Load data sets was created
    """
    normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    train_set = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=normalize)
    # Load retain-train set and compact train set
    retain_train_idx = np.load("retain_train_idx.npy")
    compact_train_idx = np.load("compact_train_idx.npy")
    retain_train_set = Subset(train_set, retain_train_idx)
    compact_train_set = Subset(train_set, compact_train_idx)
    
    # Load retain-test set and compact test set
    test_set = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=normalize)
    retain_test_idx = np.load("retain_test_idx.npy")
    compact_test_idx = np.load("compact_test_idx.npy")
    retain_test_set = Subset(test_set, retain_test_idx)
    compact_test_set = Subset(test_set, compact_test_idx)

    retain_train_loader = DataLoader(
        retain_train_set, batch_size=128, shuffle=True)
    compact_train_loader = DataLoader(
        compact_train_set, batch_size=128, shuffle=True)
    
    retain_test_loader = DataLoader(
        retain_test_set, batch_size=128, shuffle=False)
    compact_test_loader = DataLoader(
        compact_test_set, batch_size=128, shuffle=False)
    
    return {
        "retain_train": retain_train_loader,
        "compact_train": compact_train_loader,
        "retain_test": retain_test_loader,
        "compact_test": compact_test_loader
    }

In [5]:
data_loaders = load_data()

net_state_dict = torch.load("./tmp/checkpoint.pth")
net = resnet18(pretrained=False, num_classes=10)
net.to(DEVICE)
net.load_state_dict(net_state_dict)
sampling(net, ConcatDataset([(data_loaders["retain_train"]).dataset, (data_loaders["compact_train"]).dataset]), (data_loaders["compact_train"]).dataset, (data_loaders["compact_test"]).dataset, net_state_dict)

Files already downloaded and verified
Files already downloaded and verified


0

# Load tables of cohesive degree

In [6]:
uncond_cohesive_scores_np = np.load("./tmp/uncond_cohesive_scores.npy")
cohesive_scores_np = np.load("./tmp/cohesive_scores.npy")
compact_train_labels_np = np.load("./tmp/compact_train_labels.npy")
compact_test_labels_np = np.load("./tmp/compact_test_labels.npy")

uncond_cohesive_scores = torch.from_numpy(uncond_cohesive_scores_np)
cohesive_scores = torch.from_numpy(cohesive_scores_np)
compact_train_labels = torch.from_numpy(compact_train_labels_np)
compact_test_labels = torch.from_numpy(compact_test_labels_np)

# Experiments' result

Experiment 1

In [7]:
score_values, uncond_indices = torch.max(cohesive_scores,dim=1,keepdim=False)
uncond_masks = (torch.arange(len(compact_train_labels)).reshape(1,-1) == uncond_indices.reshape(-1,1))
predicts = (torch.masked_select(compact_train_labels.reshape(1,-1),uncond_masks)).reshape(-1)
acc = (predicts==compact_test_labels.reshape(-1)).float()
acc = acc.sum()/len(acc)
acc

tensor(0.9258)

Experiment 2

In [8]:
identical_label_masks = (torch.arange(10).reshape(1,-1) == compact_train_labels.reshape(-1,1)).reshape(1,-1)
identical_label_cohesive_scores = uncond_cohesive_scores.reshape(uncond_cohesive_scores.size(0),-1).masked_scatter(~identical_label_masks, torch.zeros_like(uncond_cohesive_scores).reshape(uncond_cohesive_scores.size(0),-1))
identical_label_cohesive_scores = identical_label_cohesive_scores.reshape(uncond_cohesive_scores.size(0), uncond_cohesive_scores.size(1), uncond_cohesive_scores.size(2))

uncond_score_values, uncond_uncond_indices = torch.max(identical_label_cohesive_scores,dim=1,keepdim=False)
uncond_uncond_masks = (torch.arange(len(compact_train_labels)).reshape(1,-1) == uncond_uncond_indices.reshape(-1,1))
uncond_uncond_labels = (torch.masked_select(compact_train_labels.reshape(1,-1),uncond_uncond_masks)).reshape(-1,10)
uncond_predict_indices = torch.argmax(uncond_score_values,dim=-1,keepdim=False)
uncond_predicts = uncond_uncond_labels[range(len(uncond_predict_indices)),uncond_predict_indices]
acc = (uncond_predicts==compact_test_labels.reshape(-1)).float()
acc = acc.sum()/len(acc)
acc

tensor(0.7480)

Accuracies of arg-max algorithm on retain training set and retain test set, respectively. 

In [9]:
retain_accuracy = accuracy(net, data_loaders["retain_train"])
test_accuracy = accuracy(net, data_loaders["retain_test"])
print((retain_accuracy,test_accuracy))

(0.999979793081151, 0.8138701517706577)
