In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp "/content/drive/MyDrive/FSL/Resnet18_RetrainedV2.pth" .
!cp "/content/drive/MyDrive/FSL/support_embeddings.pt" .


: 

In [None]:
import torch
from torchvision import transforms
import torchvision
from PIL import Image
import numpy as np
import os
from google.colab import files
import ipywidgets as widgets
from IPython.display import display, clear_output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "Resnet18_RetrainedV2.pth"
CLASS_NAMES = [
    'A&B50', 'A&C&B10', 'A&C&B30', 'A&C10', 'A&C30',
    'A10', 'A30', 'A50', 'Fan', 'Noload', 'Rotor-0'
]

In [None]:
print(device)

cuda


In [None]:
class Resnet18(torch.nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        model_weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        model = torchvision.models.resnet18(weights=model_weights)
        self.encoder = torch.nn.Sequential(*list(model.children())[:-1])

        self.embedding = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True),
            torch.nn.Linear(in_features=512, out_features=embedding_dim, bias=True) ,
            torch.nn.BatchNorm1d(embedding_dim)
            )

        for param in list(self.encoder.parameters())[:-4]:
            param.requires_grad = False

    def forward(self, x):
        return torch.nn.functional.normalize(
        self.embedding(torch.flatten(self.encoder(x), 1)),
        p=2, dim=1
        )


In [None]:
def load_model():
    """Load trained model with embeddings"""
    model = Resnet18(embedding_dim=256)
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device).eval()
    return model

In [None]:
# %%
def predict(image_path, model, k_shot=5):
    """All-in-one prediction function"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # Load query image
    img = transform(Image.open(image_path).convert("RGB")).unsqueeze(0)

    # Generate support embeddings (pre-computed)
    support_embeds = torch.load("support_embeddings.pt")

    with torch.no_grad():
        # Get query embedding
        query_embed = model(img.to(device))

        # Calculate distances
        dists = torch.cdist(query_embed, support_embeds)
        probs = torch.softmax(-dists/0.1, dim=1)

        pred_idx = torch.argmax(probs).item()
        confidence = probs[0][pred_idx].item()

    return CLASS_NAMES[pred_idx], confidence


In [None]:
 # %%
# Prediction UI
model = load_model()

upload_btn = widgets.FileUpload(description="Upload Image")
predict_btn = widgets.Button(description="Predict")
output = widgets.Output()

def on_predict_click(b):
    with output:
        clear_output()
        if not upload_btn.value:
            print("Please upload an image first")
            return

        # Save uploaded file
        uploaded = next(iter(upload_btn.value))
        with open("temp.jpg", "wb") as f:
            f.write(upload_btn.value[uploaded]['content'])

        # Make prediction
        pred, conf = predict("temp.jpg", model)
        print(f"Predicted: {pred} ({conf:.1%} confidence)")

predict_btn.on_click(on_predict_click)
display(upload_btn, predict_btn, output)

FileUpload(value={}, description='Upload Image')

Button(description='Predict', style=ButtonStyle())

Output()