# Qwen 2.5 VLM Fine-tuning for Movie Trailer Tag Classification

Multi-label classification of movie trailer scenes to YOLO tags with probability scores.

In [None]:
import sys
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

from tags import YOLO_TAGS, NUM_TAGS
from data_generator import generate_dataset
from model import Qwen2VLClassifier, get_processor
from inference import MovieTrailerPredictor

## 1. Generate Mock Dataset

In [None]:
data_dir = Path('data')

print(f"Total tags: {NUM_TAGS}")
print(f"Sample tags: {YOLO_TAGS[:10]}...")

train_data = generate_dataset(data_dir, num_samples=120, split='train')
val_data = generate_dataset(data_dir, num_samples=30, split='val')
test_data = generate_dataset(data_dir, num_samples=20, split='test')

print(f"\nGenerated {len(train_data)} train, {len(val_data)} val, {len(test_data)} test samples")

## 2. Visualize Sample Data

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i in range(6):
    sample = train_data[i]
    img = Image.open(sample['image_path'])
    axes[i].imshow(img)
    axes[i].set_title(f"Tags: {', '.join(sample['tags'][:3])}...", fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 3. Initialize Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = Qwen2VLClassifier(
    use_lora=True,
    lora_r=8,
    lora_alpha=16,
    use_gradient_checkpointing=True,
    pooling_strategy='attention',
    loss_type='focal'
)
model = model.to(device)

print(f"\nModel architecture:")
print(model.classifier)

## 4. Training

Run training script (this may take time):

In [None]:
!python train.py

## 5. Load Trained Model and Inference

In [None]:
checkpoint_path = Path('checkpoints/best_model.pt')

if checkpoint_path.exists():
    predictor = MovieTrailerPredictor(checkpoint_path)
else:
    print("No checkpoint found, using untrained model")
    predictor = MovieTrailerPredictor()

## 6. Inference on Test Data

In [None]:
import json

with open(data_dir / 'test' / 'metadata.json', 'r') as f:
    test_data = json.load(f)

sample = test_data[0]

print(f"Text: {sample['text']}")
print(f"Ground Truth: {', '.join(sample['tags'])}\n")

predictions = predictor.predict(sample['image_path'], sample['text'], threshold=0.3, top_k=10)

print("Predictions:")
for pred in predictions:
    print(f"  {pred['tag']:20s}: {pred['probability']:.4f}")

img = Image.open(sample['image_path'])
plt.figure(figsize=(8, 6))
plt.imshow(img)
plt.title(f"Scene: {sample['text'][:50]}...")
plt.axis('off')
plt.show()

## 7. Batch Prediction

In [None]:
batch_samples = test_data[:3]

image_paths = [s['image_path'] for s in batch_samples]
texts = [s['text'] for s in batch_samples]

batch_results = predictor.batch_predict(image_paths, texts, threshold=0.3, top_k=5)

for i, result in enumerate(batch_results):
    print(f"\n=== Sample {i+1} ===")
    print(f"Text: {result['text']}")
    print(f"Top predictions:")
    for pred in result['predictions']:
        print(f"  {pred['tag']:20s}: {pred['probability']:.4f}")

## 8. Visualize Training History

In [None]:
history_path = Path('checkpoints/history.json')

if history_path.exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    train_f1 = [m['f1'] for m in history['train_metrics']]
    val_f1 = [m['f1'] for m in history['val_metrics']]
    
    axes[1].plot(train_f1, label='Train F1')
    axes[1].plot(val_f1, label='Val F1')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('F1 Score')
    axes[1].set_title('F1 Score')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()
else:
    print("No training history found. Run training first.")