# Task 2: Caption Length Classification

- Load và tiền xử lý dữ liệu
- Chia train/test
- Train RNN, LSTM, Attention
- Đánh giá, visualize kết quả


In [ ]:
import pandas as pd
import numpy as np
from src.data_processor import Flickr8kProcessor, TextPreprocessor
from src.models.classification_models import RNNClassifier, LSTMClassifier, AttentionClassifier
from src.utils.metrics import MetricsCalculator
from src.utils.visualization import DataVisualizer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline


## 1. Load và tiền xử lý dữ liệu

In [ ]:
DATA_PATH = '../data'
processor = Flickr8kProcessor(DATA_PATH)
captions_df = processor.load_captions()
captions_df = processor.create_length_labels(captions_df)
text_prep = TextPreprocessor(language='en')
tokenized = text_prep.tokenize(captions_df['caption'].tolist())
vocab = text_prep.build_vocabulary(tokenized, vocab_size=5000)
sequences = text_prep.texts_to_sequences(captions_df['caption'].tolist(), vocab)
padded = text_prep.pad_sequences(sequences, maxlen=20)
label_map = {'short': 0, 'medium': 1, 'long': 2}
y = captions_df['length_category'].map(label_map).values
X_train, X_val, y_train, y_val = train_test_split(padded, y, test_size=0.2, random_state=42)


## 2. Train và đánh giá các mô hình

In [ ]:
results = {}
metrics = MetricsCalculator()
for ModelClass, name in zip([RNNClassifier, LSTMClassifier, AttentionClassifier], ['RNN', 'LSTM', 'Attention']):
    print(f'Training {name}...')
    model = ModelClass(vocab_size=len(vocab), num_classes=3)
    model.build_model()
    history = model.train(X_train, y_train, X_val, y_val, epochs=5, batch_size=64)
    y_pred = np.argmax(model.model.predict(X_val), axis=1)
    result = metrics.calculate_classification_metrics(y_val, y_pred)
    print(f'{name} metrics:', result)
    results[name] = result


## 3. Visualize kết quả

In [ ]:
visualizer = DataVisualizer()
for name in results:
    print(f'{name} metrics:', results[name])
# Có thể thêm visualize confusion matrix nếu muốn
