<a href="https://colab.research.google.com/github/shreyasudaya/MajorProjectCS402/blob/master/mp_fs_unlearn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
import tarfile

tar_path = "/content/forget_dataset.tar"
extract_path = "/content"

with tarfile.open(tar_path) as tar:
    tar.extractall(path=extract_path)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "/content/forget_dataset/Dogs/Siberian_husky"  # Update this to your image path
SAVE_PATH = "/content/"
CONCEPT = "Siberian Husky"

os.makedirs(SAVE_PATH, exist_ok=True)

In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from torch import nn
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2", subfolder="vae").to(DEVICE)
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet").to(DEVICE)
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder").to(DEVICE)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")



# Access projection matrix P
P = nn.Linear(1024, 1024, bias=False).to(DEVICE)  # if not exposed directly

# Initialize low-rank matrices A, B
r = 6 # rank
A = nn.Parameter(torch.randn(1024, r, device=DEVICE) * 0.01)
B = nn.Parameter(torch.randn(1024, r, device=DEVICE) * 0.01)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
vae.eval()
unet.eval()
for p in vae.parameters(): p.requires_grad = False
for p in unet.parameters(): p.requires_grad = False
for p in text_encoder.parameters(): p.requires_grad = False

for name, param in text_encoder.named_parameters():
    param.requires_grad = False
    if any(f"encoder.layers.{i}" in name for i in [20, 21, 22]) or "final_layer_norm" in name:
        param.requires_grad = True


In [None]:
from PIL import Image
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def load_images(data_path, max_images=5):
    images = []
    for fname in os.listdir(data_path):
        if fname.endswith(".jpg"):
            img = Image.open(os.path.join(data_path, fname)).convert("RGB")
            images.append(transform(img))
            if len(images) >= max_images:
                break
    return torch.stack(images)

images = load_images(DATA_PATH).to(DEVICE)


In [None]:
forget_prompt = f"a photo of {CONCEPT}"
retain_prompts = ["a photo of a car", "a photo of a flower", "a photo of a person", "a photo of a German Shepherd", "A photo of a frog", "an astronaut"]

x_f = tokenizer(forget_prompt, return_tensors="pt").input_ids.to(DEVICE)
x_r = tokenizer(retain_prompts, padding=True, return_tensors="pt").input_ids.to(DEVICE)

f_f = text_encoder(x_f)[0].mean(dim=1)
f_r = text_encoder(x_r)[0].mean(dim=1)
# Instead of zero vector
F_forget = torch.randn((1024, 1), device=DEVICE) * 2.0



In [None]:
%%time
optimizer = torch.optim.AdamW([
    {'params': filter(lambda p: p.requires_grad, text_encoder.parameters())},
    {'params': [A, B]}
], lr=1e-5)

text_encoder.train()


for epoch in range(10):
    total_loss = 0
    for i in range(len(images)):
        image = images[i].unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            latents = vae.encode(image).latent_dist.sample() * 0.18215

        noise = torch.randn_like(latents)
        t = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=DEVICE).long()
        noisy_latents = scheduler.add_noise(latents, noise, t)

        # Forget prompt embedding
        text_input = tokenizer(forget_prompt, return_tensors="pt").input_ids.to(DEVICE)
        f_f = text_encoder(text_input)[0].mean(dim=1)
        text_embed = text_encoder(text_input)[0]  # full sequence for UNet

        # Retain prompt embedding
        x_r = tokenizer(retain_prompts, padding=True, return_tensors="pt").input_ids.to(DEVICE)
        f_r = text_encoder(x_r)[0].mean(dim=1)

        # UNet prediction
        pred_noise = unet(noisy_latents, t, encoder_hidden_states=text_embed).sample
        loss_img = -F.mse_loss(pred_noise, noise)

        # Low-rank projection loss
        delta_P = A @ B.T
        retain_loss = F.mse_loss((P.weight + delta_P) @ f_r.T, P.weight @ f_r.T)
        forget_loss = F.mse_loss((P.weight + delta_P) @ f_f.T, F_forget)
        reg_loss = torch.norm(delta_P, p='fro')
        before = (P.weight @ f_f.T).detach()
        after = ((P.weight + delta_P) @ f_f.T)
        cos = F.cosine_similarity(before.T, after.T).item()
        print(f"Forget prompt projection similarity: {cos:.4f}")

        # Total loss
        loss_total = loss_img + 1.0 * retain_loss + 0.5 * forget_loss + 1e-3 * reg_loss

        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()

        total_loss += loss_total.item()

        # Optional: print losses

    print(f"Epoch {epoch}: Avg Loss = {total_loss:.4f}")
SO = "/content/model/"
# Save model and projection components
text_encoder.save_pretrained(SAVE_PATH)
torch.save({'A': A, 'B': B}, os.path.join(SAVE_PATH, 'deltaP.pt'))


Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Epoch 0: Avg Loss = 4.1046
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Epoch 1: Avg Loss = 4.4887
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Epoch 2: Avg Loss = 3.5552
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Forget prompt projection similarity: 0.9999
Epoch 3: Avg Loss = 3.0495
Forget promp

In [None]:
# Convert models and tensors to float16 for compatibility
text_encoder = text_encoder.half()
A.data = A.data.half()
B.data = B.data.half()
P.weight.data = P.weight.data.half()


In [None]:

# Then run inference
pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2",
    text_encoder=text_encoder,
    safety_checker=None,
    torch_dtype=torch.float16
).to(DEVICE)

prompt = f"a photo of the frog"
image = pipe(prompt).images[0]
image.save(f"{SAVE_PATH}/eval_retain.png")

prompt = f"a photo of {CONCEPT}"
image = pipe(prompt).images[0]
image.save(f"{SAVE_PATH}/eval_forgotten.png")



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
forget_prompt = f"a photo of {CONCEPT}"
retain_prompts = ["a photo of a car", "a flower", "a person"]

# Generate images
from PIL import Image
generated_images = []
generated_prompts = []

for prompt in [forget_prompt] + retain_prompts:
    img = pipe(prompt).images[0]
    generated_images.append(img)
    generated_prompts.append(prompt)


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]