In [None]:
import os
import numpy as np
import tensorflow as tf
from transformers import ViTFeatureExtractor, TFViTForImageClassification, create_optimizer
from datasets import load_dataset #type: ignore
import matplotlib.pyplot as plt

# Configuration
model_checkpoint = "Falconsai/nsfw_image_detection"
num_classes = 25
batch_size = 16
epochs = 10
img_size = 224
data_dir = "./dataset/newDatasetSplit"

# Load dataset using Hugging Face
dataset = load_dataset("imagefolder", data_dir=data_dir)

# Get class names
class_names = dataset['train'].features['label'].names

# Preprocessing with Hugging Face feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(model_checkpoint)

def preprocess(example):
    image = example['image'].convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="np")
    inputs['label'] = example['label']
    return inputs

# Apply preprocessing
dataset = dataset.map(preprocess, remove_columns=["image"])

# Convert to tf.data.Dataset
def to_tf_dataset(split):
    def gen():
        for example in dataset[split]:
            yield {"pixel_values": example["pixel_values"][0], "label": example["label"]}
    return tf.data.Dataset.from_generator(
        gen,
        output_signature={
            "pixel_values": tf.TensorSpec(shape=(img_size, img_size, 3), dtype=tf.float32),
            "label": tf.TensorSpec(shape=(), dtype=tf.int64)
        }
    ).shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

train_ds = to_tf_dataset("train")
val_ds = to_tf_dataset("val")
test_ds = to_tf_dataset("test")

# Load model with updated classification head
model = TFViTForImageClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_classes,
    ignore_mismatched_sizes=True,
    from_pt=True  # Use if the original model is PyTorch-based
)

# Compile model with optimizer, loss, and metrics
num_train_steps = len(train_ds) * epochs
optimizer, _ = create_optimizer(init_lr=3e-5, num_train_steps=num_train_steps, num_warmup_steps=0)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# Train the model
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

# Save the model
model.save_pretrained("./vit_finetuned_old_tamil")

# Evaluate on test set
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc*100:.2f}%")

# Plot training history
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history["accuracy"], label="Train Acc")
plt.plot(history.history["val_accuracy"], label="Val Acc")
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy") 
plt.legend()

plt.tight_layout()
plt.show()
