In [2]:
import torch
import torchvision.models as models
import torch.nn as nn
import numpy as np
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             confusion_matrix as sklearn_confusion_matrix, classification_report)
import time
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from data_loader import get_cifar100_loaders
from evaluation_utils import get_cifar100_class_names


In [5]:
def plot_sklearn_confusion_matrix(cm, class_names, figsize=(20,20), filename='svm_confusion_matrix.png', normalize=False):
    if normalize:
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_to_plot = cm_normalized
        fmt = '.2f'
        title = 'Normalized Confusion Matrix (SVM)'
    else:
        cm_to_plot = cm
        fmt = 'd'
        title = 'Confusion Matrix, without normalization (SVM)'

    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=figsize)
    sns.heatmap(cm_to_plot, annot=False, fmt=fmt, cmap='Blues',
                xticklabels=class_names, yticklabels=class_names) # annot=False for 100 classes
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.title(title)
    plt.tight_layout()
    try:
        plt.savefig(filename)
        print(f"Confusion matrix saved to {filename}")
    except Exception as e:
        print(f"Error saving SVM confusion matrix: {e}")
    plt.show()


In [7]:
def extract_features(data_loader, model, device):
    model.eval()
    features_list = []
    labels_list = []
    print("Extracting features...")
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            features_list.append(outputs.cpu().numpy())
            labels_list.append(labels.cpu().numpy())

    features = np.concatenate(features_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    return features, labels


In [8]:
BATCH_SIZE = 128
IMG_SIZE = 224
DATA_DIR = './data_cifar100'
NUM_WORKERS = 4
FEATURE_EXTRACTOR_MODEL_NAME = "resnet18"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Sử dụng thiết bị: {device} for feature extraction")

class_names = get_cifar100_class_names(DATA_DIR)
num_classes = len(class_names)

print(f"Đang tải dữ liệu CIFAR-100 (resize: {IMG_SIZE}x{IMG_SIZE})...")
train_loader, test_loader, _ = get_cifar100_loaders(
    batch_size=BATCH_SIZE,
    data_dir=DATA_DIR,
    img_size=IMG_SIZE,
    use_augmentation=True,
    num_workers=NUM_WORKERS
)
if train_loader is None:
    print("Không thể tải dữ liệu. Kết thúc chương trình.")
    exit()

# --- 2. Load Pre-trained Model as Feature Extractor ---
print(f"Loading {FEATURE_EXTRACTOR_MODEL_NAME} as feature extractor...")
if FEATURE_EXTRACTOR_MODEL_NAME == "resnet18":
    weights = models.ResNet18_Weights.IMAGENET1K_V1
    feature_extractor = models.resnet18(weights=weights)
    # Remove the final fully connected layer (the classifier)
    feature_extractor.fc = nn.Identity() # Replace with an identity layer
elif FEATURE_EXTRACTOR_MODEL_NAME == "resnet50":
    weights = models.ResNet50_Weights.IMAGENET1K_V2 # Or V1
    feature_extractor = models.resnet50(weights=weights)
    feature_extractor.fc = nn.Identity()
else:
    raise ValueError(f"Unsupported feature extractor: {FEATURE_EXTRACTOR_MODEL_NAME}")

feature_extractor.to(device)
feature_extractor.eval() # Set to evaluation mode

# --- 3. Extract Features ---
start_time = time.time()
train_features, train_labels = extract_features(train_loader, feature_extractor, device)
test_features, test_labels = extract_features(test_loader, feature_extractor, device)
extraction_time = time.time() - start_time
print(f"Feature extraction completed. Time taken: {extraction_time:.2f}s")
print(f"Shape of training features: {train_features.shape}") # Should be (50000, num_output_features_from_cnn)
print(f"Shape of test features: {test_features.shape}")     # Should be (10000, num_output_features_from_cnn)

# --- 4. Train SVM Classifier ---
print("Training SVM classifier...")
start_time = time.time()

svm_pipeline = make_pipeline(
    StandardScaler(),
    SVC(kernel='rbf', C=1.0, gamma='scale', decision_function_shape='ovr', random_state=42, verbose=True)
)

svm_pipeline.fit(train_features, train_labels)
training_time = time.time() - start_time
print(f"SVM training completed. Time taken: {training_time:.2f}s")

# --- 5. Evaluate SVM ---
print("Evaluating SVM on the test set...")
start_time = time.time()
test_predictions = svm_pipeline.predict(test_features)
evaluation_time = time.time() - start_time
print(f"SVM evaluation completed. Time taken: {evaluation_time:.2f}s")

# Calculate metrics
accuracy = accuracy_score(test_labels, test_predictions)
# For Top-5 with SVM, we'd need decision_function scores and then sort, which is more complex
# For now, focusing on Top-1 as SVM doesn't directly give Top-N probabilities like NNs.

precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    test_labels, test_predictions, average='macro', zero_division=0
)
precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
    test_labels, test_predictions, average='weighted', zero_division=0
)

print("\nSVM Classification Report:")
print(classification_report(test_labels, test_predictions, target_names=class_names, zero_division=0))

print(f"\nSVM Evaluation Summary:")
print(f"  Top-1 Accuracy: {accuracy:.4f}")
print(f"  Macro Avg - Precision: {precision_macro:.4f}, Recall: {recall_macro:.4f}, F1-Score: {f1_macro:.4f}")
print(f"  Weighted Avg - Precision: {precision_weighted:.4f}, Recall: {recall_weighted:.4f}, F1-Score: {f1_weighted:.4f}")

# Confusion Matrix
cm = sklearn_confusion_matrix(test_labels, test_predictions, labels=range(num_classes))
plot_sklearn_confusion_matrix(cm, class_names,
                                filename=f'svm_{FEATURE_EXTRACTOR_MODEL_NAME}_features_cm.png',
                                fiCgsize=(25,25)) # Adjust size if needed

Sử dụng thiết bị: cpu for feature extraction
Files already downloaded and verified
Đang tải dữ liệu CIFAR-100 (resize: 224x224)...
Files already downloaded and verified
Files already downloaded and verified
Đã tải xong CIFAR-100.
Số lượng ảnh Train: 50000
Số lượng ảnh Test: 10000
Kích thước ảnh: 224x224
Sử dụng Data Augmentation: True
Loading resnet18 as feature extractor...
Extracting features...


  0%|          | 0/391 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  0%|          | 0/391 [00:08<?, ?it/s]  File "/opt/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.12/site-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/opt/anaconda3/lib/python3.12/site-packages/torchvision/models/__init__.py", line 2, in <module>
    from .convnext import *
  File "/opt/anaconda3/lib/python3.12/site-packages/torchvision/models/convnext.py", line 8, in <module>
    from ..ops.misc import Conv2dNormActivation, Permute
  File "/opt

KeyboardInterrupt: 