The problem - Retraining from sctrach is heavy on computational time. The reason behind this is the slow back propagation of the model state using gradient descent on all parameters that are starting from a random state.

The approach - Retaining the global information existing on the original model and confusing it's vision to reconstruct it later properly.

Custom loss function -

$loss = -(1/N) \displaystyle\sum_{i=1}^{N} w_i log(p_i)$

$p_i - probability\ related\ to\ the\ i^{th}\ image$

$w_i - weight\ associated\ to\ the\ target\ of\ the\ i^{th}\ image$

where weights are equally assigned to each class as 0.5 except *'class 0'*, where the weight is 1.





In [None]:
!pip install tensorflow



In [None]:
import os
import warnings
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn import linear_model, model_selection


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.python.client import device_lib

from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from datetime import datetime

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 128

(X_train, y_train), held_out = keras.datasets.cifar10.load_data()
test_set, val_set = keras.utils.split_dataset(held_out, left_size=0.2)

  arr = np.array(sample)


In [None]:
def normalize(image, label, denorm=False):
    rescale = keras.layers.Rescaling(scale=1./255.)
    norms = keras.layers.Normalization(
        mean=[0.4914, 0.4822, 0.4465],
        variance=[np.square(0.2023), np.square(0.1994), np.square(0.2010)],
        invert=denorm,
        axis=-1,
    )

    if not denorm:
        image = rescale(image)
    return norms(image), label

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_ds = train_ds.map(normalize)
train_ds = train_ds.shuffle(buffer_size=8*BATCH_SIZE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)

val_ds = val_set.map(normalize).batch(BATCH_SIZE).prefetch(AUTOTUNE)
test_ds = test_set.map(normalize).batch(BATCH_SIZE).prefetch(AUTOTUNE)

In [None]:
def get_model():
  base_model = ResNet50(weights='imagenet', include_top=False)
  model = Sequential([
      base_model,
      GlobalAveragePooling2D(),
      Dense(128, activation='relu'),
      Dense(10, activation='softmax')
  ])
  model.compile(
      optimizer='adam',
      loss='sparse_categorical_crossentropy',
      metrics=['accuracy'],
      jit_compile=True
  )
  return model

In [None]:
model = get_model()
model.fit(train_ds, validation_data=val_ds, epochs=2, batch_size=128)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7cd3ac1fb850>

In [None]:
print(f"Train set accuracy: {100.0 * model.evaluate(train_ds)[-1]:0.1f}%%")
print(f"Test set accuracy: {100.0 * model.evaluate(test_ds)[-1]:0.1f}%%")

Train set accuracy: 81.1%%
Test set accuracy: 75.1%%


In [None]:
X_train[0].shape

(32, 32, 3)

In [None]:
y_pred = model.evaluate(X_train[0:1, :, :, :])



In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
#model.save('/content/gdrive/res_net.h5')

## Unlearning Algorithm

In [None]:
forget_set, retain_set = keras.utils.split_dataset(train_ds.unbatch(), left_size=0.1)
forget_ds = forget_set.batch(BATCH_SIZE).prefetch(AUTOTUNE)
retain_ds = retain_set.batch(BATCH_SIZE).prefetch(AUTOTUNE)
int(forget_ds.cardinality()), int(retain_ds.cardinality())

(40, 352)

In [None]:
def unlearning(net, retain, forget, validation):

  def custom_cross_entropy_loss(class_weights=None):
    def loss(y_true, y_pred):
      ce_loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
      if class_weights is not None:
        weights = tf.gather(class_weights, tf.cast(y_true, dtype=tf.int32))
        ce_loss = tf.reduce_mean(ce_loss * weights)
      return ce_loss
    return loss

  def add_noise_to_weights(layer, std=0.6):
    if isinstance(layer, tf.keras.layers.Conv2D) and 'conv' in layer.name:
      weights = layer.get_weights()
      weights_with_noise = [w + np.random.normal(0, std, w.shape) for w in weights]
      layer.set_weights(weights_with_noise)

  def vision_confuser(model, std=0.6):
    for layer in model.layers:
      add_noise_to_weights(layer, std)

  epochs = 5
  w = 0.5
  class_weights = [1, w, w, w, w, w, w, w, w, w]

  loss = custom_cross_entropy_loss(class_weights)
  metric = metrics.SparseCategoricalAccuracy(name='accuracy')
  optim = optimizers.SGD(momentum=0.9, weight_decay=5e-4)

  net.compile(
      optimizer=optim,
      loss=loss,
      metrics=metric,
  )

  vision_confuser(net, std=0.6)

  print(net.summary())

  net.fit(retain,verbose=1, epochs=epochs)

  return net

In [None]:
model_ft = unlearning(model, retain_ds, forget_ds, test_ds)

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, None, None, 2048   23587712  
                             )                                   
                                                                 
 global_average_pooling2d_3  (None, 2048)              0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_6 (Dense)             (None, 128)               262272    
                                                                 
 dense_7 (Dense)             (None, 10)                1290      
                                                                 
Total params: 23851274 (90.99 MB)
Trainable params: 23798154 (90.78 MB)
Non-trainable params: 53120 (207.50 KB)
_________________________________________________________________
None
Epo

In [None]:
print(f"Retain set accuracy: {100.0 * model_ft.evaluate(retain_ds)[-1]:0.1f}%%")
print(f"Test set accuracy: {100.0 * model_ft.evaluate(test_ds)[-1]:0.1f}%%")

Retain set accuracy: 92.3%%
Test set accuracy: 80.0%%


In [None]:
model.evaluate(forget_ds)



[0.35386016964912415, 0.8586000204086304]

In [None]:
model.evaluate(retain_ds)



[0.05112976208329201, 0.9691555500030518]