In [14]:
# 📦 Setup
import torch
from torchvision import transforms
from modular import model_builder
from PIL import Image
import requests
from io import BytesIO
from pathlib import Path
import matplotlib.pyplot as plt


In [15]:
# 🔧 Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (224, 224)  # Change if your model expects something else
CLASS_NAMES = [
    'A&B50', 'A&C&B10', 'A&C&B30', 'A&C10', 'A&C30',
    'A10', 'A30', 'A50', 'Fan', 'Noload', 'Rotor-0']
MODEL_PATH = Path("models/Resnet18_RetrainedV2.pth")


In [16]:
model = model_builder.Resnet18(embedding_dim=256)

# Load checkpoint properly
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)

# If the checkpoint is a dictionary (which it is), extract the state_dict
model.load_state_dict(checkpoint["model_state_dict"])

model = model.to(DEVICE)
model.eval()

# 🔁 Transforms (must match training transforms)
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust if RGB or different stats
])


In [19]:
def predict_image_from_url(image_url):
    try:
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    except Exception as e:
        print(f"Error loading image: {e}")
        return

    input_tensor = transform(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.softmax(outputs, dim=1)
        pred_idx = torch.argmax(probs, dim=1).item()

        # Check if pred_idx is valid
        if pred_idx >= len(CLASS_NAMES):
            print(f"Invalid prediction index: {pred_idx}")
            return

        pred_class = CLASS_NAMES[pred_idx]
        confidence = probs[0][pred_idx].item()

    plt.imshow(image)
    plt.axis('off')
    plt.title(f"{pred_class} ({confidence*100:.2f}%)")
    plt.show()

    print(f"🔮 Predicted class: {pred_class} ({confidence*100:.2f}% confidence)")

In [20]:
predict_image_from_url("https://theengineeringmindset.com/wp-content/uploads/2019/05/Induction-motor-heat-thermal-image.png")

Invalid prediction index: 170


In [23]:
dummy_input = torch.randn(1, 3, 128, 128).to(DEVICE)
print("Output shape:", model(dummy_input).shape)


Output shape: torch.Size([1, 256])
