In [28]:
import torch
import clip
from PIL import Image
import requests
from io import BytesIO
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [29]:
# CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt_model = GPT2LMHeadModel.from_pretrained("gpt2")
gpt_model.to(device)

# pad_token
tokenizer.pad_token = tokenizer.eos_token

# Load an image from a URL
def load_image(image_url):
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content))
    return img

image_url = "https://assets.aboutamazon.com/7a/f6/2381a7084f3184ed19fc33d2efae/taylor-swift-eras-ext-pvod-1920x1080.jpeg"
image = load_image(image_url)

# Preprocess the image for CLIP
image_input = preprocess(image).unsqueeze(0).to(device)

In [37]:
with torch.no_grad():
    image_features = model.encode_image(image_input)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

# Text descriptions
text_descriptions = [
    "a concert photo of a famous singer",
    "a live performance on stage",
    "a famous singer performing at a concert",
    "a photo of Taylor Swift performing at a concert",
    "a crowd watching a concert",
    "a stage performance with bright lights",
    "a singer performing in front of a huge audience"
]

# Tokenize the descriptions
text_inputs = clip.tokenize(text_descriptions).to(device)

# Calculate similarities between the image and descriptions
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    # Compare image features with text features
    similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# Find the best matching description
best_match_index = similarities.argmax().item()
best_description = text_descriptions[best_match_index]
print("Best Description: ", best_description)

def clean_caption(caption):
    # Remove irrelevant parts if they exist
    stop_phrases = ['Getty Images', 'Oscar', 'The Artist', 'Photo', '(']
    for phrase in stop_phrases:
        if phrase in caption:
            caption = caption.split(phrase)[0].strip()
    return caption

prompt = f"A concert photo of Taylor Swift performing on stage."
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)

output_sequences = gpt_model.generate(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'], 
    max_length=30,  
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.5,
    pad_token_id=tokenizer.eos_token_id
)

# Decode and post-process the output to remove nonsense
caption = tokenizer.decode(output_sequences[0], skip_special_tokens=True)

# Clean caption to remove irrelevant parts
cleaned_caption = clean_caption(caption)
print("Generated Caption: ", cleaned_caption)


Best Description:  a photo of Taylor Swift performing at a concert
Generated Caption:  A concert photo of Taylor Swift performing on stage.
