In [None]:
import os
import tensorflow as tf
print(tf.__version__)

import numpy as np
from tqdm.notebook import tqdm as tqdm
from scipy.special import softmax

In [None]:
# Choose from '64', '256', '1024'
dense_units = '256'

In [None]:
NUM_CLASSES = 10

In [None]:
def oh(ys,n):
    return np.stack([
        np.arange(n) == y for y in ys
        ]) * 1.0

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

x_train = x_train / 255.
x_test = x_test / 255.

y_train = np.concatenate(y_train)
y_test = np.concatenate(y_test)

y_train_oh = oh(y_train,NUM_CLASSES)
y_test_oh = oh(y_test,NUM_CLASSES)

In [None]:
def filter_indices(a, classes):
    if not classes:
        return np.repeat(True, a.shape[0])

    return np.all([a != cl for cl in classes], axis=0)

In [None]:
def filter_data(classes):
    ind = filter_indices(y_train, classes)

    x_train_nocl = x_train[ind]
    y_train_nocl = y_train[ind]
    y_train_oh_nocl = np.delete(y_train_oh[ind],classes,1)

    ind_test = filter_indices(y_test, classes)

    x_test_nocl = x_test[ind_test]
    y_test_nocl = y_test[ind_test]
    y_test_oh_nocl = np.delete(y_test_oh[ind_test],classes,1)

    return (x_train_nocl, y_train_nocl, y_train_oh_nocl, x_test_nocl, y_test_nocl, y_test_oh_nocl)

In [None]:
def make_split(data, ratio=.3):
    N = data.shape[0]
    split = int(N * ratio)
    return data[split:],data[:split] 

def flatten(data):
    return data.reshape((data.shape[0]*data.shape[1],data.shape[2]))

In [None]:
def prepare_data(p1, p2, k=500, l=500, ratio=.3):
    
    p1_train, p1_test = make_split(p1, ratio=ratio)
    p1_train = flatten(p1_train[:,np.random.permutation(p1_train.shape[1])[:k]])
    p1_test = flatten(p1_test[:,np.random.permutation(p1_test.shape[1])[:l]])

    p2_train, p2_test = make_split(p2, ratio=ratio)
    p2_train = flatten(p2_train[:,np.random.permutation(p2_train.shape[1])[:k]])
    p2_test = flatten(p2_test[:,np.random.permutation(p2_test.shape[1])[:l]])

    train_data = np.concatenate([p1_train, p2_train])
    test_data = np.concatenate([p1_test, p2_test])
    
    train_labels = np.concatenate(
        [np.ones(p1_train.shape[0]), np.zeros(p2_train.shape[0])]
    )
    test_labels = np.concatenate(
        [np.ones(p1_test.shape[0]), np.zeros(p2_test.shape[0])]
    )
    
    train_perm = np.random.permutation(train_data.shape[0])
    train_data = train_data[train_perm]
    train_labels = train_labels[train_perm]
    
    test_perm = np.random.permutation(test_data.shape[0])
    test_data = test_data[test_perm]
    test_labels = test_labels[test_perm]
    
    return (train_data, train_labels, test_data, test_labels)

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier

In [None]:
def compare(data, sm=False, silent=False):

    (train_data, train_labels, test_data, test_labels) = data
    if sm:
        train_data = softmax(train_data, axis=-1)
        test_data = softmax(test_data, axis=-1)

    if not silent:
        print('training svm on {} samples, testing on {} samples'.format(train_data.shape[0], test_data.shape[0]))

    clfs = {}
    clfs['K-NN'] = KNeighborsClassifier(5)
    clfs['Random Forest'] = RandomForestClassifier(max_depth=5, n_estimators=10)
    clfs['AdaBoost'] = AdaBoostClassifier()

    r = []
    for clf_name in clfs:
        clf = clfs[clf_name]
        clf.fit(train_data, train_labels)

        train_acc = np.sum(clf.predict(train_data) == train_labels)/train_labels.shape[0]
        test_acc = np.sum(clf.predict(test_data) == test_labels)/test_labels.shape[0]

        if not silent:
            print('{} train accuracy'.format(train_acc))
            print('{} test accuracy'.format(test_acc))

        r.append((train_acc, test_acc, clf_name))
    return r

In [None]:
# def test_pred_path(name, i, dir):
#     return os.path.join(dir, 'test_' + name + '_' + str(i) + '.npy')

# preds = np.stack([
#     np.load(test_pred_path('cifar10_cnn'+dense_units+'_full_transfer', i, 'predictions_cifar10')) for i in range(100)
# ])

# def test_pred_path_total(name, dir):
#     return os.path.join(dir, 'test_' + name + '.npy')

# np.save(test_pred_path_total('cifar10_cnn'+dense_units+'_full_transfer', 'predictions_cifar10'), preds)

In [None]:
P = {}

P['p_full'] = np.load(os.path.join('predictions_cifar10','test_cifar10_cnn'+dense_units+'_full.npy'))
P['p_nocl0'] = np.load(os.path.join('predictions_cifar10','test_cifar10_cnn'+dense_units+'_del0.npy'))

P['p_transfer'] = np.load(os.path.join('predictions_cifar10','test_cifar10_cnn'+dense_units+'_full_transfer.npy'))
P['p_transfer2'] = P['p_transfer']

def filter_helper(key):
    if key=='p_full':
        return []
    return [0]

In [None]:
P['p_naive'] = P['p_full'][:,:,1:]

In [None]:
import filtration

In [None]:
P['p_lin_norm_mean'] = np.stack([
    np.transpose(np.matmul(
        filtration.filtration_matrix(
            P['p_full'][i], y_test, NUM_CLASSES, [0], mode='normalization'
        ), np.transpose(P['p_full'][i])))
    for i in range(P['p_full'].shape[0])])

def sample(x, y, s):
    return np.concatenate([
        x[y == i][:s]
        for i in range(NUM_CLASSES)])

P['p_lin_norm_mean10'] = np.stack([
    np.transpose(np.matmul(
        filtration.filtration_matrix(
            sample(P['p_full'][i], y_test, 10), sample(y_test, y_test, 10), NUM_CLASSES, [0], mode='normalization'
        ), np.transpose(P['p_full'][i])))
    for i in range(P['p_full'].shape[0])])

P['p_lin_norm_mean100'] = np.stack([
    np.transpose(np.matmul(
        filtration.filtration_matrix(
            sample(P['p_full'][i], y_test, 100), sample(y_test, y_test, 100), NUM_CLASSES, [0], mode='normalization'
        ), np.transpose(P['p_full'][i])))
    for i in range(P['p_full'].shape[0])])

P['p_lin_random'] = np.stack([
    np.transpose(np.matmul(
        filtration.filtration_matrix(
            P['p_full'][i], y_test, NUM_CLASSES, [0], mode='randomization'
        ), np.transpose(P['p_full'][i])))
    for i in range(P['p_full'].shape[0])])

P['p_lin_zero'] = np.stack([
    np.transpose(np.matmul(
        filtration.filtration_matrix(
            P['p_full'][i], y_test, NUM_CLASSES, [0], mode='zeroing'
        ), np.transpose(P['p_full'][i])))
    for i in range(P['p_full'].shape[0])])

In [None]:
Names = {}
Names['p_nocl0'] = 'Retraining'
Names['p_nocl0_second'] = 'Retraining2'
Names['p_lin_norm_mean'] = 'Normalization'
Names['p_lin_norm_mean10'] = 'Normalization (s=10)'
Names['p_lin_norm_mean100'] = 'Normalization (s=100)'
Names['p_lin_random'] = 'Randomization'
Names['p_lin_zero'] = 'Zeroing'
Names['p_naive'] = 'Naive'
Names['p_full'] = 'Before unlearning'
Names['p_transfer'] = 'Transfer'
Names['p_transfer2'] = 'Transfer2'

In [None]:
def make_label_map(filter):
    return [i for i in range(NUM_CLASSES) if i not in filter]

def evaluate_predictions(predictions, filter=[]):
    l_map = make_label_map(filter)
    f_i = filter_indices(y_test,filter)
    
    accs = []
    losses = []
    
    labels_oh = filter_data(filter)[5] ## 2
    
    for p in predictions[:,f_i]:
        a = np.argmax(p,axis=1)
        l = np.array([l_map[i] for i in a])
        accs.append((l == y_test[f_i]).mean())
        
        s = softmax(p,axis=1)
        s = s * labels_oh
        s = -np.log(np.sum(s,axis=1)).mean()
        losses.append(s)
        
    accs = np.array(accs) * 100
    losses = np.array(losses)
    return accs.mean(), losses.mean(), accs.std(), losses.std()

In [None]:
for p in P:
    a, l, a_std, l_std = evaluate_predictions(P[p], filter_helper(p))
    print('{:0.1f} +- {:0.2f} \t {:0.2f} +- {:0.2f} \t {}'.format(a, a_std, l, l_std, Names[p]))

In [None]:
C = [
    ('p_naive', 'p_nocl0'),
    ('p_lin_norm_mean', 'p_nocl0'),
    ('p_transfer', 'p_nocl0'),
    ('p_transfer2', 'p_nocl0'),
#     ('p_lin_norm_mean10', 'p_nocl0'),
#     ('p_lin_norm_mean100', 'p_nocl0'),
#     ('p_lin_random', 'p_nocl0'),
#     ('p_lin_zero', 'p_nocl0'),
]

In [None]:
N = 1

A = {}
for (p1,_) in C:
    A[p1] = [np.zeros(N),np.zeros(N),np.zeros(N)]

for c in tqdm(range(N)):
    f = y_test == c
    for (p1,p2) in C:
        data = prepare_data(P[p1][:,f], P[p2][:,f], k=150, l=150)
        r = compare(data, sm=False, silent=True)
        i = 0
        for (_,test_acc,_) in r:
            A[p1][i][c] += (test_acc - .5) * 2
            i += 1
        
for (p1, _) in C:
    for i in range(3):
        print('{:0.3f} \t'.format(A[p1][i].mean()), end='')
    print('{}'.format(Names[p1]), end='')
    print('')

In [None]:
N = 1

A = {}
for (p1,_) in C:
    A[p1] = [np.zeros(NUM_CLASSES-N),np.zeros(NUM_CLASSES-N),np.zeros(NUM_CLASSES-N)]

for c in tqdm(range(N,NUM_CLASSES)):
    f = y_test == c
    for (p1,p2) in C:
        data = prepare_data(P[p1][:,f], P[p2][:,f], k=150, l=150)
        r = compare(data, sm=False, silent=True)
        i = 0
        for (_,test_acc,_) in r:
            A[p1][i][c-N] += (test_acc - .5) * 2
            i += 1
        
for (p1, _) in C:
    for i in range(3):
        print('{:0.3f} \t'.format(A[p1][i].mean()), end='')
    print('{}'.format(Names[p1]), end='')
    print('')

In [None]:
def make_distributions(x):
    num_classes = x.shape[2]
    x = np.argmax(x,axis=-1)
    x = np.array([np.sum(x[m] == i)/x.shape[1] for m in range(x.shape[0]) for i in range(num_classes)]).reshape(x.shape[0],num_classes)
    return x

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
# import tikzplotlib 

In [None]:
X = [
    'p_nocl0',
    'p_lin_norm_mean',
    'p_transfer',
    'p_naive'
]

plt.figure()

legends = []

for (p,i,c) in zip(X,range(len(X)), ['#D81B60','#3892E0','#FFC107','tab:green']):
    offset = .07
    x = make_distributions(P[p][:,y_test==0])
    plt.errorbar(np.array(range(1,10))+offset*i, x.mean(axis=0), yerr=x.std(axis=0), fmt='none', capsize=5, ecolor=c)
    patch = mpatches.Patch(color=c, label=Names[p])
    legends.append(patch)

plt.legend(handles=legends)
plt.ylabel('Probability')

plt.xticks(list(range(1,10)),
          [
              'autom.',
              'bird',
              'cat',
              'deer',
              'dog',
              'frog',
              'horse',
              'ship',
              'truck',
          ],rotation=45)
plt.show()
