In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import pandas as pd
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPVisionModel, CLIPImageProcessor, GPT2LMHeadModel, GPT2Tokenizer
import json
from tqdm import tqdm

# Configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

# Paths for Earth VQA dataset
vqa_csv_path = "/kaggle/input/train-dataset-csv/filtered_dataset.csv"  # Update with your path
vqa_image_folder = "/kaggle/input/train-dataset/dataset/Train/Train/images_png"  # Update with your path
pretrained_model_path = "/kaggle/input/vqamodel/pytorch/default/2/projection_pretrained_model"  # From pretraining

# Load Earth VQA dataset
vqa_df = pd.read_csv(vqa_csv_path)
vqa_df["image_path"] = vqa_df["image_name"].apply(lambda x: os.path.join(vqa_image_folder, x))
vqa_df = vqa_df[vqa_df["image_path"].apply(os.path.exists)].reset_index(drop=True)

# Load pretrained components
clip_processor = CLIPImageProcessor.from_pretrained(pretrained_model_path)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_path)
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

# Load model config
with open(os.path.join(pretrained_model_path, "model_config.json"), "r") as f:
    model_config = json.load(f)

# Recreate model architecture
class VisionTextModel(nn.Module):
    def __init__(self, vision_encoder, llm, projection_in_dim, projection_out_dim, num_image_tokens=16):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.llm = llm
        self.projection = nn.Sequential(
            nn.Linear(projection_in_dim, num_image_tokens * projection_out_dim),
            nn.LayerNorm(num_image_tokens * projection_out_dim)
        )
        self.num_image_tokens = num_image_tokens

    def forward(self, images, input_ids, attention_mask=None, labels=None):
        # Process images
        vision_output = self.vision_encoder(images)
        image_features = vision_output.last_hidden_state.mean(dim=1)
        
        # Project features
        projected = self.projection(image_features)
        projected = projected.view(-1, self.num_image_tokens, self.projection[0].out_features // self.num_image_tokens)
        projected = projected * 0.1  # Stabilize training

        # Get text embeddings
        text_embeds = self.llm.get_input_embeddings()(input_ids)
        
        # Combine embeddings
        inputs_embeds = torch.cat([projected, text_embeds], dim=1)
        
        # Adjust attention mask and labels
        if attention_mask is not None:
            image_mask = torch.ones(attention_mask.shape[0], self.num_image_tokens, device=attention_mask.device)
            attention_mask = torch.cat([image_mask, attention_mask], dim=1)
        
        if labels is not None:
            image_labels = torch.full((labels.shape[0], self.num_image_tokens), -100, device=labels.device)
            labels = torch.cat([image_labels, labels], dim=1)
        
        # Forward through LLM
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

# Initialize models
clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")

# Create multimodal model
model = VisionTextModel(
    vision_encoder=clip_vision_model,
    llm=gpt2_model,
    projection_in_dim=model_config["projection_in_dim"],
    projection_out_dim=model_config["projection_out_dim"],
    num_image_tokens=32
)

# Load pretrained weights
model.load_state_dict(torch.load(os.path.join(pretrained_model_path, "pytorch_model.bin")))
model.to(device)

# Enable DataParallel for multi-GPU training
if num_gpus > 1:
    model = nn.DataParallel(model)
    print(f"Using {num_gpus} GPUs for training!")

# Dataset class for VQA
class VQADataset(Dataset):
    def __init__(self, dataframe, image_processor, tokenizer, max_length=256):
        self.data = dataframe
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        image = Image.open(item["image_path"]).convert("RGB")
        pixel_values = self.image_processor(image, return_tensors="pt")["pixel_values"][0]
        
        # Format: "User: {question}\nAssistant: {answer}"
        text = f"User: {item['question']}\nAssistant: {item['answer']}"
        
        tokenized = self.tokenizer(
            text,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
            padding="max_length"
        )
        
        return {
            "pixel_values": pixel_values,
            "input_ids": tokenized["input_ids"][0],
            "attention_mask": tokenized["attention_mask"][0]
        }

# Collate function
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = input_ids.clone()
    
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

# Create dataset and dataloader
vqa_dataset = VQADataset(
    dataframe=vqa_df,
    image_processor=clip_processor,
    tokenizer=gpt2_tokenizer,
    max_length=256
)

vqa_dataloader = DataLoader(
    vqa_dataset,
    batch_size=8,  # Adjust based on GPU memory
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)



2025-07-31 10:33:50.977128: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753958031.331230      18 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753958031.434529      18 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Using 2 GPUs for training!


In [3]:
# Training setup
model.train()
for param in model.parameters():
    param.requires_grad = True  # Unfreeze all parameters

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 5
gradient_accumulation_steps = 2
total_steps = num_epochs * len(vqa_dataloader) // gradient_accumulation_steps

# Training loop
progress_bar = tqdm(total=total_steps, desc="Fine-tuning")
global_step = 0

for epoch in range(num_epochs):
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(vqa_dataloader):
        # Move batch to device
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        # Forward pass
        outputs = model(
            images=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        # Handle multi-GPU output
        loss = outputs.loss.mean() if num_gpus > 1 else outputs.loss
        loss = loss / gradient_accumulation_steps
        loss.backward()
        
        # Update progress
        global_step += 1
        progress_bar.set_postfix({"loss": f"{loss.item() * gradient_accumulation_steps:.4f}"})
        
        # Gradient accumulation
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)

progress_bar.close()
print("Fine-tuning completed!")

# Save the fine-tuned model
output_dir = "/kaggle/working/earth_vqa_finetuned_model"
os.makedirs(output_dir, exist_ok=True)

# Save model weights
model_to_save = model.module if num_gpus > 1 else model
torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))

# Save tokenizer and config
gpt2_tokenizer.save_pretrained(output_dir)
clip_processor.save_pretrained(output_dir)

with open(os.path.join(output_dir, "model_config.json"), "w") as f:
    json.dump(model_config, f)

print(f"Model saved to {output_dir}")


Fine-tuning:   0%|          | 0/3152 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Fine-tuning: 100%|█████████▉| 3150/3152 [1:09:15<00:02,  1.32s/it, loss=0.0131]


Fine-tuning completed!
Model saved to /kaggle/working/earth_vqa_finetuned_model
