In [None]:
from ultralytics import YOLO

model = YOLO("yolo11s-cls.pt")

results = model.train(data="../split_dataset_6_classes/", epochs=100, imgsz=224)

results

In [None]:
from ultralytics import YOLO
import matplotlib.pyplot as plt
import os
from PIL import Image
import math

# Load the trained YOLO model
model = YOLO("runs/classify/train/weights/best.pt")

# Path to test dataset
test_data_path = "../split_dataset_6_classes/test"
output_folder = "prediction_plots_test_set"  # Folder to save plots
os.makedirs(output_folder, exist_ok=True)  # Create folder if it doesn't exist

# Iterate through each class in the test dataset
for class_name in os.listdir(test_data_path):
    class_path = os.path.join(test_data_path, class_name)
    if not os.path.isdir(class_path):
        continue  # Skip non-directory entries
    
    # Get all images in the current class directory
    images = [os.path.join(class_path, img) for img in os.listdir(class_path) if img.endswith(('.png', '.jpg', '.jpeg'))]
    
    # Determine grid size
    num_images = len(images)
    num_cols = 4
    num_rows = math.ceil(num_images / num_cols)
    
    # Create a figure for the class
    plt.figure(figsize=(15, 5 * num_rows))
    plt.suptitle(f"Class: {class_name}", fontsize=16)
    
    # Plot each image
    for idx, img_path in enumerate(images):
        img = Image.open(img_path)
        pred = model.predict(source=img, imgsz=224, verbose=False)  # Predict class
        
        # Extract predicted class
        pred_class = pred[0].probs.top1
        pred_conf = pred[0].probs.top1conf.item()
        pred_label = model.names[pred_class]
        
        # Determine title color based on correctness
        title_color = "green" if pred_label == class_name else "red"
        
        # Display image and prediction
        plt.subplot(num_rows, num_cols, idx + 1)
        plt.imshow(img)
        plt.title(f"Pred: {pred_label} ({pred_conf:.2f})", fontsize=10, color=title_color)
        plt.axis('off')
    
    # Adjust layout and save the plot
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Leave space for title
    output_path = os.path.join(output_folder, f"{class_name}_plot.png")
    plt.savefig(output_path)  # Save the figure
    plt.show()
    plt.close()  # Close the figure to free memory