In [None]:
# ⬛ Step 1: Setup
!pip install ipywidgets --quiet

import torch
from torchvision import transforms
from PIL import Image
from IPython.display import display
import io
import ipywidgets as widgets
from models import BayesianMLP

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ⬛ Step 2: Load image and preprocess
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)

# ⬛ Step 3: Load trained model
def load_model():
    model = BayesianMLP()
    model.load_state_dict(torch.load("plusdiff_model.pt", map_location=device))
    model.eval()
    return model

# ⬛ Step 4: Classify image
def classify_uploaded_image(image_bytes, model):
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    image_tensor = preprocess_image(image).to(device)
    with torch.no_grad():
        output = model(image_tensor)
        prediction = torch.argmax(output, dim=1).item()
        label = "REAL 🐶" if prediction == 0 else "AI-GENERATED 🤖"
        print(f"Prediction: {label}")
        display(image)

# ⬛ Step 5: Upload widget
upload = widgets.FileUpload(accept='image/*', multiple=False)

def on_upload_change(change):
    if upload.value:
        image_info = next(iter(upload.value.values()))
        model = load_model()
        classify_uploaded_image(image_info['content'], model)

upload.observe(on_upload_change, names='value')
display(upload)
