# Influential data identification - Stable_Diffusion - Style_Transfer

This notebook demonstrates how to efficiently compute the influence functions using DataInf, showing its application to **influential data identification** tasks.

- Model: [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
- Fine-tuning dataset: [A style_transfer dataset](https://huggingface.co/datasets/kewu93/three_styles_prompted_250_512x512) that combined three different styles (cartoon, sketch, and pixel-art).

References
- `diffusers` HuggingFace library [[Link]](https://huggingface.co/docs/diffusers).
- DataInf is available at this [ArXiv link](https://arxiv.org/abs/2310.00902).

## Fine-tune a text-to-image model
- We fine-tune a stable-diffusion-v1-5 model on a style-transfer dataset. We use `src/train_text_to_image_lora.py`, which is built on HuggingFace's [example](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py). 
- The following code fine-tunes the model. If you want to skip this part, we can simply load fine-tuned weights at [this link](https://huggingface.co/kewu93/three_styles_lora).

In [1]:
# !accelerate launch /PATH_TO_DataInf/DataInf/src/train_text_to_image_lora.py \
#   --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
#   --dataset_name=kewu93/three_styles_prompted_250_512x512 \ 
#   --resolution=512 --center_crop --random_flip \
#   --train_batch_size=1 \
#   --gradient_accumulation_steps=4 \
#   --max_train_steps=10000 \
#   --learning_rate=1e-04 \
#   --max_grad_norm=1 \
#   --lr_scheduler="cosine" --lr_warmup_steps=0 \
#   --output_dir=/PATH_TO_OUTPUT_DIR/three_styles_lora \
#   --checkpointing_steps=1000 \
#   --validation_prompt="A sports car driving down a windy road." \
#   --seed=1337 \
#   --rank=2 \
#   --resume_from_checkpoint="latest"

In [2]:
import random, pickle
import numpy as np
from tqdm import tqdm
import torch
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, DiffusionPipeline
import torch.nn.functional as F

import sys
sys.path.append('../src')
from influence import IFEngineGeneration

## Load a fine-tuned model

In [3]:
model_base = "runwayml/stable-diffusion-v1-5"
tokenizer = CLIPTokenizer.from_pretrained(model_base, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_base, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_base, subfolder="vae").cuda()
noise_scheduler = DDPMScheduler.from_pretrained(model_base, subfolder="scheduler")

'''
Load Lora-tuned Unet
'''
pipeline = DiffusionPipeline.from_pretrained(model_base)
pipeline.load_lora_weights("kewu93/three_styles_lora") # publicly available weights!
unet=pipeline.unet

for param in unet.named_parameters():
    param[1].requires_grad = True

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


## Load datasets and data loaders

In [4]:
'''
Load Datasets
'''

def tokenize_captions(examples, is_train=True):
    captions = []
    for caption in examples['text']:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `'text'` should contain either strings or lists of strings."
            )
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

train_transforms = transforms.Compose(
    [
        transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

def preprocess_train(examples):
    images = [image.convert("RGB") for image in examples['image']]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples

dataset_name = 'kewu93/three_styles_prompted_250_512x512'
dataset = load_dataset(dataset_name)

train_dataset = dataset["train"].with_transform(preprocess_train)
val_dataset = dataset["val"].with_transform(preprocess_train)


'''
Create Data Loaders
'''

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}
    
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,
    num_workers=1,
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,
    num_workers=1,
)

## Compute the gradient
 - Influence function uses the first-order gradient of a loss function. 

In [5]:
name_list = ['train', 'val']
gradient_dict={}
for idx, dataloader_ in enumerate([train_dataloader, val_dataloader]):
    print('-'*30)
    print(name_list[idx])
    print('-'*30)
    unet.train()
    unet = unet.cuda()
    grad_dict = {}
    for step, batch in tqdm(enumerate(dataloader_)):
        torch.manual_seed(step)
        grad_dict_one_sample={}
        for layer_name, layer_weights in unet.named_parameters():
            if 'lora_' in layer_name:
                grad_dict_one_sample[layer_name] = []

        for timestep_ in [25, 225, 425, 525, 725, 925]:
            unet.zero_grad()
            latents = vae.encode(batch["pixel_values"].cuda()).latent_dist.sample().cuda()
            latents = latents * vae.config.scaling_factor
            noise = torch.randn_like(latents).cuda()
            bsz = latents.shape[0]
            timesteps = torch.LongTensor([timestep_]).cuda()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).cuda()
            encoder_hidden_states = text_encoder(batch["input_ids"])[0].cuda()
            target = noise
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            loss.backward()
            for layer_name, layer_weights in unet.named_parameters():
                if 'lora_A' in layer_name:
                    grad_dict_one_sample[layer_name].append(layer_weights.grad.cpu())
                elif 'lora_B' in layer_name:
                    # first index of shape indicates low-rank
                    grad_dict_one_sample[layer_name].append(layer_weights.grad.T.cpu())
                else:
                    pass
                if 'lora_' in layer_name and timestep_ == 925:
                    grad_dict_one_sample[layer_name] = torch.cat(grad_dict_one_sample[layer_name])            

        grad_dict[step]=grad_dict_one_sample
        del latents, noise, bsz, timesteps, noisy_latents, encoder_hidden_states, target, model_pred, loss
        torch.cuda.empty_cache()
        
    gradient_dict[name_list[idx]]=grad_dict

------------------------------
train
------------------------------


600it [48:52,  4.89s/it]

------------------------------
val
------------------------------



150it [12:08,  4.86s/it]


## Compute the influence function

In [6]:
influence_engine = IFEngineGeneration()
influence_engine.preprocess_gradients(gradient_dict['train'], gradient_dict['val'])
influence_engine.compute_hvps()
influence_engine.compute_IF()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [23:48<00:00,  9.52s/it]


Computing IF for method:  identity
Computing IF for method:  proposed


## Application to influential data detection task
### AUC and Recall 

In [7]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score

identity_df=influence_engine.IF_dict['identity']
proposed_df=influence_engine.IF_dict['proposed']

In [14]:
identity_auc_list, proposed_auc_list=[], []
for i in range(len(dataset["val"]['style_class'])):
    gt_label=dataset["val"]['style_class'][i]
    gt_array=np.array([1 if tr_label == gt_label else 0 for tr_label in dataset["train"]['style_class']])
    
    # The influence function is anticipated to have a big negative value when its class equals to a validation data point. 
    # This is because a data point with the same class is likely to be more helpful in minimizing the validation loss.
    # Thus, we multiply the influence function value by -1 to account for alignment with the gt_array. 
    identity_auc_list.append(roc_auc_score(gt_array, -(identity_df.iloc[i,:].to_numpy())))
    proposed_auc_list.append(roc_auc_score(gt_array, -(proposed_df.iloc[i,:].to_numpy())))
    
print(f'identity AUC: {np.mean(identity_auc_list):.3f}/{np.std(identity_auc_list):.3f}')
print(f'proposed AUC: {np.mean(proposed_auc_list):.3f}/{np.std(proposed_auc_list):.3f}')

identity AUC: 0.612/0.079
proposed AUC: 0.599/0.077


In [9]:
# Recall calculations
val_array=np.array(dataset['val']['style_class'])
identity_recall_list, proposed_recall_list=[], []
for i in range(len(dataset["val"]['style_class'])):
    gt_label=dataset["val"]['style_class'][i]
    n_label=np.sum(val_array == gt_label)
    
    sorted_index=np.argsort(identity_df.iloc[i].values) # ascending order
    sorted_array=np.array([dataset["train"]['style_class'][j] for j in sorted_index])
    recall_identity=np.count_nonzero(sorted_array[:n_label] == gt_label)/n_label
    identity_recall_list.append(recall_identity)
    
    sorted_index=np.argsort(proposed_df.iloc[i].values) # ascending order
    sorted_array=np.array([dataset["train"]['style_class'][j] for j in sorted_index])
    recall_proposed=np.count_nonzero(sorted_array[:n_label] == gt_label)/n_label
    proposed_recall_list.append(recall_proposed)
    
print(f'identity Recall: {np.mean(identity_recall_list):.3f}/{np.std(identity_recall_list):.3f}')
print(f'proposed Recall: {np.mean(proposed_recall_list):.3f}/{np.std(proposed_recall_list):.3f}')

identity Recall: 0.889/0.173
proposed Recall: 0.916/0.109
