In [None]:
import json
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AdamW
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, json_file, img_dir, processor):
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        self.img_dir = img_dir
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def format_output(self, item):
        return (
            f"content: {item['content']}\n"
            f"color: {item['color']}\n"
            f"composition: {item['composition']}\n"
            f"quality: {item['quality']}"
        )

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = f"{self.img_dir}/{item['image']}"
        image = Image.open(img_path).convert("RGB")
        
        combined_caption = self.format_output(item)
        inputs = self.processor(image, combined_caption, return_tensors="pt", padding=True, truncation=True)
        return inputs

In [None]:
def generate_image_description(image_path, prompt):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_length=200)
    description = processor.decode(outputs[0], skip_special_tokens=True)
    return json.loads(description)

In [None]:
def collate_fn(batch):
    input_ids = [item['input_ids'].squeeze(0) for item in batch]
    attention_masks = [item['attention_mask'].squeeze(0) for item in batch]
    pixel_values = torch.stack([item['pixel_values'].squeeze(0) for item in batch])

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded,
        'pixel_values': pixel_values
    }


In [None]:
def generate_image_description(image_path, max_length):
    
    image = Image.open(image_path).convert("RGB")
    
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    output = model.generate(**inputs, max_length=max_length)
    
    evaluation = processor.decode(output[0], skip_special_tokens=True)
    return evaluation

In [None]:
def parse_output(description):
    lines = description.strip().split('\n')
    parsed_output = {}
    for line in lines:
        if ': ' in line:
            key, value = line.split(': ', 1)
            parsed_output[key.strip()] = value.strip()
    return parsed_output

In [None]:
def format_with_newlines(text):
    lines = text.split('. ')
    formatted_text = "\n".join([line + "." for line in lines if line.strip()])
    return formatted_text

In [None]:
epochs = 100
batch_size = 3
learning_rate = 0.0001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
model.to(device)

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

dataset = CustomImageDataset("dataset/annotations.json", "dataset/images", processor)

dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

optimizer = AdamW(model.parameters(), lr=learning_rate)

loss_fn = CrossEntropyLoss()

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)

        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()

        print(f"Epoch [{epoch + 1}], Step [{step + 1}/{len(dataloader)}], Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(dataloader)
    print(f"Average Loss: {avg_loss:.4f}")

In [None]:
image_path = "dataset/images/ID10.jpeg" 
evaluation = generate_image_description(image_path, max_length=320)
output = format_with_newlines(evaluation)
print("Evaluation:\n" + output)