## Selective Amnesia: A Continual Learning Approach for Forgetting in Deep Generative Models

### Learning Objectives

<center>
<img src="https://github.com/clear-nus/selective-amnesia/raw/main/assets/main_fig.png" width=700px/>
</center>
<br><br>

Figure 1: Qualitative results of our method, Selective Amnesia (SA).

SA can be applied to a variety of models, from forgetting textual prompts such as specific celebrities or nudity in text-to-image models to discrete classes in VAEs and diffusion models (DDPM).


### Description:

Dataset Description:

MNIST dataset contains 60,000 Handwritten digits as training samples and 10,000 Test samples, which means each digit occurs 6000 times in the training set and 1000 times in the testing set. (approximately). Each image is Size Normalized and Centered Each image is 28 X 28 Pixel with 0-255 Gray Scale Value. That means each image is represented as 784 (28 X28) dimension vector where each value is in the range 0- 255.

CIFAR-10  is an established computer-vision dataset used for object recognition. It is a subset of the 80 million tiny images dataset and consists of 60,000 32x32 color images containing one of 10 object classes, with 6000 images per class. It was collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.

## Grading = 10 Marks

Here is a handy link to Kaggle's competition documentation (https://www.kaggle.com/docs/competitions), which includes, among other things, instructions on submitting predictions (https://www.kaggle.com/docs/competitions#making-a-submission).

### Instructions for downloading train and test data are as follows:

### 1. Create an API key in Kaggle.

To do this, go to the competition site on Kaggle at (https://www.kaggle.com/t/c8bda808fac2419d8025370763a90ada) and click on user then click on your profile as shown below. Click Account.

![alt text](https://cdn.iisc.talentsprint.com/DLFA/Experiment_related_data/Capture-NLP.PNG)

### 2. Next, scroll down to the API access section and click on **Create New Token** to download an API key (kaggle.json).

![alt text](https://cdn.iisc.talentsprint.com/DLFA/Experiment_related_data/Capture-NLP_1.PNG)

### 3. Upload your kaggle.json file using the following snippet in a code cell:



In [1]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"somyabaral","key":"d4c4d8370ddb793755e3633e756db159"}'}

In [2]:
#If successfully uploaded in the above step, the 'ls' command here should display the kaggle.json file.
%ls

kaggle.json  [0m[01;34msample_data[0m/


### 4. Install the Kaggle API using the following command


In [3]:
!pip install -U -q kaggle==1.5.8

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/59.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.2/59.2 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m118.8/118.8 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for kaggle (setup.py) ... [?25l[?25hdone
  Building wheel for slugify (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchdata 0.11.0 requires urllib3>=1.25, but you have urllib3 1.24.3 which is incompatible.
blobfile 3.0.0 requires urllib3<3,>=1.25.3, but you have urllib3 1.24.3 which is incompatible.
distributed 2024.12.1 requires urllib3>

### 5. Move the kaggle.json file into ~/.kaggle, which is where the API client expects your token to be located:



In [4]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

In [5]:
#Execute the following command to verify whether the kaggle.json is stored in the appropriate location: ~/.kaggle/kaggle.json
!ls ~/.kaggle

kaggle.json


In [6]:
!chmod 600 /root/.kaggle/kaggle.json #run this command to ensure your Kaggle API token is secure on colab

In [7]:
!git clone https://github.com/somyaranjan84/selective-amnesia.git

Cloning into 'selective-amnesia'...
remote: Enumerating objects: 234, done.[K
remote: Counting objects: 100% (234/234), done.[K
remote: Compressing objects: 100% (153/153), done.[K
remote: Total 234 (delta 67), reused 211 (delta 52), pack-reused 0 (from 0)[K
Receiving objects: 100% (234/234), 5.61 MiB | 20.08 MiB/s, done.
Resolving deltas: 100% (67/67), done.


In [1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !pip install -q condacolab
    import condacolab
    condacolab.install()

# Create a new environment (replace 'my_colab_env' with your desired name)
environment_name = "sa-dddpm"
!conda create --name sa-dddpm python=3.8 -y

# Activate the environment and check Python version and installed packages
!source activate sa-dddpm && python --version && pip list

✨🍰✨ Everything looks OK!
Channels:
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): - \ | / - done
Solving environment: | / done


    current version: 24.11.2
    latest version: 25.5.0

Please update conda by running

    $ conda update -n base -c conda-forge conda



## Package Plan ##

  environment location: /usr/local/envs/sa-dddpm

  added / updated specs:
    - python=3.8


The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge 
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu 
  bzip2              conda-forge/linux-64::bzip2-1.0.8-h4bc722e_7 
  ca-certificates    conda-forge/noarch::ca-certificates-2025.4.26-hbd8a1cb_0 
  ld_impl_linux-64   conda-forge/linux-64::ld_impl_linux-64-2.43-h712a8e2_4 
  libffi             conda-forge/linux-64::libffi-3.4.6-h2dba641_1 
  libgcc             conda-forge/linux-64::libgcc-15.1.0-h767d61c_2 
  libgcc-ng          c

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Mon Jun  2 03:06:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   52C    P8             18W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
!pip install -r /content/selective-amnesia/ddpm/requirements.txt



In [2]:
!pip install einops



###1. First train a conditional DDPM on all 10 CIFAR10/STL10 classes. Specify GPUs using the CUDA_VISIBLE_DEVICES environment flag. We demonstrate the code to run SA on CIFAR10; the commands can run the STL10 experiments using the same commands but replacing config and dataset flags accordingly.

In [3]:
!sed -i 's/n_iters: 800000/n_iters: 5000/g' /content/selective-amnesia/ddpm/configs/cifar10_train.yml
!sed -i 's/snapshot_freq: 5000/snapshot_freq: 2000/g' /content/selective-amnesia/ddpm/configs/cifar10_train.yml

In [4]:
!cat /content/selective-amnesia/ddpm/configs/cifar10_train.yml
!cat /content/selective-amnesia/ddpm/configs/cifar10_fim.yml
!cat /content/selective-amnesia/ddpm/configs/cifar10_forget.yml

data:
    path: ./data
    dataset: CIFAR10
    image_size: 32
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4
    n_classes: 10

model:
    type: simple
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True
    cond_drop_prob: 0.1

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 128
    n_iters: 5000
    snapshot_freq: 2000
    log_freq: 50
    visualization_samples: 100

sampling:
    batch_size: 128
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.0002
    beta1: 0.9
    amsgrad: false
    eps: 0.00000001
    grad_clip: 1.0

comments: nildata:
    image_size: 32
    c

In [5]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
!python /content/selective-amnesia/ddpm/train.py --config /content/selective-amnesia/ddpm/configs/cifar10_train.yml --mode train

INFO - train.py - 2025-06-02 03:13:15,032 - Writing log file to ./results/cifar10/2025_06_02_031315/logs
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100% 170498071/170498071 [00:13<00:00, 13014561.49it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
INFO - diffusion.py - 2025-06-02 03:14:11,860 - step: 49, loss: 224.9555206298828, time: 38.105074644088745
INFO - diffusion.py - 2025-06-02 03:14:41,825 - step: 99, loss: 168.1160125732422, time: 29.791980981826782
INFO - diffusion.py - 2025-06-02 03:15:11,862 - step: 149, loss: 198.83755493164062, time: 29.86226487159729
INFO - diffusion.py - 2025-06-02 03:15:41,781 - step: 199, loss: 165.90924072265625, time: 29.745646238327026
INFO - diffusion.py - 2025-06-02 03:16:11,756 - step: 249, loss: 108.78804016113281, time: 29.801692724227905
INFO - diffusion.py - 2025-06-02 03:16:41,753 - step: 299, loss: 149.55972290039062, time: 29.82290506362915
INFO - diffusion.py - 2025-06-02 

In [None]:
!ls /content/results/cifar10/2025_06_02_031315/ckpts/

ls: cannot access '/content/results/cifar10/2025_05_31_062347/ckpts/': No such file or directory


In [None]:
!cat /content/selective-amnesia/ddpm/configs/cifar10_sample.yml

data:
    path: ./data
    dataset: CIFAR10
    image_size: 32
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4
    n_classes: 10

model:
    type: simple
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True
    cond_drop_prob: 0.1

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    visualization_samples: 100

sampling:
    batch_size: 512
    last_only: True

comments: nil

##2. Next, we need to generate class samples for calculating the FIM, and to be used as the GR samples later.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

!python /content/selective-amnesia/ddpm/sample.py --config /content/selective-amnesia/ddpm/configs/cifar10_sample.yml --ckpt_folder /content/results/cifar10/2025_06_02_031315/ --mode sample_classes --n_samples_per_class 50

Generating image samples for class 0 to use as dataset: 100% 1/1 [02:38<00:00, 158.64s/it]
Generating image samples for class 1 to use as dataset: 100% 1/1 [02:37<00:00, 157.09s/it]
Generating image samples for class 2 to use as dataset: 100% 1/1 [02:37<00:00, 157.13s/it]
Generating image samples for class 3 to use as dataset: 100% 1/1 [02:37<00:00, 157.11s/it]
Generating image samples for class 4 to use as dataset: 100% 1/1 [02:37<00:00, 157.07s/it]
Generating image samples for class 5 to use as dataset: 100% 1/1 [02:37<00:00, 157.13s/it]
Generating image samples for class 6 to use as dataset: 100% 1/1 [02:37<00:00, 157.13s/it]
Generating image samples for class 7 to use as dataset: 100% 1/1 [02:37<00:00, 157.08s/it]
Generating image samples for class 8 to use as dataset: 100% 1/1 [02:37<00:00, 157.05s/it]
Generating image samples for class 9 to use as dataset: 100% 1/1 [02:37<00:00, 157.10s/it]


##3. Calculate the FIM. Depending on the value n_samples_per_class in step 2 (500 is what is used in the paper), this step could take a while as the ELBO of diffusion models requires a sum over 1000 timesteps PER sample.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
!python /content/selective-amnesia/ddpm/fim.py --config /content/selective-amnesia/ddpm/configs/cifar10_fim.yml --ckpt_folder /content/results/cifar10/2025_06_01_053656/ --n_chunks 20


Loading checkpoints /content/results/cifar10/2025_06_01_053656/
Calculating Fisher information matrix: 100% 500/500 [7:03:50<00:00, 50.86s/it]


In [None]:
!cat /content/selective-amnesia/ddpm/configs/cifar10_forget.yml

data:
    path: ./data
    dataset: CIFAR10
    image_size: 32
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4
    n_classes: 10

model:
    type: simple
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True
    cond_drop_prob: 0.1

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 128
    n_iters: 20000
    snapshot_freq: 1000
    log_freq: 50
    visualization_samples: 100
    train_embeddings: False
    gamma: 1 # weight of GR term, leave it at 1
    lmbda: 10 # adjust lambda for FIM term

sampling:
    batch_size: 128
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.00

In [None]:
!sed -i 's/n_iters: 20000/n_iters: 3000/g' /content/selective-amnesia/ddpm/configs/cifar10_forget.yml
!sed -i 's/snapshot_freq: 1000/snapshot_freq: 2000/g' /content/selective-amnesia/ddpm/configs/cifar10_forget.yml
!cat /content/selective-amnesia/ddpm/configs/cifar10_forget.yml

data:
    path: ./data
    dataset: CIFAR10
    image_size: 32
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4
    n_classes: 10

model:
    type: simple
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True
    cond_drop_prob: 0.1

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 128
    n_iters: 3000
    snapshot_freq: 2000
    log_freq: 50
    visualization_samples: 100
    train_embeddings: False
    gamma: 1 # weight of GR term, leave it at 1
    lmbda: 10 # adjust lambda for FIM term

sampling:
    batch_size: 128
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.000

In [None]:
!ls /content/selective-amnesia/ddpm/results/cifar10/2025_06_01_053656/

classifier_evaluation.py  fim.py     README.md	       save_base_dataset.py
configs			  functions  requirements.txt  train_classifier.py
datasets		  LICENSE    runners	       train.py
evaluator.py		  models     sample.py


##4. Forgetting training with SA

You can vary the lambda weight for the FIM in configs/cifar10_forget.yml.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
!python /content/selective-amnesia/ddpm/train.py --config /content/selective-amnesia/ddpm/configs/cifar10_forget.yml --ckpt_folder /content/selective-amnesia/ddpm/results/cifar10/2025_06_01_053656/ --label_to_forget 0 --mode forget

INFO - diffusion.py - 2025-06-01 14:09:28,006 - Training diffusion forget with contrastive and EWC. Gamma: 1, lambda: 10
ERROR - train.py - 2025-06-01 14:09:28,007 - Traceback (most recent call last):
  File "/content/selective-amnesia/ddpm/train.py", line 109, in main
    runner.train_forget()
  File "/content/selective-amnesia/ddpm/runners/diffusion.py", line 253, in train_forget
    D_train_loader = all_but_one_class_path_dataset(config, os.path.join(args.ckpt_folder, "class_samples"), args.label_to_forget)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/selective-amnesia/ddpm/datasets/__init__.py", line 103, in all_but_one_class_path_dataset
    train_dataset = ImageFolder(
                    ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torchvision/datasets/folder.py", line 309, in __init__
    super().__init__(
  File "/usr/local/lib/python3.11/site-packages/torchvi

#Evaluation

5. Image Metrics Evaluation on Classes to Remember

First generate the sample images on the model trained in step 3.

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
!python /content/selective-amnesia/ddpm/sample.py --config /content/selective-amnesia/ddpm/configs/cifar10_sample.yml --ckpt_folder /content/selective-amnesia/ddpm/results/cifar10/2025_06_01_053656/ --mode sample_fid --n_samples_per_class 500 --classes_to_generate 'x0'


ERROR:root:Traceback (most recent call last):
  File "/content/selective-amnesia/ddpm/sample.py", line 87, in main
    runner.sample()
  File "/content/selective-amnesia/ddpm/runners/diffusion.py", line 380, in sample
    states = torch.load(
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/serialization.py", line 791, in load
    with _open_file_like(f, 'rb') as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/serialization.py", line 271, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/serialization.py", line 252, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/content/selective-amnesia/ddpm/results/cifar10/2025_05_31_062347/ckpts/ckpt.pth'



In [None]:
!python save_base_dataset.py --dataset cifar10 --label_to_forget 0

python: can't open file '/content/save_base_dataset.py': [Errno 2] No such file or directory


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
!python  /content/selective-amnesia/ddpm/evaluator.py  /content/selective-amnesia/ddpm/results/cifar10/2025_06_01_053656/fid_samples_without_label_0_guidance_2.0 cifar10_without_label_0

2025-05-31 08:38:20.602305: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-31 08:38:20.605235: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-31 08:38:20.659460: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-31 08:38:20.659986: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
  File "/content/selective-amnesia/ddpm/evaluator.py", line 683, in

### 6. Classifier Evaluation

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
!python content/selective-amnesia/ddpm/train_classifier.py --dataset cifar10

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
!python sample.py --config cifar10_sample.yml --ckpt_folder results/cifar10/2025_06_01_053656 --mode sample_classes --classes_to_generate "0" --n_samples_per_class 500

In [None]:
!python classifier_evaluation.py --sample_path results/cifar10/2025_06_01_053656/class_samples/0 --dataset cifar10 --label_of_forgotten_class 0
