In [None]:
import os
import random
from PIL import Image
import torch
from torchvision import transforms
from torchvision.models import efficientnet_v2_s
import matplotlib.pyplot as plt

# --- Configurations ---
data_dir = '/kaggle/input/weather-dataset/dataset'  # Root of all category folders
model_path = '/content/efficientnet_v2_s_weather.pth'   # Update for your model path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Get class names from subfolders
class_names = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
num_classes = len(class_names)

# -- Load the model --
model = efficientnet_v2_s(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

# -- Image transform (should match your train/test pipeline) --
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# -- Collect random images from all categories --
all_paths = []
for cat in class_names:
    cat_folder = os.path.join(data_dir, cat)
    images = [os.path.join(cat_folder, f) for f in os.listdir(cat_folder)
              if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    all_paths.extend([(img, cat) for img in images])

chosen = random.sample(all_paths, 7) if len(all_paths) >= 7 else all_paths

# -- Predict and show images --
plt.figure(figsize=(18,6))
for idx, (img_path, true_cat) in enumerate(chosen):
    img = Image.open(img_path).convert('RGB')
    inp = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(inp)
        pred_idx = out.argmax(dim=1).item()
    pred_label = class_names[pred_idx]
    plt.subplot(2, 4, idx+1)
    plt.imshow(img)
    plt.title(f"True: {true_cat}\nPred: {pred_label}", color='green' if pred_label==true_cat else 'red')
    plt.axis('off')
plt.tight_layout()
plt.show()
