In [1]:
# ============================================================================
# 0. Скачивание данных и чекпоинта
# ============================================================================

# !curl https://storage.yandexcloud.net/hse-ai24-mm-hw3-data/course_vqa_dataset.zip -o data/course_vqa_dataset.zip
!curl https://storage.yandexcloud.net/hse-ai24-mm-hw3-data/vqa-clip-bert-pred-epoch=66-val_acc=0.649.ckpt -o data/VQAClipBERTpred.ckpt
# !unzip -nq data/course_vqa_dataset.zip -d data/course_vqa_dataset

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  694M  100  694M    0     0   101M      0  0:00:06  0:00:06 --:--:--  101M


In [17]:
# Воспроизводимый пайплайн: формирование финального submission.csv

# ============================================================================
# 1. ИМПОРТЫ
# ============================================================================
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import re
import clip
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import pytorch_lightning as pl
import pandas as pd
from PIL import Image
from pathlib import Path
from collections import Counter
from torchmetrics import Accuracy
from torch.nn import MultiheadAttention
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader

# ============================================================================
# 2. КОНСТАНТЫ И ПУТИ
# ============================================================================
DATA_DIR = Path("./data/course_vqa_dataset")
IMAGES_DIR = DATA_DIR / "images"
TRAIN_PATH = DATA_DIR / "train.csv"
TEST_PATH = DATA_DIR / "test.csv"
SAMPLE_SUB_PATH = DATA_DIR / "sample_submission.csv"
MODEL_PATH = DATA_DIR / "../VQAClipBERTpred.ckpt"
SUBMISSION_PATH = DATA_DIR / "submission.csv"

# ============================================================================
# 3. ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ
# ============================================================================
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

def tokenize(text: str):
    text = text.lower()
    text = re.sub(r"[^a-z0-9]+", " ", text)
    tokens = text.strip().split()
    return tokens

def build_word_vocab(texts, min_freq: int = 1, max_size: int | None = None):
    counter = Counter()
    for t in texts:
        counter.update(tokenize(t))
    
    words_and_counts = [(w, c) for w, c in counter.items() if c >= min_freq]
    words_and_counts.sort(key=lambda x: x[1], reverse=True)
    if max_size is not None:
        words_and_counts = words_and_counts[:max_size]
    
    vocab_words = [w for w, _ in words_and_counts]
    word2id = {PAD_TOKEN: 0, UNK_TOKEN: 1}
    for w in vocab_words:
        if w not in word2id:
            word2id[w] = len(word2id)
    
    id2word = {i: w for w, i in word2id.items()}
    return word2id, id2word

def encode_question(text: str, word2id: dict, max_len: int = 20):
    tokens = tokenize(text)
    ids = []
    for tok in tokens[:max_len]:
        ids.append(word2id.get(tok, word2id[UNK_TOKEN]))
    if len(ids) < max_len:
        ids += [word2id[PAD_TOKEN]] * (max_len - len(ids))
    return torch.tensor(ids, dtype=torch.long)

def build_answer_vocab(train_df, min_freq: int = 1, max_word_count: int = 0):
    counts = Counter(train_df["answer"])
    
    if min_freq > 1:
        rare_answers = {ans for ans, c in counts.items() if c < min_freq}
        train_df = train_df.copy()
        train_df["answer"] = train_df["answer"].apply(
            lambda x: "no_answer" if x in rare_answers else x
        )
        counts = Counter(train_df["answer"])
    
    if max_word_count > 0:
        train_df = train_df.copy()
        train_df["answer"] = train_df["answer"].apply(
            lambda x: "no_answer" if len(x.split()) > max_word_count else x
        )
        counts = Counter(train_df["answer"])
    
    answers = sorted(counts.keys())
    answer2id = {ans: idx for idx, ans in enumerate(answers)}
    id2answer = {idx: ans for ans, idx in answer2id.items()}
    
    return answer2id, id2answer, train_df

def classify_question_type(question: str) -> str:
    question_lower = question.lower().strip()
    
    if question_lower.startswith('what'):
        if 'color' in question_lower or 'colour' in question_lower:
            return 'what_color'
        elif 'text' in question_lower or 'say' in question_lower or 'read' in question_lower:
            return 'what_text'
        elif 'number' in question_lower or 'count' in question_lower or 'many' in question_lower:
            return 'what_number'
        else:
            return 'what_other'
    elif question_lower.startswith('where'):
        return 'where'
    elif question_lower.startswith('when'):
        return 'when'
    elif question_lower.startswith('who'):
        return 'who'
    elif question_lower.startswith('why'):
        return 'why'
    elif question_lower.startswith('how'):
        if 'many' in question_lower or 'much' in question_lower:
            return 'how_many'
        else:
            return 'how_other'
    elif question_lower.startswith(('is ', 'are ', 'was ', 'were ', 'do ', 'does ', 'did ', 'can ', 'could ', 'will ', 'would ')):
        return 'yes_no'
    else:
        return 'other'

def categorize_answer(answer: str) -> str:
    if answer == 'no_answer':
        return 'no_answer'
    
    answer_lower = answer.lower().strip()
    colors = ['red', 'blue', 'green', 'yellow', 'orange', 'purple', 'pink', 
              'black', 'white', 'grey', 'gray', 'brown', 'tan', 'beige']
    if answer_lower in colors:
        return 'color'
    
    if answer_lower.isdigit() or (answer_lower.replace('.', '').isdigit()):
        return 'number'
    
    if answer_lower in ['yes', 'no']:
        return 'yes_no'
    
    if len(answer.split()) == 1:
        return 'single_word'
    
    if len(answer.split()) >= 2:
        return 'phrase'
    
    return 'other'


# ============================================================================
# 4. КЛАССЫ ДЛЯ ДАТАСЕТА И МОДЕЛИ
# ============================================================================
class VQADataset(Dataset):
    def __init__(
        self,
        df,
        images_dir,
        word2id,
        answer2id=None,
        max_question_len: int = 20,
        is_train: bool = True,
        image_size: int = 224,
        use_bert: bool = False,
        augment: bool = False,
        image_normalization: str = "imagenet",
    ):
        self.df = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.word2id = word2id
        self.answer2id = answer2id
        self.max_question_len = max_question_len
        self.is_train = is_train
        self.use_bert = use_bert
        self.augment = augment

        if image_normalization == "clip":
            mean = (0.48145466, 0.4578275, 0.40821073)
            std = (0.26862954, 0.26130258, 0.27577711)
        else:
            mean = (0.485, 0.456, 0.406)
            std = (0.229, 0.224, 0.225)

        if self.augment:
            img_tfms = [
                T.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
                T.RandomHorizontalFlip(),
                T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            ]
        else:
            img_tfms = [
                T.Resize(image_size),
                T.CenterCrop(image_size),
            ]

        self.transform = T.Compose([
            *img_tfms,
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = self.images_dir / row["image"]
        img = Image.open(img_path).convert("RGB")
        img_tensor = self.transform(img)
        question = row["question"]

        if self.use_bert:
            question_data = question
        else:
            question_data = encode_question(question, self.word2id, max_len=self.max_question_len)

        if self.is_train:
            if self.use_bert:
                assert "answer_category" in row and "question_type" in row
                answer_category = row["answer_category"]
                question_type = row["question_type"]
            else:
                answer_category = ""
                question_type = ""
            answer_str = row["answer"]
            if answer_str not in self.answer2id:
                answer_str = "no_answer"
            answer_id = self.answer2id[answer_str]
            answer_id = torch.tensor(answer_id, dtype=torch.long)
            return img_tensor, question_data, answer_id, answer_category, question_type
        else:
            qid = row["ID"]
            return img_tensor, question_data, qid, "", ""


class CrossAttentionFusion(nn.Module):
    def __init__(self, question_dim: int, image_dim: int, hidden_dim: int, num_heads: int = 8):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.q_proj = nn.Linear(question_dim, hidden_dim)
        self.k_proj = nn.Linear(image_dim, hidden_dim)
        self.v_proj = nn.Linear(image_dim, hidden_dim)
        self.attention = MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True
        )
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, question_features, image_features):
        B = question_features.shape[0]
        Q = self.q_proj(question_features).unsqueeze(1)
        K = self.k_proj(image_features).unsqueeze(1)
        V = self.v_proj(image_features).unsqueeze(1)
        attn_out, _ = self.attention(Q, K, V)
        attn_out = attn_out.squeeze(1)
        q_residual = self.q_proj(question_features)
        out = self.layer_norm(attn_out + q_residual)
        out = self.dropout(out)
        out = self.out_proj(out)
        return out

# для финального пайплайна эта модель не используется, но она базовая
class VQAClipBERT(pl.LightningModule):
    def __init__(
        self,
        clip_model_name: str = "ViT-L/14",
        bert_model_name: str = "bert-base-uncased",
        num_answers: int = 5524,
        num_answer_categories: int = 10,
        num_question_types: int = 20,
        hidden_dim: int = 1024,
        learning_rate: float = 5e-4,
        freeze_image_encoder: bool = True,
        freeze_bert: bool = True,
        use_cross_attention: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.clip_model, self.clip_preprocess = clip.load(clip_model_name)
        self.clip_visual = self.clip_model.visual

        clip_dimensions = {
            "ViT-B/32": 512,
            "ViT-B/16": 512,
            "ViT-L/14": 768,
            "RN50": 1024,
            "RN101": 512,
            "RN50x4": 640,
            "RN50x16": 768,
            "RN50x64": 1024,
        }

        if clip_model_name in clip_dimensions:
            image_feature_dim = clip_dimensions[clip_model_name]
        else:
            raise ValueError(f"Неизвестная модель изображений: {clip_model_name}")
        
        if freeze_image_encoder:
            for param in self.clip_visual.parameters():
                param.requires_grad = False
        
        self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
        self.bert_model = AutoModel.from_pretrained(bert_model_name)
        
        if freeze_bert:
            for param in self.bert_model.parameters():
                param.requires_grad = False
        
        question_feature_dim = self.bert_model.config.hidden_size
        
        # Эмбеддинги для категорий ответов и типов вопросов
        self.answer_categories = nn.Embedding(num_answer_categories, question_feature_dim)
        self.question_types = nn.Embedding(num_question_types, question_feature_dim)
        
        self.answer_category2id = None
        self.question_type2id = None

        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.fusion = CrossAttentionFusion(
                question_dim=question_feature_dim,
                image_dim=image_feature_dim,
                hidden_dim=hidden_dim,
                num_heads=8
            )
            fusion_output_dim = hidden_dim
        else:
            fusion_input_dim = image_feature_dim + question_feature_dim
            self.fusion = nn.Sequential(
                nn.Linear(fusion_input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.3),
            )
            fusion_output_dim = hidden_dim // 2
        
        self.classifier = nn.Sequential(
            nn.Linear(fusion_output_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_answers),
        )
        
        self.train_acc = Accuracy(task="multiclass", num_classes=num_answers)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_answers)
        self.learning_rate = learning_rate

    def forward(self, images, questions, answer_categories, question_types):
        # удалил для сокращения размера файла, т.к. этот код не используется
        raise NotImplementedError("forward is not implemented")

    def training_step(self, batch, batch_idx):
        raise NotImplementedError("training_step is not implemented")

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError("validation_step is not implemented")

    def configure_optimizers(self):
        raise NotImplementedError("configure_optimizers is not implemented")


class VQAClipBERTpred(VQAClipBERT):
    def __init__(
        self,
        clip_model_name: str = "ViT-L/14",
        bert_model_name: str = "bert-base-uncased",
        num_answers: int = 5524,
        num_answer_categories: int = 10,
        num_question_types: int = 20,
        hidden_dim: int = 1024,
        learning_rate: float = 2e-5,
        freeze_image_encoder: bool = True,
        freeze_bert: bool = True,
        use_cross_attention: bool = False,
        category_loss_weight: float = 0.4,
        type_loss_weight: float = 0.15
    ):
        super().__init__(
            clip_model_name=clip_model_name,
            bert_model_name=bert_model_name,
            num_answers=num_answers,
            num_answer_categories=num_answer_categories,
            num_question_types=num_question_types,
            hidden_dim=hidden_dim,
            learning_rate=learning_rate,
            freeze_image_encoder=freeze_image_encoder,
            freeze_bert=freeze_bert,
            use_cross_attention=use_cross_attention,
        )
        # веса для multi-task loss
        self.category_loss_weight = category_loss_weight
        self.type_loss_weight = type_loss_weight
        
        if use_cross_attention:
            fusion_output_dim = hidden_dim
        else:
            fusion_output_dim = hidden_dim // 2  # так уж сложилось в базовой модели
        
        # обучаемые предикторы для категорий ответов и типов вопросов
        self.category_predictor = nn.Sequential(
            nn.Linear(fusion_output_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_answer_categories),
        )
        
        self.type_predictor = nn.Sequential(
            nn.Linear(fusion_output_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_question_types),
        )
    
    def forward(self, images, questions, answer_categories, question_types, predict_auxiliary=False):
        images = images.to(dtype=self.clip_model.dtype)
        img_features = self.clip_visual(images)
        img_features = img_features.float()
        
        if isinstance(questions, tuple):
            questions_list = list(questions)
        elif isinstance(questions, list):
            questions_list = questions
        elif isinstance(questions, str):
            questions_list = [questions]
        else:
            raise ValueError(f"Неожиданный формат вопросов: {type(questions)}")
        
        encoded = self.bert_tokenizer(
            questions_list,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=77,
        ).to(images.device)

        with torch.set_grad_enabled(not self.hparams.freeze_bert):
            bert_output = self.bert_model(**encoded)
            q_features = bert_output.last_hidden_state[:, 0, :]
        
        need_predict = predict_auxiliary or (len(answer_categories) == 0 or answer_categories[0] == "")
        
        # 2-х проходный фьюжн
        # 1-й проход: эмбеддинг BERT + эмбеддинги меток из датасета -> cross-attention
        # 2-й проход: добавляем эмбеддинги предсказаний предикторов -> cross-attention
        
        # 1-й проход
        if not need_predict:
            ac_list = list(answer_categories) if isinstance(answer_categories, (tuple, list)) else [answer_categories]
            if self.answer_category2id is not None:
                ac_indices = [self.answer_category2id.get(ac, 0) for ac in ac_list]
            else:
                ac_indices = [hash(ac) % self.hparams.num_answer_categories for ac in ac_list]
            ac_indices_tensor = torch.tensor(ac_indices, dtype=torch.long, device=images.device)
            ac_features = self.answer_categories(ac_indices_tensor)
            
            qt_list = list(question_types) if isinstance(question_types, (tuple, list)) else [question_types]
            if self.question_type2id is not None:
                qt_indices = [self.question_type2id.get(qt, 0) for qt in qt_list]
            else:
                qt_indices = [hash(qt) % self.hparams.num_question_types for qt in qt_list]
            qt_indices_tensor = torch.tensor(qt_indices, dtype=torch.long, device=images.device)
            qt_features = self.question_types(qt_indices_tensor)
            
            txt_features = q_features + ac_features + qt_features
        else:
            txt_features = q_features
        
        if self.use_cross_attention:
            fused_features_initial = self.fusion(txt_features, img_features)
        else:
            combined_initial = torch.cat([img_features, txt_features], dim=1)
            fused_features_initial = self.fusion(combined_initial)
        
        # 2-й проход
        if need_predict:
            category_logits = self.category_predictor(fused_features_initial)
            type_logits = self.type_predictor(fused_features_initial)
            
            ac_pred_indices = category_logits.argmax(dim=1)
            qt_pred_indices = type_logits.argmax(dim=1)
            
            ac_features_pred = self.answer_categories(ac_pred_indices)
            qt_features_pred = self.question_types(qt_pred_indices)
            
            txt_features_final = q_features + ac_features_pred + qt_features_pred
            
            if self.use_cross_attention:
                fused_features_final = self.fusion(txt_features_final, img_features)
            else:
                combined_final = torch.cat([img_features, txt_features_final], dim=1)
                fused_features_final = self.fusion(combined_final)
        else:
            fused_features_final = fused_features_initial
            category_logits = None
            type_logits = None
        
        logits = self.classifier(fused_features_final)
        
        if predict_auxiliary or need_predict:
            return logits, category_logits, type_logits
        else:
            return logits
    
    def training_step(self, batch, batch_idx):
        images, questions, answers, answer_categories, question_types = batch
        
        if self.answer_category2id is not None:
            ac_indices = torch.tensor([
                self.answer_category2id.get(ac, 0) for ac in answer_categories
            ], dtype=torch.long, device=images.device)
        else:
            ac_indices = torch.zeros(len(answer_categories), dtype=torch.long, device=images.device)
        
        if self.question_type2id is not None:
            qt_indices = torch.tensor([
                self.question_type2id.get(qt, 0) for qt in question_types
            ], dtype=torch.long, device=images.device)
        else:
            qt_indices = torch.zeros(len(question_types), dtype=torch.long, device=images.device)
        
        logits, category_logits, type_logits = self(
            images, questions, answer_categories, question_types,
            predict_auxiliary=True
        )
        
        loss_answer = F.cross_entropy(logits, answers)
        loss_category = F.cross_entropy(category_logits, ac_indices)
        loss_type = F.cross_entropy(type_logits, qt_indices)
        
        total_loss = loss_answer + self.category_loss_weight * loss_category + self.type_loss_weight * loss_type
        
        batch_size = answers.shape[0]
        self.train_acc(logits, answers)
        self.log("train_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        self.log("train_loss_answer", loss_answer, on_step=True, on_epoch=True, batch_size=batch_size)
        self.log("train_loss_category", loss_category, on_step=True, on_epoch=True, batch_size=batch_size)
        self.log("train_loss_type", loss_type, on_step=True, on_epoch=True, batch_size=batch_size)
        self.log("train_acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        images, questions, answers, answer_categories, question_types = batch
        
        logits, category_logits, type_logits = self(
            images, questions, answer_categories, question_types,
            predict_auxiliary=True
        )
        
        loss = F.cross_entropy(logits, answers)
        
        if self.answer_category2id is not None:
            ac_indices = torch.tensor([
                self.answer_category2id.get(ac, 0) for ac in answer_categories
            ], dtype=torch.long, device=images.device)
            category_acc = (category_logits.argmax(dim=1) == ac_indices).float().mean()
            self.log("val_category_acc", category_acc, on_step=False, on_epoch=True, batch_size=answers.shape[0])
        
        if self.question_type2id is not None:
            qt_indices = torch.tensor([
                self.question_type2id.get(qt, 0) for qt in question_types
            ], dtype=torch.long, device=images.device)
            type_acc = (type_logits.argmax(dim=1) == qt_indices).float().mean()
            self.log("val_type_acc", type_acc, on_step=False, on_epoch=True, batch_size=answers.shape[0])
        
        batch_size = answers.shape[0]
        self.val_acc(logits, answers)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=8,
            T_mult=2,
            eta_min=1e-6,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }


def get_submission_final(model, test_df, images_dir, submission_path, id2answer, batch_size=128, num_workers=4):
    """Простая функция для инференса без DataModule"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model = model.float()
    model.eval()
    
    # Создаём только test датасет
    test_dataset = VQADataset(
        df=test_df,
        images_dir=images_dir,
        word2id=None,  # не используется для BERT
        answer2id=None,
        max_question_len=20,  # не важно для BERT
        is_train=False,
        use_bert=True,
        augment=False,
        image_normalization="clip",
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    all_predictions = []
    all_ids = []
    
    with torch.no_grad():
        for batch in test_loader:
            images, questions, ids, _, _ = batch
            images = images.to(device)
            
            # Для BERT questions - это список строк
            questions_list = list(questions) if isinstance(questions, tuple) else questions
            result = model(images, questions_list, [], [])
            
            if isinstance(result, tuple):
                logits = result[0]
            else:
                logits = result
            
            pred_ids = logits.argmax(dim=1).cpu().numpy()
            all_predictions.extend(pred_ids)
            all_ids.extend(ids)
    
    predicted_answers = [id2answer[pred_id] for pred_id in all_predictions]
    
    submission_df = pd.DataFrame({
        "ID": all_ids,
        "answer": predicted_answers,
    })
    
    print(f"Размер submission: {len(submission_df)}")
    print(f"Колонки: {submission_df.columns.tolist()}")
    print("\nПервые несколько строк:")
    print(submission_df.head(10))
    
    submission_df.to_csv(submission_path, index=False)
    print(f"\nSubmission сохранён в {submission_path}")
    
    sample_sub = pd.read_csv(SAMPLE_SUB_PATH)
    print("Проверка формата:")
    print(f"  Размер sample_submission: {len(sample_sub)}")
    print(f"  Размер нашего submission: {len(submission_df)}")
    print(f"  Совпадают ID: {set(submission_df['ID']) == set(sample_sub['ID'])}")


# ============================================================================
# 5. ОСНОВНОЙ ПАЙПЛАЙН
# ============================================================================
print("=" * 80)
print("ВОСПРОИЗВОДИМЫЙ ПАЙПЛАЙН: Формирование финального submission.csv")
print("=" * 80)

# 5.1. Загрузка данных
print("\n1. Загрузка данных...")
train = pd.read_csv(TRAIN_PATH)
test = pd.read_csv(TEST_PATH)
print(f"   Train: {len(train)} примеров")
print(f"   Test: {len(test)} примеров")

# 5.2. Подготовка данных
print("\n2. Подготовка данных...")
MAX_WORD_COUNT = 10
answer2id, id2answer, train_processed = build_answer_vocab(train, min_freq=2, max_word_count=MAX_WORD_COUNT)
print(f"   Размер словаря ответов: {len(answer2id)}")

# Добавляем категории и типы вопросов
train_processed['answer_category'] = train_processed['answer'].apply(categorize_answer)
train_processed['question_type'] = train_processed['question'].apply(classify_question_type)

# Создаём словари для категорий и типов
answer_categories_unique = sorted(train_processed["answer_category"].unique())
question_types_unique = sorted(train_processed["question_type"].unique())
answer_category2id = {cat: idx for idx, cat in enumerate(answer_categories_unique)}
question_type2id = {qt: idx for idx, qt in enumerate(question_types_unique)}
print(f"   Категорий ответов: {len(answer_category2id)}")
print(f"   Типов вопросов: {len(question_type2id)}")

# 5.3. Загрузка модели
print("\n3. Загрузка модели...")
print(f"   Путь к модели: {MODEL_PATH}")
if not MODEL_PATH.exists():
    raise FileNotFoundError(f"Модель не найдена: {MODEL_PATH}")

model = VQAClipBERTpred.load_from_checkpoint(
    MODEL_PATH,
    clip_model_name="ViT-B/16",
    bert_model_name="distilbert-base-uncased",
    num_answers=len(answer2id),
    num_answer_categories=len(answer_category2id),
    num_question_types=len(question_type2id),
    hidden_dim=1024,
    learning_rate=5e-4,
    freeze_image_encoder=True,
    freeze_bert=True,
    use_cross_attention=True,
    category_loss_weight=0.45,
    type_loss_weight=0.2
)

# Устанавливаем словари
model.answer_category2id = answer_category2id
model.question_type2id = question_type2id

# 5.4. Инференс и сохранение
print("\n4. Инференс и формирование submission.csv...")
get_submission_final(model, test, IMAGES_DIR, SUBMISSION_PATH, id2answer, batch_size=128, num_workers=4)

print("\n" + "=" * 80)
print("ПАЙПЛАЙН ЗАВЕРШЁН!")
print("=" * 80)

ВОСПРОИЗВОДИМЫЙ ПАЙПЛАЙН: Формирование финального submission.csv

1. Загрузка данных...
   Train: 19873 примеров
   Test: 4969 примеров

2. Подготовка данных...
   Размер словаря ответов: 1274
   Категорий ответов: 6
   Типов вопросов: 12

3. Загрузка модели...
   Путь к модели: data/course_vqa_dataset/../VQAClipBERTpred.ckpt


'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: a9ba974e-af8f-4f19-a0eb-284f4fe191b2)')' thrown while requesting HEAD https://huggingface.co/distilbert-base-uncased/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].



4. Инференс и формирование submission.csv...
Размер submission: 4969
Колонки: ['ID', 'answer']

Первые несколько строк:
                 ID                 answer
0  001506824e76191d              no_answer
1  002677a6d6bac1dd              no_answer
2  0027a003a50877a5  microsoft corporation
3  002bf579cd0bfaed              no_answer
4  0035fa712c369df8              no_answer
5  0042dd305b63ce2d                  money
6  0043d79066103187              no_answer
7  004741a80abd988c               keyboard
8  004c59b16f26a4f7              no_answer
9  005aa83ea31a8235              no_answer

Submission сохранён в data/course_vqa_dataset/submission.csv
Проверка формата:
  Размер sample_submission: 4969
  Размер нашего submission: 4969
  Совпадают ID: True

ПАЙПЛАЙН ЗАВЕРШЁН!
