In [None]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import torch

# Define constants
num_samples = 2500  # Number of images
image_shape = (32, 32, 3)  # Original image shape
num_classes = 10  # Number of classes
epochs = 5  # Number of epochs
data = torch.load('../dataset/dataset/part_one_dataset/train_data/1_train_data.tar.pth')

# Generate synthetic image data for demonstration (random RGB images)
images = data['data']

# Generate synthetic target labels (random integers in range of num_classes)
targets = data['targets']
# Convert integer labels to one-hot encoded vectors
one_hot_labels = tf.keras.utils.to_categorical(targets, num_classes=num_classes)

# Resize images to the required input size for ResNet50V2 (224, 224, 3)
resized_images = tf.image.resize(images, (224, 224))

# Preprocess images using ResNet50V2 preprocessing
processed_images = preprocess_input(resized_images)

# Load ResNet50V2 as the base model with pre-trained ImageNet weights
base_model = ResNet50V2(weights="imagenet", include_top=False, input_shape=(224, 224, 3))

# Freeze the base model to use it as a feature extractor
base_model.trainable = False

# Add custom layers on top of the base model
inputs = Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)  # Pass input through the base model
x = GlobalAveragePooling2D()(x)  # Pool the output feature maps into a single feature vector
outputs = Dense(num_classes, activation="softmax")(x)  # Output layer for classification

# Define the complete model
model = Model(inputs, outputs)

# Compile the model
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

# Define callbacks to manage training
callbacks = [
    EarlyStopping(monitor="loss", patience=10, restore_best_weights=True),
    ReduceLROnPlateau(monitor="loss", factor=0.1, patience=5, min_lr=1e-6)
]

# Train the model for feature extraction
print("\nTraining the model:")
model.fit(
    processed_images,
    one_hot_labels,
    epochs=epochs,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)

# Create a feature extractor from the trained ResNet50V2 base model
print("\nExtracting features using the base model:")
feature_extractor = Model(inputs=base_model.input, outputs=base_model.output)

# Use the feature extractor to extract features from the images
features = feature_extractor.predict(processed_images, batch_size=32)

# Print the shape of the extracted features
print(f"Extracted features shape: {features.shape}")  # Example: (2500, 7, 7, 2048)