# Use LORA to finetune caption model

## Blip

### Load model and dataset

In [1]:
import os
! cd "/home/changl25/Image-to-text-of-Stable-Diffusion"
work_dir = "/home/changl25/Image-to-text-of-Stable-Diffusion"
os.chdir(work_dir)

In [None]:
! ls "/data/changl25/Diffusion2DB/part-000002"
data_dir = "/data/changl25/Diffusion2DB"
os.chdir(data_dir)

In [30]:
from urllib.request import urlretrieve
import shutil
part_id = 1
part_url = f'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/images/part-{part_id:06}.zip'
shutil.unpack_archive(f'part-{part_id:06}.zip', f'part-{part_id:06}')
urlretrieve(part_url, f'part-{part_id:06}.zip')
part_id = 2
part_url = f'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/images/part-{part_id:06}.zip'
urlretrieve(part_url, f'part-{part_id:06}.zip')
shutil.unpack_archive(f'part-{part_id:06}.zip', f'part-{part_id:06}')

('part-000002.zip', <http.client.HTTPMessage at 0x2ab32ecea860>)

In [2]:
import torch
from dataset_diffusionDB import DiffusionDB
from transformers import BlipProcessor, BlipForConditionalGeneration
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
import numpy as np

from tqdm.auto import tqdm

%load_ext autoreload
%autoreload 2

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
precision="float32"

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to(device) if precision == "float16" \
    else  BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

In [None]:
model  = model.to()

In [5]:
from transformers import CLIPTextModel, CLIPTokenizer
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

In [6]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [7]:
def print_model_structures(model):
    for name, module in model.named_modules():
        print("name: ", name)
        print("module.module: ", list(module.modules()))

In [8]:
model

BlipForConditionalGeneration(
  (vision_model): BlipVisionModel(
    (embeddings): BlipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): BlipEncoder(
      (layers): ModuleList(
        (0-11): 12 x BlipEncoderLayer(
          (self_attn): BlipAttention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): BlipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-0

In [8]:
print_trainable_parameters(model)

trainable params: 247414076 || all params: 247414076 || trainable%: 100.0


In [9]:
batch_size = 16
root_path = '/data/changl25/Diffusion2DB/part-000001'
test_dataset = DiffusionDB(root_path, text = "a photo of", transform=processor, test=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True)
batch_size = 16
root_path = '/data/changl25/Diffusion2DB/part-000002'
train_dataset = DiffusionDB(root_path, text = "a photo of", transform=processor)
train_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True)

In [31]:
! ls /data/changl25/img2textModel/blip_model

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


config.json  generation_config.json  model.safetensors


### LoRA config

In [10]:
config = LoraConfig(
    r = 4,
    lora_alpha = 32, 
    lora_dropout = 0.05,
    bias = "none",
    target_modules = ["query", "value"]
)

peft_model = get_peft_model(model, config)
peft_model = peft_model.to(device)

In [11]:
print_trainable_parameters(peft_model)

trainable params: 294912 || all params: 247708988 || trainable%: 0.11905583337169824


In [12]:
optimizer = torch.optim.AdamW(peft_model.parameters(), lr=1e-4)
epochs = 50

In [34]:
import time
def sd_encoder(text, tokenizer, encoder):
    input = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    input_ids = input["input_ids"].to(device)
    embeddings = encoder(input_ids)[0]
    return embeddings

def clip_cos_similarity(output_prompt, prompts):
    caption_embeddings = sd_encoder(output_prompt, clip_tokenizer, clip_text_encoder)
    prompt_embeddings = sd_encoder(prompts, clip_tokenizer, clip_text_encoder)
    prompt_embeds_flat = prompt_embeddings.view(prompt_embeddings.size(0), -1)
    caption_embeds_flat = caption_embeddings.view(caption_embeddings.size(0), -1)
    prompt_embeds = prompt_embeds_flat / prompt_embeds_flat.norm(dim=1, keepdim=True)
    caption_embeds = caption_embeds_flat / caption_embeds_flat.norm(dim=1, keepdim=True)
    similarity = (torch.matmul(prompt_embeds, caption_embeds.t())).mean()
    return similarity

def sentence_cos_similarity(output_prompt, prompts):
    import sys
    sys.path.append("/data/changl25/img2textModel/sentence-transformers/")
    from sentence_transformers import SentenceTransformer, models

    st_model = SentenceTransformer('/data/changl25/img2textModel/all-MiniLM-L6-v2')
    prompt_embedding = st_model.encode(prompts).flatten()
    output_embedding = st_model.encode(output_prompt).flatten()
    similarity = np.dot(prompt_embedding, output_embedding) / (np.linalg.norm(prompt_embedding) * np.linalg.norm(output_embedding))
    return similarity

def evaluate(peft_model, preprocessor, data_loader, text_flag, precision):
    loss = 0
    evaluate_text = []
    caption_text = []
    prompt_text = []
    peft_model.eval()
    
    loss_time = 0
    generate_time = 0
    load_time = 0
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            start_time = time.time()
            inputs, inputs_generator, prompt = data
            for key, value in inputs.items():
                if precision == "float32":
                    inputs[key] = value.to(device)
                    inputs_generator[key] = inputs_generator[key].to(device)
                elif precision == "float16":
                    inputs[key] = value.to(device)
                    inputs_generator[key] = inputs_generator[key].to(device, torch.float16)
            if precision == "float32":
                labels = torch.tensor(preprocessor.tokenizer(text=prompt, padding="max_length")["input_ids"]).to(device)
            elif precision == "float16":
                labels = torch.tensor(preprocessor.tokenizer(text=prompt, padding="max_length")["input_ids"]).to(device, torch.float16)
            end_time3 = time.time()
            if text_flag:
                input_ids = inputs["input_ids"]
                pixel_values = inputs["pixel_values"]
                attention_mask = inputs["attention_mask"]
                outputs = peft_model(input_ids=input_ids, pixel_values=pixel_values,attention_mask=attention_mask, labels=labels)
            else:
                pixel_values = inputs["pixel_values"]
                outputs = peft_model(pixel_values=pixel_values, labels=labels)
            loss += outputs.loss / len(prompt)
            end_time1 = time.time()
            out = peft_model.generate(**inputs_generator)
            out_text = preprocessor.batch_decode(out, skip_special_tokens=True)
            end_time2 = time.time()
            evaluate_text.append((out_text, prompt))
            loss_time += (end_time1  - end_time3) / len(prompt)
            generate_time += (end_time2  - end_time1) / len(prompt)
            load_time += (end_time3 - start_time) / len(prompt)
    start_time = time.time()        
    for i in range(len(evaluate_text)):
        output_text_batch, prompt_batch = evaluate_text[i]
        for j in range(len(output_text_batch)):
            output_text, prompt = output_text_batch[j], prompt_batch[j]
            caption_text.append(output_text)
            prompt_text.append(prompt)
    clip_sim = clip_cos_similarity(caption_text, prompt_text)
    sentence_sim = sentence_cos_similarity(caption_text, prompt_text)
    end_time = time.time()
    print(f"load time: {load_time / len(data_loader):.4f}, loss time: {loss_time / len(data_loader)}, generate time: {generate_time / len(data_loader)}, similar time: {(end_time - start_time) / len(caption_text)}")
    return loss / len(data_loader), clip_sim, sentence_sim


In [35]:
loss, clip_similarity, sentence_similarity = evaluate(model, processor, test_loader, test_dataset.is_text_supervised(), precision)



load time: 0.0004, loss time: 1.6894775756767817, generate time: 1.6949237776654107, similar time: 0.1320671010017395


tensor(1.1335)

In [52]:
def train_manual(peft_model, preprocessor, train_loader, test_loader, epochs, optimizer, precision, text_flag):
    for epoch in range(epochs):
        peft_model.train()
        for i, data in enumerate(train_loader):
            inputs, _, prompt = data
            for key, value in inputs.items():
                if precision == "float32":
                    inputs[key] = value.to(device)
                    labels = torch.tensor(preprocessor.tokenizer(text=prompt, padding="max_length")["input_ids"]).to(device)
                elif precision == "float16":
                    inputs[key] = value.to(device, torch.float16)
                    labels = torch.tensor(preprocessor.tokenizer(text=prompt, padding="max_length")["input_ids"]).to(device, torch.float16)
                    
            if text_flag:
                input_ids = inputs.pop("input_ids")
                pixel_values = inputs.pop("pixel_values")
                attention_mask = inputs.pop("attention_mask")
                outputs = peft_model(input_ids=input_ids, pixel_values=pixel_values,attention_mask=attention_mask,labels=labels)
            else:
                pixel_values = inputs.pop("pixel_values")
                outputs = peft_model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % (len(train_loader) // 3) == 0:
                print(f"Epoch {epoch}: {i} / {len(train_loader)} {loss} ")
        evaluate_loss, similarity = evaluate(peft_model, preprocessor,test_loader, text_flag, precision)
        print(f"Epoch {epoch} final loss: {evaluate_loss}; similarity: {similarity}")

In [68]:
data = next(iter(train_loader))
peft_model.train()
inputs, prompt = data
for key, value in inputs.items():
    print(key)
    if precision == "float32":
        inputs[key] = value.to(device)
    elif precision == "float16":
        inputs[key] = value.to(device, torch.float16)
input_ids = inputs.pop("input_ids")
pixel_values = inputs.pop("pixel_values")
attention_mask = inputs.pop("attention_mask")
outputs = peft_model(input_ids=input_ids, pixel_values=pixel_values,attention_mask=attention_mask)


pixel_values
input_ids
attention_mask


In [None]:
evaluate(peft_model, test_loader, test_dataset.is_text_supervised(), precision)

In [53]:
train_manual(peft_model, processor, train_loader, test_loader, epochs, optimizer, precision, train_dataset.is_text_supervised())

Epoch 0: 1 / 7 12.801506042480469 
Epoch 0: 3 / 7 12.566864967346191 
Epoch 0: 5 / 7 12.68709659576416 
('portrait of saitama, artstation winner by victo ngai, kilian eng and by jake parker, by conrad roset, swirly vibrant color lines, winning award masterpiece, fantastically gaudy, aesthetic, cinematic, octane render, 8 k, hd resolution ', 'a huge worm - man beast realistic cinematic 3 5 mm ', 'a geometrical portrait of a knave, fractal flowering background, digital art, analogous colours, trending on artstation, ', 'film still by 3 4 3 industries ', 'photo cartoon bd illustration comic manga painting of hangar environement : 5 fantasy environement, digital painting : 1 fat brush concept sketch artist bd enki bilal : 1 0 ', 'two vipers entwined, fighting to the death realistic cinematic 3 5 mm ', 'a beautiful painting of the heart of pripyat by nekro and pascal blanche and syd mead and greg rutkowski and sin jong hun and victo ngai and simon stalenhag and chris voy. in style of cg art



('a star is born, in style of john harris, watercolor, artstation ', 'model in baggy colorful 9 0 s jacket by rick owens. magazine ad. pastel background. ', 'anthropomorphic racoon animal, dressed as a raver chick dancer, dance club in the forest, many woodland creatures dancing, concept design, contrast, hot toys, kim jung gi, greg rutkowski, zabrocki, karlkka, jayison devadas, trending on artstation, 8 k, ultra wide angle, pincushion lens effect ', 'render of dreamy beautiful landscape, fantasy dreamy, ice kingdom, artger, large scale, details vintage photo hyper realistic ultra realistic photo realistic photography, unreal engine, high detailed, 8 k ', 'a painting by edward hopper of a busy and bustling city by contrast with the calming blue sky. ', 'by maxfield parrish, greg manchess, mucha ', 'an illustration of the second law of thermodynamics ', 'from the coast of north oregon, usa ( oc ) [ r / earthporn ] ', 'a beautiful painting of chernobyl in autumn by nekro and pascal blanc

: 

Current problems
* What is the exact input? what parameters do we need and their meanings? like
    * input_ids: labels or hint 
    * labels: encoded or not encoded?
    * loss: how is it computed? what kind of loss function is used? if we need to finetune, do we need to change the loss?

In [52]:
model.save_pretrained("/data/changl25/img2textModel/blip_model")

True