In [None]:
# img2text
# https://grok.com/chat/1c87d70c-502b-447f-a53f-b13833c55a61

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import pandas as pd
import os
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split
from collections import Counter
import nltk
from transformers import AutoTokenizer

# Параметры
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
EPOCHS = 5
IMG_SIZE = 224  # Уменьшенное разрешение для экономии памяти
MAX_TEXT_LENGTH = 50  # Максимальная длина текста в токенах
VOCAB_SIZE = 5000  # Размер словаря
LEARNING_RATE = 1e-5
VALIDATION_SPLIT = 0.1  # 10% данных для валидации


# 1. Кастомный датасет
class ImageTextDataset(Dataset):
    def __init__(self, data, img_dir, processor, tokenizer):
        self.data = data
        self.img_dir = img_dir
        self.processor = processor
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx]["img_path"])
        text = self.data.iloc[idx]["text"]

        # Загрузка изображения
        img = Image.open(img_path).convert("RGB")

        # Обработка изображения и текста
        encoding = self.processor(
            images=img,
            text=text,
            padding="max_length",
            max_length=MAX_TEXT_LENGTH,
            truncation=True,
            return_tensors="pt",
        )

        return {k: v.squeeze(0) for k, v in encoding.items()}


# 2. Создание кастомного словаря
def build_custom_vocab(texts, vocab_size):
    # Токенизация текстов (простая, на основе слов)
    nltk.download("punkt")
    all_words = []
    for text in texts:
        words = nltk.word_tokenize(text.lower())
        all_words.extend(words)

    # Ограничение словаря
    word_counts = Counter(all_words)
    vocab = {
        word: idx + 4
        for idx, (word, _) in enumerate(word_counts.most_common(vocab_size - 4))
    }
    vocab["<PAD>"] = 0
    vocab["<START>"] = 1
    vocab["<END>"] = 2
    vocab["<UNK>"] = 3

    return vocab


# 3. Загрузка и подготовка данных
def prepare_data(csv_file, img_dir):
    # Чтение CSV
    data = pd.read_csv(csv_file)

    # Разделение на train и validation
    train_data, val_data = train_test_split(
        data, test_size=VALIDATION_SPLIT, random_state=42
    )

    # Создание кастомного словаря
    vocab = build_custom_vocab(data["text"].values, VOCAB_SIZE)

    # Загрузка процессора и токенизатора BLIP
    processor = BlipProcessor.from_pretrained(
        "Salesforce/blip-image-captioning-base", image_size=IMG_SIZE
    )
    tokenizer = processor.tokenizer

    # Датасеты
    train_dataset = ImageTextDataset(train_data, img_dir, processor, tokenizer)
    val_dataset = ImageTextDataset(val_data, img_dir, processor, tokenizer)

    # DataLoader
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4
    )
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4
    )

    return train_loader, val_loader, vocab, processor


# 4. Дообучение модели
def train_model(model, train_loader, val_loader, optimizer, scaler, epochs):
    model.train()
    for epoch in range(epochs):
        # Тренировка
        total_train_loss = 0
        for batch in train_loader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            optimizer.zero_grad()

            with autocast():  # Mixed precision
                outputs = model(**batch)
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # Валидация
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(DEVICE) for k, v in batch.items()}
                with autocast():
                    outputs = model(**batch)
                    loss = outputs.loss
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        print(
            f"Epoch [{epoch+1}/{epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
        )
        model.train()

    return model


# 5. Инференс (пример генерации текста)
def generate_caption(model, processor, image_path, max_length=MAX_TEXT_LENGTH):
    model.eval()
    img = Image.open(image_path).convert("RGB")
    inputs = processor(images=img, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model.generate(
            pixel_values=inputs.pixel_values,
            max_length=max_length,
            num_beams=5,
            early_stopping=True,
        )

    caption = processor.decode(outputs[0], skip_special_tokens=True)
    return caption


# Основной блок
def main():
    # Путь к данным
    csv_file = "dataset.csv"
    img_dir = "images/"

    # Подготовка данных
    train_loader, val_loader, vocab, processor = prepare_data(csv_file, img_dir)

    # Загрузка модели
    model = BlipForConditionalGeneration.from_pretrained(
        "Salesforce/blip-image-captioning-base"
    )
    model = model.to(DEVICE)

    # Заморозка ViT для экономии памяти
    for param in model.vision_model.parameters():
        param.requires_grad = False

    # Оптимизатор и mixed precision
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scaler = GradScaler()

    # Gradient checkpointing
    model.gradient_checkpointing_enable()

    # Дообучение
    print("Starting fine-tuning...")
    model = train_model(model, train_loader, val_loader, optimizer, scaler, EPOCHS)

    # Сохранение модели
    model.save_pretrained("blip_finetuned")
    processor.save_pretrained("blip_finetuned")

    # Пример инференса
    test_image = os.path.join(img_dir, train_loader.dataset.data.iloc[0]["img_path"])
    caption = generate_caption(model, processor, test_image)
    print(f"Generated caption for {test_image}: {caption}")


if __name__ == "__main__":
    main()