***Load Data and Libraries***

In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
import tarfile

tar_path = "/content/forget-data.tar"
extract_path = "/content"

with tarfile.open(tar_path) as tar:
    tar.extractall(path=extract_path)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "/content/forget-data/Taj Mahal"
SAVE_PATH = "/content/"
CONCEPT = "Taj Mahal"
#
os.makedirs(SAVE_PATH, exist_ok=True)

***LOAD THE MODEL AND Initialize Matrices for Projection***


In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from torch import nn
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2", subfolder="vae").to(DEVICE)
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet").to(DEVICE)
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder").to(DEVICE)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")



# Access projection matrix P
P = nn.Linear(1024, 1024, bias=False).to(DEVICE)  # if not exposed directly

# Initialize low-rank matrices A, B
r = 8 # rank
A = nn.Parameter(torch.randn(1024, r, device=DEVICE) * 0.01)
B = nn.Parameter(torch.randn(1024, r, device=DEVICE) * 0.01)

**Freezing and Unfreezing Process**

During the unlearning process, we are solely working on the CLIP Text-Encoder, and hence, we disable all parameters for the other parts of the model, namely U-Net and Variational autoencoder(VAE). This prohibits any gradient computation for VAE and Unet, thus preventing any retraining process for it during fine tuning.

However, we selectively update the CLIP model, but only at the high level layers and the final normalization layer, i.e. we unfreeze the CLIP model selectively.

In [None]:
vae.eval()
unet.eval()
for p in vae.parameters(): p.requires_grad = False
for p in unet.parameters(): p.requires_grad = False
for p in text_encoder.parameters(): p.requires_grad = False

for name, param in text_encoder.named_parameters():
    param.requires_grad = False
    if any(f"encoder.layers.{i}" in name for i in [20, 21, 22]) or "final_layer_norm" in name:
        param.requires_grad = True


***Load and Transform Images***

In [None]:
from PIL import Image
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def load_images(data_path, max_images=5):
    images = []
    for fname in os.listdir(data_path):
        if fname.endswith(".jpg"):
            img = Image.open(os.path.join(data_path, fname)).convert("RGB")
            images.append(transform(img))
            if len(images) >= max_images:
                break
    return torch.stack(images)

images = load_images(DATA_PATH).to(DEVICE)


***Set up the Text Embeddings***

Create the forget and retain prompts and put it through the tokenizer to get the embeddings, which retains a tensor of token IDs of the various prompts. We then obtain the mean of these to get a single embedding vector representing the forget and retain prompts. The number of retain prompts doesn't matter, as long as the features given in retain prompts are as diverse as possible.

In [None]:
forget_prompt = f"a photo of {CONCEPT}"
retain_prompts = ["a photo of a car", "a photo of a flower", "a photo of a person", "a photo of a German Shepherd", "A photo of a frog", "an astronaut"]

x_f = tokenizer(forget_prompt, return_tensors="pt").input_ids.to(DEVICE)
x_r = tokenizer(retain_prompts, padding=True, return_tensors="pt").input_ids.to(DEVICE)

f_f = text_encoder(x_f)[0].mean(dim=1)
f_r = text_encoder(x_r)[0].mean(dim=1)
# Instead of zero vector
F_forget = torch.randn((1024, 1), device=DEVICE) * 2.0



## ***Unlearning Process***

This process occurs mainly through weight decay. We train the selected unfrozen text encoder layers for an average of 10 epochs. We choose this number as less epochs may not remove the forget concept entirely, while more epochs may cause overfitting and corruption of the retain images also.

**The following is the process for each epoch:**

1.   Each image from the forget dataset is loaded and passed through the Variational Autoencoder to obtain the latent representation of the image. It multiplies by a constant to fit the latent scale.
2.   We mimick the Stable Diffusion training Process by adding Gaussian Noise.
3.   Tokenized forget and retain prompts are embedded and the embeddings are again averaged.
4.   We then predict the gaussian noise by passing it through a Unet and compare with actual noise.
5. Regularization loss is used to keep the embedding meaningful.
6. Compute a low-rank delta projection. Project the matrix onto frozen layers through MSE loss. Do same with forget prompts.

7. Compute total loss.




In [None]:
%%time
optimizer = torch.optim.AdamW([
    {'params': filter(lambda p: p.requires_grad, text_encoder.parameters())},
    {'params': [A, B]}
], lr=1e-5)

text_encoder.train()

for epoch in range(10):
    total_loss = 0
    for i in range(len(images)):
        image = images[i].unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            latents = vae.encode(image).latent_dist.sample() * 0.18215

        noise = torch.randn_like(latents)
        t = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=DEVICE).long()
        noisy_latents = scheduler.add_noise(latents, noise, t)

        # Forget prompt embedding
        text_input = tokenizer(forget_prompt, return_tensors="pt").input_ids.to(DEVICE)
        f_f = text_encoder(text_input)[0].mean(dim=1)
        text_embed = text_encoder(text_input)[0]  # full sequence for UNet

        # Retain prompt embedding
        x_r = tokenizer(retain_prompts, padding=True, return_tensors="pt").input_ids.to(DEVICE)
        f_r = text_encoder(x_r)[0].mean(dim=1)

        # UNet prediction
        pred_noise = unet(noisy_latents, t, encoder_hidden_states=text_embed).sample
        loss_img = -F.mse_loss(pred_noise, noise)

        # Low-rank projection loss
        delta_P = A @ B.T
        retain_loss = F.mse_loss((P.weight + delta_P) @ f_r.T, P.weight @ f_r.T)
        forget_loss = F.mse_loss((P.weight + delta_P) @ f_f.T, F_forget)
        reg_loss = torch.norm(delta_P, p='fro')
        before = (P.weight @ f_f.T).detach()
        after = ((P.weight + delta_P) @ f_f.T)
        cos = F.cosine_similarity(before.T, after.T).item()
        print(f"Forget prompt projection similarity: {cos:.4f}")

        # Total loss
        loss_total = loss_img + 1.0 * retain_loss + 0.5 * forget_loss + 1e-3 * reg_loss

        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()

        total_loss += loss_total.item()

        # Optional: print losses

    print(f"Epoch {epoch}: Avg Loss = {total_loss:.4f}")
SO = "/content/model/"
# Save model and projection components
text_encoder.save_pretrained(SAVE_PATH)
torch.save({'A': A, 'B': B}, os.path.join(SAVE_PATH, 'deltaP.pt'))


Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Epoch 0: Avg Loss = 5.2838
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Forget prompt projection similarity: 0.9998
Epoch 1: Avg Loss = 2.8434
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Epoch 2: Avg Loss = 2.5232
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Epoch 3: Avg Loss = 3.8847
Forget promp

Convert the text encoder to float-16 format in order to fit in the pipeline


In [None]:
# Convert models and tensors to float16 for compatibility
text_encoder = text_encoder.half()
A.data = A.data.half()
B.data = B.data.half()
P.weight.data = P.weight.data.half()


In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2",
    text_encoder=text_encoder,
    safety_checker=None,
    torch_dtype=torch.float16
).to(DEVICE)

prompt = f"a photo of the frog"
image = pipe(prompt).images[0]
image.save(f"{SAVE_PATH}/eval_retain.png")

prompt = f"Taj Mahal outside"
image = pipe(prompt).images[0]
image.save(f"{SAVE_PATH}/eval_forgotten.png")



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
forget_prompt = f"a photo of {CONCEPT}"
retain_prompts = ["a photo of a car", "a flower", "a person"]

# Generate images
from PIL import Image
generated_images = []
generated_prompts = []

for prompt in [forget_prompt] + retain_prompts:
    img = pipe(prompt).images[0]
    generated_images.append(img)
    generated_prompts.append(prompt)


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
from PIL import Image

test_prompts_forget = [
    "Taj Mahal Dome",
    "a palace like the Taj Mahal",
    "white marble dome of taj mahal",
    "Taj Mahal masoleum",
    ""
]

test_prompts_retain = [
 "a green lizard",
 "a lizard on a rock",
 "a cartoon lizard",
 "lizard in a desert",
 "a lizard climbing a wall"
]

IMAGE_PATH_FORGET = r"/content/test_forget_images"
IMAGE_PATH_RETAIN = r"/content/test_retain_images"

imgs_retain = []
imgs_forget = []

os.makedirs(IMAGE_PATH_FORGET, exist_ok=True)
os.makedirs(IMAGE_PATH_RETAIN, exist_ok=True)
i = 0
for test_prompt in test_prompts_forget:
    i += 1
    img = pipe(test_prompt).images[0]
    imgs_forget.append(img)
    img.save(os.path.join(IMAGE_PATH_FORGET, f"image_{i}.png"))



  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
i = 0
for test_prompt in test_prompts_retain:
    i += 1
    img = pipe(test_prompt).images[0]
    imgs_retain.append(img)
    img.save(os.path.join(IMAGE_PATH_RETAIN, f"image_{i}.png"))

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def clip_score(image:Image.Image, text:str):
  inputs = clip_processor(text=[prompt], images=image, return_tensors="pt", padding=True).to(DEVICE)

  with torch.no_grad():
      outputs = clip_model(**inputs)
      image_embeds = outputs.image_embeds  # (1, D)
      text_embeds = outputs.text_embeds    # (1, D)
      clip_score = F.cosine_similarity(image_embeds, text_embeds).item()

  return clip_score


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
#IMAGE_PATH_FORGET IMAGE_PATH_RETAIN
#test_prompts_forget test_prompts_retain
clip_scores_retain = []
clip_scores_forget = []

for prompt, image in zip(test_prompts_forget, imgs_forget):
  clip_scores_forget.append(clip_score(image, prompt))

for prompt, image in zip(test_prompts_retain, imgs_retain):
  clip_scores_retain.append(clip_score(image, prompt))

In [None]:
print(clip_scores_forget)

[0.2504916191101074, 0.294971764087677, 0.26772060990333557, 0.2110106348991394]


In [None]:
print(clip_scores_retain)

[0.32222050428390503, 0.3099864721298218, 0.2590012550354004, 0.29753541946411133, 0.35376039147377014]


In [None]:
average_retain_clip = sum(clip_scores_retain) / len(clip_scores_retain)
average_forget_clip = sum(clip_scores_forget) / len(clip_scores_forget)
average_overall_clip = sum(clip_scores_retain) + sum(clip_scores_forget) / len(clip_scores_retain + clip_scores_forget)

In [None]:
print(average_retain_clip)
print(average_forget_clip)
print(average_overall_clip)

0.30850080847740174
0.25604865700006485
1.6563034454981487


In [None]:
!pip install torch-fidelity




In [None]:
generated_images = imgs_forget

In [None]:
!pip install clean-fid




In [None]:
import os
from PIL import Image
from cleanfid import fid

# Define a folder to save combined generated images
combined_generated_path = IMAGE_PATH_FORGET
os.makedirs(combined_generated_path, exist_ok=True)

# Save each image with a unique name
for i, img in enumerate(generated_images):
    img.save(os.path.join(combined_generated_path, f"gen_{i}.jpg"))




print("Evaluating...")
fid_value = fid.compute_fid(DATA_PATH, combined_generated_path, batch_size=16, num_workers=4)
print(f"FID Score: {fid_value}")


Evaluating...




compute FID between two folders
Found 5 images in the folder /content/forget-data/Taj Mahal


FID Taj Mahal : 100%|██████████| 1/1 [00:01<00:00,  1.49s/it]


Found 8 images in the folder /content/test_forget_images


FID test_forget_images : 100%|██████████| 1/1 [00:06<00:00,  6.93s/it]


FID Score: 339.2599015762849
