## Code Exploring LLAMA based Prompt Enhancement for SUR-Adapter

In [1]:
model_name = "meta-llama/Llama-2-7b-chat-hf"

In [18]:
def get_caption_detailing_prompt(caption):
    base_prompt = '''
    Please generate the long prompt version of the short one according to the given examples. Long prompt version should consist of 3 to 5 sentences. Long prompt version must specify the color, shape, texture or spatial relation of the included objects. DO NOT generate sentences that describe any atmosphere!!!
    
        Short: A calico cat with eyes closed is perched upon a Mercedes.
        Long: a multicolored cat perched atop a shiny black car. the car is parked in front of a building with wooden walls and a green fence. the reflection of the car and the surrounding environment can be seen on the car's glossy surface.
    
        Short: A boys sitting on a chair holding a video game remote.
        Long: a young boy sitting on a chair, wearing a blue shirt and a baseball cap with the letter 'm'. he has a red medal around his neck and is holding a white game controller. behind him, there are two other individuals, one of whom is wearing a backpack. to the right of the boy, there's a blue trash bin with a sign that reads 'automatic party'.
    
        Short: A man is on the bank of the water fishing.
        Long: a serene waterscape where a person, dressed in a blue jacket and a red beanie, stands in shallow waters, fishing with a long rod. the calm waters are dotted with several sailboats anchored at a distance, and a mountain range can be seen in the background under a cloudy sky.
    
        Short: A kitchen with a cluttered counter and wooden cabinets.
        Long: a well-lit kitchen with wooden cabinets, a black and white checkered floor, and a refrigerator adorned with a floral decal on its side. the kitchen countertop holds various items, including a coffee maker, jars, and fruits.
    
        Short: 
        %s
    Strictly return the output as following "ANSWER: {long prompt for above short}", do not give any other text as output
    '''
    return base_prompt % caption

In [47]:
from transformers import AutoTokenizer
import transformers
import torch
import accelerate

device = torch.device('cuda:4')

tokenizer=AutoTokenizer.from_pretrained(model_name)
pipeline=transformers.pipeline(
    "text-generation",
    model=model_name,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    return_full_text=False
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [48]:
import time

def get_detailed_caption_with_llama(caption):
    t1 = time.perf_counter()
    sequences = pipeline(get_caption_detailing_prompt(caption))
    print("llama inference took:", time.perf_counter() - t1)
    return sequences[0]['generated_text'].strip(" \nLong:")

In [49]:
get_detailed_caption_with_llama("An elephant holding an apple in its trunk")

llama inference took: 3.444841541000642


'a majestic elephant with a wrinkled gray hide, standing near a lush green tree. the elephant is holding an apple in its long, flexible trunk, and its tusks glint in the sunlight. in the background, a group of birds can be seen flying overhead, and a distant mountain range can be seen through the trees.'

In [50]:
from transformers import CLIPTextModel, CLIPTokenizer

diff_model_name = "runwayml/stable-diffusion-v1-5"

clip_tokenizer = CLIPTokenizer.from_pretrained(
    diff_model_name,
    subfolder="tokenizer",
    revision=None,
)

clip_text_encoder = CLIPTextModel.from_pretrained(
    diff_model_name,
    subfolder="text_encoder",
    revision=None
)


In [68]:
complex_prompt = get_detailed_caption_with_llama("An elephant holding an apple in its trunk")
complex_prompt

llama inference took: 3.5084380309999688


'a majestic elephant with a wrinkled gray skin, holding a red apple in its long, curled trunk. the elephant stands in a lush green forest, surrounded by tall trees and a clear blue sky. the sunlight filters through the leaves, casting dappled shadows on the forest floor.'

In [69]:
tokens = clip_tokenizer(
    complex_prompt,
    return_tensors='pt',
    max_length=clip_tokenizer.model_max_length,
    padding="max_length",
    truncation=True
)

tokens

{'input_ids': tensor([[49406,   320, 15335, 10299,   593,   320, 22201,   912,  7048,  3575,
           267,  5050,   320,   736,  3055,   530,   902,  1538,   267, 43734,
         18347,   269,   518, 10299,  6446,   530,   320, 16263,  1901,  4167,
           267, 13589,   638,  7771,  4682,   537,   320,  3143,  1746,  2390,
           269,   518, 17996, 18385,  1417,   518,  5579,   267,  7087, 46857,
           912, 12971,   525,   518,  4167,  4125,   269, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0]])}

In [70]:
embeddings = clip_text_encoder(**tokens)
clip_out = embeddings.last_hidden_state
clip_out.shape

torch.Size([1, 77, 768])

In [72]:
def inference(caption):
    complex_prompt = get_detailed_caption_with_llama(caption)
    tokens = clip_tokenizer(
        complex_prompt,
        return_tensors='pt',
        max_length=clip_tokenizer.model_max_length,
        padding="max_length",
        truncation=True
    )
    embeddings = clip_text_encoder(**tokens)
    clip_out = embeddings.last_hidden_state
    return clip_out

In [73]:
c = inference("a racoon holding a red box")

llama inference took: 2.828783167002257


In [74]:
c.shape

torch.Size([1, 77, 768])

In [75]:
from SUR_adapter import Adapter
suradapter = Adapter(adapter_weight=1e-4, sd_text_size=768)

In [113]:
simple_prompt = "a racoon holding a red box"
tokens = clip_tokenizer(
    simple_prompt,
    return_tensors='pt',
    max_length=clip_tokenizer.model_max_length,
    padding="max_length",
    truncation=True
)
embeddings = clip_text_encoder(**tokens)
clip_out = embeddings.last_hidden_state
clip_out.shape

torch.Size([1, 77, 768])

In [87]:
out, _ , _ = suradapter(clip_out)

In [89]:
out.shape

torch.Size([1, 77, 768])

In [110]:
def loss_fn(x, y):
    lf = torch.nn.CosineSimilarity(dim=0, eps=1e-08)
    return lf(x.flatten(), y.flatten())

In [111]:
loss_fn(out, c)

tensor(0.3128, grad_fn=<SumBackward1>)