In [None]:

import sys
import os
# Check if this is first run

first_run = not os.path.exists('/tmp/deps_installed')

if first_run:
    print("Installing dependencies...")

    # Fix for numpy conflicts
    !pip install -q --force-reinstall "numpy==1.26.4"

    # Install required packages
    !pip install -q transformers==4.36.0 timm==0.9.12 accelerate
    !pip install -q torch torchvision tqdm pillow opencv-python-headless

    # Mark installation complete
    with open('/tmp/deps_installed', 'w') as f:
        f.write('done')

    print("üîÑ Restarting kernel...")
    os.kill(os.getpid(), 9)
else:
    print("‚úÖ Dependencies already installed!")

    
    # Verify versions
    import numpy as np
    import torch
    import transformers
    print(f"   NumPy: {np.__version__}")
    print(f"   PyTorch: {torch.__version__}")
    print(f"   Transformers: {transformers.__version__}")
    
    import shutil
    total, used, free = shutil.disk_usage("/")
    print(f"\nüíæ Free Space: {free // (2**30)} GB")
    
    # Clean up flag for next run
    if os.path.exists('/tmp/deps_installed'):
        os.remove('/tmp/deps_installed')



‚úÖ Dependencies already installed!
   NumPy: 2.2.6
   PyTorch: 2.6.0+cu124
   Transformers: 4.36.0

üíæ Free Space: 1495 GB


  _torch_pytree._register_pytree_node(


In [2]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import T5ForConditionalGeneration, T5Config, T5Tokenizer
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
import timm
from PIL import Image
from torchvision import transforms
import json
import os
import pandas as pd
import glob
from tqdm import tqdm
import io

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üöÄ Using device: {device}")

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


üöÄ Using device: cuda


In [None]:
# ============================================================================
# CELL 1: Load A-OKVQA from Parquet Files
# ============================================================================
import os
import json
import pandas as pd
import glob

for dir_path in ['data/aokvqa', 'data/coco', 'models', 'outputs']:
    os.makedirs(dir_path, exist_ok=True)

print("Loading A-OKVQA from parquet files...")

# Find parquet files
parquet_paths = glob.glob('/kaggle/input/*/train-*.parquet') + glob.glob('/kaggle/input/*/validation-*.parquet')

if not parquet_paths:
    print("‚ùå No parquet files found")
    raise Exception("Parquet files not found")

print(f"Found {len(parquet_paths)} parquet files")

train_files = sorted([f for f in parquet_paths if 'train-' in f])
val_files = sorted([f for f in parquet_paths if 'validation-' in f])

print(f"  Train files: {len(train_files)}")
print(f"  Val files: {len(val_files)}")

# Inspect schema
print("\nüîç Inspecting parquet schema...")
sample_df = pd.read_parquet(parquet_paths[0])
print(f"Columns: {list(sample_df.columns)}")
print(f"Sample image field type: {type(sample_df.iloc[0]['image'])}")

def convert_parquet_to_json(parquet_files, output_file):
    """Read parquet files and convert to JSON format"""
    all_data = []
    
    for pq_file in parquet_files:
        print(f"  Reading {os.path.basename(pq_file)}...")
        df = pd.read_parquet(pq_file)
        
        for idx, row in df.iterrows():
            # Get question_id
            question_id = str(row.get('question_id', idx))
            
            # Get image_id from image field
            image_field = row.get('image', {})
            if isinstance(image_field, dict):
                # Try to get image_id from bytes or metadata
                # For A-OKVQA, we need to extract COCO image_id
                # Usually stored in the dict or we can derive from question_id
                # For now, use a placeholder - will be mapped later
                image_id = hash(question_id) % 1000000  # Temp placeholder
            else:
                image_id = 0
            
            # Get other fields
            question = str(row.get('question', ''))
            choices = list(row.get('choices', []))
            correct_idx = int(row.get('correct_choice_idx', -1))
            
            # Get rationales - handle as list
            rationales_raw = row.get('rationales', [])
            if isinstance(rationales_raw, list):
                rationales = rationales_raw
            elif rationales_raw is None or (isinstance(rationales_raw, float) and pd.isna(rationales_raw)):
                rationales = []
            else:
                rationales = [str(rationales_raw)]
            
            item = {
                'question_id': question_id,
                'image_id': image_id,
                'question': question,
                'choices': choices,
                'correct_choice_idx': correct_idx,
                'rationales': rationales
            }
            all_data.append(item)
    
    # Save as JSON
    with open(output_file, 'w') as f:
        json.dump(all_data, f)
    
    return len(all_data)

# Convert files
if train_files:
    print("\nConverting training data...")
    train_count = convert_parquet_to_json(train_files, 'data/aokvqa/aokvqa_v1p0_train.json')
    print(f"‚úÖ Train: {train_count} samples")

if val_files:
    print("\nConverting validation data...")
    val_count = convert_parquet_to_json(val_files, 'data/aokvqa/aokvqa_v1p0_val.json')
    print(f"‚úÖ Val: {val_count} samples")

print("\n‚úÖ A-OKVQA dataset ready!")

# Now we need to map image_ids to COCO IDs
# The image field contains the actual image bytes
# We'll need to save these images or use COCO dataset mapping
print("\n‚ö†Ô∏è  Note: Images are embedded in parquet as bytes")
print("   We'll extract them during feature extraction")

# Display sample
print("\nüìä Sample data:")
with open('data/aokvqa/aokvqa_v1p0_val.json', 'r') as f:
    sample = json.load(f)[0]
    print(f"  Question: {sample['question']}")
    print(f"  Choices: {sample['choices']}")
    print(f"  Correct: {sample['correct_choice_idx']}")
    print(f"  Rationales: {len(sample['rationales'])} provided")


Loading A-OKVQA from parquet files...
Found 3 parquet files
  Train files: 2
  Val files: 1

üîç Inspecting parquet schema...
Columns: ['image', 'question_id', 'question', 'choices', 'correct_choice_idx', 'direct_answers', 'difficult_direct_answer', 'rationales']
Sample image field type: <class 'dict'>

Converting training data...
  Reading train-00000-of-00002-c1d24de3bacb5e0c.parquet...
  Reading train-00001-of-00002-6b4f3abe2dc385d0.parquet...
‚úÖ Train: 17056 samples

Converting validation data...
  Reading validation-00000-of-00001-b2bd0de231b6326a.parquet...
‚úÖ Val: 1145 samples

‚úÖ A-OKVQA dataset ready!

‚ö†Ô∏è  Note: Images are embedded in parquet as bytes
   We'll extract them during feature extraction

üìä Sample data:
  Question: What is in the motorcyclist's mouth?
  Choices: ['toothpick', 'food', 'popsicle stick', 'cigarette']
  Correct: 3
  Rationales: 1 provided


In [None]:
# ============================================================================
# CELL 2: EXTRACT VISION FEATURES
# ============================================================================
print("\n" + "="*60)
print("STEP 2: EXTRACTING VISION FEATURES")
print("="*60)

def extract_features_fast(data_file, parquet_files, output_file):
    """Fast feature extraction with GPU"""
    print(f"\nüîç Extracting features for: {data_file}")
    
    # Load data
    with open(data_file) as f:
        data = json.load(f)
    
    question_ids = {item['question_id']: item['image_id'] for item in data}
    
    # Load model
    model = timm.create_model('vit_large_patch32_384', pretrained=True, num_classes=0).eval().to(device)
    
    transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    features = {}
    
    for pq_file in parquet_files:
        print(f"  üìÇ Reading: {os.path.basename(pq_file)}")
        df = pd.read_parquet(pq_file)
        
        for idx in tqdm(range(len(df)), desc="Processing"):
            row = df.iloc[idx]
            qid = str(row['question_id'])
            
            if qid not in question_ids:
                continue
            
            img_id = str(question_ids[qid])
            if img_id in features:
                continue
            
            # Extract features
            img = Image.open(io.BytesIO(row['image']['bytes'])).convert('RGB')
            img_tensor = transform(img).unsqueeze(0).to(device)
            
            with torch.no_grad():
                feat = model.forward_features(img_tensor)
            
            features[img_id] = feat[0].cpu().tolist()
    
    # Save
    with open(output_file, 'w') as f:
        json.dump(features, f)
    
    print(f"‚úÖ Extracted {len(features)} features")
    del model
    torch.cuda.empty_cache()

# Get parquet files
train_parquets = sorted(glob.glob('/kaggle/input/*/train-*.parquet'))

# Extract features
if not os.path.exists('data/aokvqa/feat_train_1k.json'):
    extract_features_fast('data/aokvqa/train_1k.json', train_parquets, 'data/aokvqa/feat_train_1k.json')

if not os.path.exists('data/aokvqa/feat_val_200.json'):
    extract_features_fast('data/aokvqa/val_200.json', train_parquets, 'data/aokvqa/feat_val_200.json')

print("‚úÖ All features extracted!")



STEP 2: EXTRACTING VISION FEATURES

üîç Extracting features for: data/aokvqa/train_1k.json


model.safetensors:   0%|          | 0.00/1.23G [00:00<?, ?B/s]

  üìÇ Reading: train-00000-of-00002-c1d24de3bacb5e0c.parquet


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [00:47<00:00, 180.99it/s]  


  üìÇ Reading: train-00001-of-00002-6b4f3abe2dc385d0.parquet


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [00:00<00:00, 27195.96it/s]


‚úÖ Extracted 1000 features

üîç Extracting features for: data/aokvqa/val_200.json
  üìÇ Reading: train-00000-of-00002-c1d24de3bacb5e0c.parquet


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [00:09<00:00, 883.13it/s] 


  üìÇ Reading: train-00001-of-00002-6b4f3abe2dc385d0.parquet


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [00:00<00:00, 27000.69it/s]


‚úÖ Extracted 200 features
‚úÖ All features extracted!


In [4]:
!pip install -q --force-reinstall numpy==1.26.4
# Then manually restart kernel before running any other code

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
datasets 4.4.1 requires pyarrow>=21.0.0, but you have pyarrow 19.0.1 which is incompatible.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.
google-colab 1.0.0 requires notebook==6.5.7, but you have notebook 6.5.4 which is incompatible.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.2.3 which is incompatible.
google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.5 which is incompatible.
google-colab 1.0.0 requires tornado==6.4.2, but you have tornado 6.5.2 which is incompatible.
dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.
bigframes 2.12.0 requires rich<14,>=12.4.4, but you have rich 14.

In [5]:
print(f"   NumPy: {np.__version__}")


   NumPy: 2.2.6


In [None]:
# ============================================================================
# CELL 3: GENERATE CAPTIONS
# ============================================================================
print("\n" + "="*60)
print("STEP 3: GENERATING CAPTIONS")
print("="*60)

def generate_captions_fast(data_file, parquet_files, output_file):
    """Fast caption generation with GPU"""
    print(f"\nüìù Generating captions for: {data_file}")
    
    # Load data
    with open(data_file) as f:
        data = json.load(f)
    
    question_ids = set(item['question_id'] for item in data)
    
    # Load BLIP model
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").eval().to(device)
    
    captions = {}
    
    for pq_file in parquet_files:
        print(f"  üìÇ Reading: {os.path.basename(pq_file)}")
        df = pd.read_parquet(pq_file)
        
        for idx in tqdm(range(len(df)), desc="Generating"):
            row = df.iloc[idx]
            qid = str(row['question_id'])
            
            if qid not in question_ids:
                continue
            
            if qid in captions:
                continue
            
            # Generate caption
            img = Image.open(io.BytesIO(row['image']['bytes'])).convert('RGB')
            inputs = processor(img, return_tensors="pt").to(device)
            
            with torch.no_grad():
                out = model.generate(**inputs, max_length=50, num_beams=3)
            
            captions[qid] = processor.decode(out[0], skip_special_tokens=True)
    
    # Save
    with open(output_file, 'w') as f:
        json.dump(captions, f)
    
    print(f"‚úÖ Generated {len(captions)} captions")
    del model, processor
    torch.cuda.empty_cache()

# Generate captions
if not os.path.exists('data/aokvqa/cap_train_1k.json'):
    generate_captions_fast('data/aokvqa/train_1k.json', train_parquets, 'data/aokvqa/cap_train_1k.json')

if not os.path.exists('data/aokvqa/cap_val_200.json'):
    generate_captions_fast('data/aokvqa/val_200.json', train_parquets, 'data/aokvqa/cap_val_200.json')

print("‚úÖ All captions generated!")



STEP 3: GENERATING CAPTIONS

üìù Generating captions for: data/aokvqa/train_1k.json




preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

  üìÇ Reading: train-00000-of-00002-c1d24de3bacb5e0c.parquet


Generating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [06:07<00:00, 23.19it/s]  


  üìÇ Reading: train-00001-of-00002-6b4f3abe2dc385d0.parquet


Generating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [00:00<00:00, 26480.65it/s]


‚úÖ Generated 1000 captions

üìù Generating captions for: data/aokvqa/val_200.json
  üìÇ Reading: train-00000-of-00002-c1d24de3bacb5e0c.parquet


Generating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [01:13<00:00, 116.03it/s] 


  üìÇ Reading: train-00001-of-00002-6b4f3abe2dc385d0.parquet


Generating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8528/8528 [00:00<00:00, 27464.48it/s]

‚úÖ Generated 200 captions
‚úÖ All captions generated!





In [5]:
from typing import Optional, Tuple


In [None]:
# ============================================================================
# CELL 4: ACTUAL MM-COT MODEL FROM AMAZON SCIENCE REPO
# ============================================================================
class UnifiedQAModel(T5ForConditionalGeneration):
    """
    The actual MM-CoT model from Amazon Science's repository.
    Integrates vision features by concatenating them with text embeddings.
    """
    def __init__(self, config: T5Config):
        super().__init__(config)
        self.model_dim = config.d_model
        
        # Vision projection layer - projects ViT features to T5 dimension
        self.vis_proj = nn.Linear(1024, config.d_model)
        
        # LayerNorm for stability
        self.vis_layer_norm = nn.LayerNorm(config.d_model)
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        img_features: Optional[torch.FloatTensor] = None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # Encode input text
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        
        hidden_states = encoder_outputs[0]
        
        # Integrate vision features
        if img_features is not None:
            vis_embeds = self.vis_proj(img_features)
            vis_embeds = self.vis_layer_norm(vis_embeds)
            
            # Concatenate text and vision
            hidden_states = torch.cat([hidden_states, vis_embeds], dim=1)
            
            # Extend attention mask
            if attention_mask is not None:
                batch_size = attention_mask.shape[0]
                num_vis_tokens = vis_embeds.shape[1]
                vis_attention_mask = torch.ones(
                    batch_size, num_vis_tokens,
                    dtype=attention_mask.dtype,
                    device=attention_mask.device
                )
                attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1)
        
        # Prepare decoder inputs
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = self._shift_right(labels)
        
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        sequence_output = decoder_outputs[0]
        
        if self.config.tie_word_embeddings:
            sequence_output = sequence_output * (self.model_dim**-0.5)
        
        lm_logits = self.lm_head(sequence_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
        
        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output
        
        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = T5Config.from_pretrained(pretrained_model_name_or_path)
        model = cls(config)
        pretrained_dict = T5ForConditionalGeneration.from_pretrained(
            pretrained_model_name_or_path, *model_args, **kwargs
        ).state_dict()
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items() 
            if k in model_dict and v.shape == model_dict[k].shape
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        return model

print("‚úÖ MM-CoT model defined (Amazon Science architecture)")


‚úÖ MM-CoT model defined (Amazon Science architecture)


In [None]:
# ============================================================================
# CELL 5: DATASET
# ============================================================================
print("\n" + "="*60)
print("STEP 4: PREPARING DATASETS")
print("="*60)

class AOKVQADataset(Dataset):
    def __init__(self, data_file, caption_file, feature_file, tokenizer, mode='rationale'):
        with open(data_file) as f:
            data = json.load(f)
        with open(caption_file) as f:
            captions = json.load(f)
        with open(feature_file) as f:
            features = json.load(f)
        
        self.tokenizer = tokenizer
        self.mode = mode
        self.data = []
        
        for item in data:
            qid = item['question_id']
            img_id = str(item['image_id'])
            
            if img_id not in features:
                continue
            
            opts = "\n".join([f"({chr(65+i)}) {c}" for i, c in enumerate(item['choices'])])
            
            self.data.append({
                'question': item['question'],
                'caption': captions.get(qid, ""),
                'options': opts,
                'correct_idx': item.get('correct_choice_idx', -1),
                'rationale': item['rationales'][0] if item.get('rationales') else "",
                'features': features[img_id],
                'generated_rationale': item.get('generated_rationale', "")
            })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        if self.mode == 'rationale':
            inp = f"Question: {item['question']}\nCaption: {item['caption']}\nOptions:\n{item['options']}\nGenerate the rationale:"
            out = item['rationale']
        else:
            inp = f"Question: {item['question']}\nCaption: {item['caption']}\nOptions:\n{item['options']}\nRationale: {item['generated_rationale']}\nThe answer is"
            out = f" ({chr(65 + item['correct_idx'])})"
        
        inp_enc = self.tokenizer(inp, max_length=512, padding='max_length', truncation=True, return_tensors='pt')
        out_enc = self.tokenizer(out, max_length=128, padding='max_length', truncation=True, return_tensors='pt')
        
        return {
            'input_ids': inp_enc['input_ids'].squeeze(0),
            'attention_mask': inp_enc['attention_mask'].squeeze(0),
            'labels': out_enc['input_ids'].squeeze(0),
            'img_features': torch.tensor(item['features'], dtype=torch.float32)
        }

def collate_fn(batch):
    labels = torch.stack([b['labels'] for b in batch])
    labels[labels == 0] = -100
    return {
        'input_ids': torch.stack([b['input_ids'] for b in batch]),
        'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
        'labels': labels,
        'img_features': torch.stack([b['img_features'] for b in batch])
    }

print("‚úÖ Dataset ready")


STEP 4: PREPARING DATASETS
‚úÖ Dataset ready


In [8]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import T5ForConditionalGeneration, T5Config, T5Tokenizer
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import get_linear_schedule_with_warmup
from transformers.modeling_outputs import Seq2SeqLMOutput
from torch.utils.data import DataLoader, Dataset
import timm
from PIL import Image
from torchvision import transforms
import json
import os
import pandas as pd
import glob
from tqdm import tqdm
import io
from typing import Optional, Tuple

In [None]:
# ============================================================================
# CELL 6: TRAIN STAGE 1 - RATIONALE
# ============================================================================
print("\n" + "="*60)
print("STEP 5: TRAINING STAGE 1 - RATIONALE GENERATION")
print("="*60)

tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')

train_ds = AOKVQADataset('/kaggle/input/coco-1k/train_1k.json', '/kaggle/input/coco-1k/cap_train_1k.json',
                         '/kaggle/input/coco-1k/feat_train_1k.json', tokenizer, mode='rationale')
val_ds = AOKVQADataset('/kaggle/input/coco-1k/val_200.json', '/kaggle/input/coco-1k/cap_val_200.json',
                       '/kaggle/input/coco-1k/feat_val_200.json', tokenizer, mode='rationale')

# Smaller batch size + gradient accumulation
batch_size = 4
accumulation_steps = 4  # Effective batch size = 4 * 4 = 16

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=8, collate_fn=collate_fn)

print(f"Train: {len(train_ds)} samples, {len(train_loader)} batches")
print(f"Val: {len(val_ds)} samples, {len(val_loader)} batches")
print(f"Effective batch size: {batch_size} √ó {accumulation_steps} = {batch_size * accumulation_steps}")

# Clear memory before training
torch.cuda.empty_cache()

# Train with smaller model
model = UnifiedQAModel.from_pretrained('google/flan-t5-small').to(device)  # Use T5-small instead
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, 100, len(train_loader) * 5)

print(f"\nüíæ GPU Memory after model load:")
print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"   Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB\n")

best_loss = float('inf')
for epoch in range(1, 6):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}/5")):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs.loss / accumulation_steps  # Scale loss
        loss.backward()
        
        # Update every accumulation_steps
        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            # Clear cache every few steps
            if (i + 1) % (accumulation_steps * 4) == 0:
                torch.cuda.empty_cache()
        
        total_loss += loss.item() * accumulation_steps
    
    avg_train = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            val_loss += model(**batch).loss.item()
    
    avg_val = val_loss / len(val_loader)
    print(f"Epoch {epoch}: Train={avg_train:.4f}, Val={avg_val:.4f}")
    
    if avg_val < best_loss:
        best_loss = avg_val
        os.makedirs('models/rationale', exist_ok=True)
        
        # Save with PyTorch format (avoids NumPy 2.x issues)
        try:
            model.save_pretrained('models/rationale', safe_serialization=False)
        except:
            # Fallback: save state dict directly
            torch.save(model.state_dict(), 'models/rationale/pytorch_model.bin')
            model.config.save_pretrained('models/rationale')
        
        tokenizer.save_pretrained('models/rationale')
        print(f"  ‚úÖ Saved!")

print("‚úÖ Stage 1 complete!")


STEP 5: TRAINING STAGE 1 - RATIONALE GENERATION




tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Train: 1000 samples, 250 batches
Val: 200 samples, 25 batches
Effective batch size: 4 √ó 4 = 16


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


üíæ GPU Memory after model load:
   Allocated: 0.29 GB
   Reserved: 0.32 GB



Epoch 1/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:58<00:00,  4.25it/s]


Epoch 1: Train=20.9390, Val=14.2596
  ‚úÖ Saved!


Epoch 2/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:00<00:00,  4.16it/s]


Epoch 2: Train=11.2713, Val=7.9116
  ‚úÖ Saved!


Epoch 3/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:02<00:00,  4.02it/s]


Epoch 3: Train=7.1446, Val=6.2304
  ‚úÖ Saved!


Epoch 4/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:02<00:00,  4.03it/s]


Epoch 4: Train=6.1326, Val=5.8029
  ‚úÖ Saved!


Epoch 5/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:02<00:00,  4.02it/s]


Epoch 5: Train=5.8111, Val=5.6633
  ‚úÖ Saved!
‚úÖ Stage 1 complete!


In [10]:
!pip install --force-reinstall safetensors


Collecting safetensors
  Downloading safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Downloading safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (507 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m507.2/507.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: safetensors
  Attempting uninstall: safetensors
    Found existing installation: safetensors 0.5.3
    Uninstalling safetensors-0.5.3:
      Successfully uninstalled safetensors-0.5.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.36.0 which is incompatible.[0m[31m
[0mSuccessfully install

In [15]:
!pip uninstall -y numpy


Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4


In [16]:
!pip install numpy==1.26.4


Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
Installing collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
datasets 4.4.1 requires pyarrow>=21.0.0, but you have pyarrow 19.0.1 which is incompatible.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.
google-colab 1.0.0 requires notebook==6.5.7, but you have notebook 6.5.4 which is incompatible.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.2.3 which is incompatible.
google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.5 which is incompatib

In [None]:
# ============================================================================
# CELL 7: GENERATE RATIONALES FOR STAGE 2
# ============================================================================
print("\n" + "="*60)
print("STEP 6: GENERATING RATIONALES FOR STAGE 2")
print("="*60)

model_rat = UnifiedQAModel.from_pretrained('models/rationale').eval().to(device)
tokenizer_rat = T5Tokenizer.from_pretrained('models/rationale')

def add_generated_rationales(data_file, feat_file, cap_file, output_file):
    with open(data_file) as f:
        data = json.load(f)
    with open(feat_file) as f:
        feats = json.load(f)
    with open(cap_file) as f:
        caps = json.load(f)
    
    for item in tqdm(data, desc="Generating"):
        img_id = str(item['image_id'])
        if img_id not in feats:
            item['generated_rationale'] = ""
            continue
        
        opts = "\n".join([f"({chr(65+i)}) {c}" for i, c in enumerate(item['choices'])])
        inp = f"Question: {item['question']}\nCaption: {caps.get(item['question_id'], '')}\nOptions:\n{opts}\nGenerate the rationale:"
        
        inputs = tokenizer_rat(inp, return_tensors='pt', max_length=512, truncation=True).to(device)
        vis_feat = torch.tensor([feats[img_id]], dtype=torch.float32).to(device)
        
        with torch.no_grad():
            out = model_rat.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'],
                                     img_features=vis_feat, max_length=128)
        
        item['generated_rationale'] = tokenizer_rat.decode(out[0], skip_special_tokens=True)
    
    # Save to writable location
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, 'w') as f:
        json.dump(data, f)
    print(f"  ‚úÖ Saved to: {output_file}")

# Save to data/aokvqa/ (writable location)
add_generated_rationales(
    '/kaggle/input/coco-1k/train_1k.json',
    '/kaggle/input/coco-1k/feat_train_1k.json',
    '/kaggle/input/coco-1k/cap_train_1k.json',
    'data/aokvqa/train_1k_rat.json'  # Writable location
)
add_generated_rationales(
    '/kaggle/input/coco-1k/val_200.json',
    '/kaggle/input/coco-1k/feat_val_200.json',
    '/kaggle/input/coco-1k/cap_val_200.json',
    'data/aokvqa/val_200_rat.json'  # Writable location
)

print("‚úÖ Rationales generated!")

del model_rat, tokenizer_rat
torch.cuda.empty_cache()



STEP 6: GENERATING RATIONALES FOR STAGE 2


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Generating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [07:19<00:00,  2.28it/s]


  ‚úÖ Saved to: data/aokvqa/train_1k_rat.json


Generating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [01:35<00:00,  2.10it/s]


  ‚úÖ Saved to: data/aokvqa/val_200_rat.json
‚úÖ Rationales generated!


In [None]:
# ============================================================================
# CELL 8: TRAIN STAGE 2 - ANSWER
# ============================================================================
print("\n" + "="*60)
print("STEP 7: TRAINING STAGE 2 - ANSWER PREDICTION")
print("="*60)

train_ds_ans = AOKVQADataset('data/aokvqa/train_1k_rat.json', '/kaggle/input/coco-1k/cap_train_1k.json',
                             '/kaggle/input/coco-1k/feat_train_1k.json', tokenizer, mode='answer')
val_ds_ans = AOKVQADataset('data/aokvqa/val_200_rat.json', '/kaggle/input/coco-1k/cap_val_200.json',
                           '/kaggle/input/coco-1k/feat_val_200.json', tokenizer, mode='answer')

train_loader_ans = DataLoader(train_ds_ans, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader_ans = DataLoader(val_ds_ans, batch_size=8, collate_fn=collate_fn)

print(f"Train: {len(train_ds_ans)} samples")
print(f"Val: {len(val_ds_ans)} samples")

# Clear memory
torch.cuda.empty_cache()

# Train with T5-small
model_ans = UnifiedQAModel.from_pretrained('google/flan-t5-small').to(device)
optimizer = AdamW(model_ans.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, 100, len(train_loader_ans) * 5)

best_loss = float('inf')
for epoch in range(1, 6):
    model_ans.train()
    total_loss = 0
    optimizer.zero_grad()
    
    for i, batch in enumerate(tqdm(train_loader_ans, desc=f"Epoch {epoch}/5")):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model_ans(**batch)
        loss = outputs.loss / accumulation_steps
        loss.backward()
        
        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model_ans.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            if (i + 1) % (accumulation_steps * 4) == 0:
                torch.cuda.empty_cache()
        
        total_loss += loss.item() * accumulation_steps
    
    avg_train = total_loss / len(train_loader_ans)
    
    # Validation
    model_ans.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader_ans, desc="Validating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            val_loss += model_ans(**batch).loss.item()
            del batch
            torch.cuda.empty_cache()
    
    avg_val = val_loss / len(val_loader_ans)
    print(f"Epoch {epoch}: Train={avg_train:.4f}, Val={avg_val:.4f}")
    
    if avg_val < best_loss:
        best_loss = avg_val
        os.makedirs('models/answer', exist_ok=True)
        
        # Save with PyTorch format
        try:
            model_ans.save_pretrained('models/answer', safe_serialization=False)
        except:
            torch.save(model_ans.state_dict(), 'models/answer/pytorch_model.bin')
            model_ans.config.save_pretrained('models/answer')
        
        tokenizer.save_pretrained('models/answer')
        print(f"  ‚úÖ Saved!")

print("‚úÖ Stage 2 complete!")



STEP 7: TRAINING STAGE 2 - ANSWER PREDICTION
Train: 1000 samples
Val: 200 samples


Epoch 1/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:00<00:00,  4.12it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:05<00:00,  4.68it/s]


Epoch 1: Train=21.3064, Val=18.2464
  ‚úÖ Saved!


Epoch 2/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:01<00:00,  4.06it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:05<00:00,  4.72it/s]


Epoch 2: Train=15.5819, Val=11.1536
  ‚úÖ Saved!


Epoch 3/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:01<00:00,  4.04it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:05<00:00,  4.74it/s]


Epoch 3: Train=9.6564, Val=6.3703
  ‚úÖ Saved!


Epoch 4/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:01<00:00,  4.04it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:05<00:00,  4.71it/s]


Epoch 4: Train=5.0319, Val=2.8216
  ‚úÖ Saved!


Epoch 5/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [01:01<00:00,  4.04it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:05<00:00,  4.69it/s]


Epoch 5: Train=2.6071, Val=2.0943
  ‚úÖ Saved!
‚úÖ Stage 2 complete!


In [None]:
# ============================================================================
# CELL 9: EVALUATION
# ============================================================================
print("\n" + "="*60)
print("STEP 8: FINAL EVALUATION")
print("="*60)

model_rat = UnifiedQAModel.from_pretrained('models/rationale').eval().to(device)
model_ans = UnifiedQAModel.from_pretrained('models/answer').eval().to(device)
tok_rat = T5Tokenizer.from_pretrained('models/rationale')
tok_ans = T5Tokenizer.from_pretrained('models/answer')

with open('/kaggle/input/coco-1k/val_200.json') as f:
    eval_data = json.load(f)
with open('/kaggle/input/coco-1k/feat_val_200.json') as f:
    eval_feats = json.load(f)
with open('/kaggle/input/coco-1k/cap_val_200.json') as f:
    eval_caps = json.load(f)

correct = 0
total = 0

for item in tqdm(eval_data, desc="Evaluating"):
    img_id = str(item['image_id'])
    if img_id not in eval_feats:
        continue
    
    opts = "\n".join([f"({chr(65+i)}) {c}" for i, c in enumerate(item['choices'])])
    
    # Generate rationale
    inp1 = f"Question: {item['question']}\nCaption: {eval_caps.get(item['question_id'], '')}\nOptions:\n{opts}\nGenerate the rationale:"
    inputs1 = tok_rat(inp1, return_tensors='pt', max_length=512, truncation=True).to(device)
    vis = torch.tensor([eval_feats[img_id]], dtype=torch.float32).to(device)
    
    with torch.no_grad():
        out1 = model_rat.generate(input_ids=inputs1['input_ids'], attention_mask=inputs1['attention_mask'],
                                  img_features=vis, max_length=128)
    rationale = tok_rat.decode(out1[0], skip_special_tokens=True)
    
    # Generate answer
    inp2 = f"Question: {item['question']}\nCaption: {eval_caps.get(item['question_id'], '')}\nOptions:\n{opts}\nRationale: {rationale}\nThe answer is"
    inputs2 = tok_ans(inp2, return_tensors='pt', max_length=512, truncation=True).to(device)
    
    with torch.no_grad():
        out2 = model_ans.generate(input_ids=inputs2['input_ids'], attention_mask=inputs2['attention_mask'],
                                  img_features=vis, max_length=64)
    answer = tok_ans.decode(out2[0], skip_special_tokens=True)
    
    pred = next((c for c in answer if c.isalpha() and c.isupper()), '')
    true = chr(65 + item['correct_choice_idx']) if item['correct_choice_idx'] >= 0 else ''
    
    if pred == true:
        correct += 1
    total += 1

accuracy = (correct / total) * 100
print(f"\nüéØ FINAL ACCURACY: {accuracy:.2f}% ({correct}/{total})")

print("\n" + "="*60)
print("‚úÖ COMPLETE PIPELINE FINISHED!")
print("="*60)
print("\nUsing ACTUAL Amazon Science MM-CoT Architecture")
print("Model: UnifiedQAModel with vision concatenation")


STEP 8: FINAL EVALUATION


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200/200 [01:48<00:00,  1.85it/s]


üéØ FINAL ACCURACY: 32.00% (64/200)

‚úÖ COMPLETE PIPELINE FINISHED!

Using ACTUAL Amazon Science MM-CoT Architecture
Model: UnifiedQAModel with vision concatenation



