In [None]:
!pip install ipywidgets --quiet

import torch
from torch.utils.data import DataLoader
from IPython.display import display
import ipywidgets as widgets
from models import BayesianMLP

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
def load_model(model_path="plusdiff_model.pt"):
    model = BayesianMLP()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model


In [None]:
def classify_feature_tensor(tensor, model):
    tensor = tensor.to(device)
    with torch.no_grad():
        output = model(tensor)
        prob = torch.sigmoid(output).item()
        prediction = "AI-GENERATED 🤖" if prob >= 0.5 else "REAL 🐶"
        print(f"Prediction: {prediction}  |  Probability: {prob:.2f}")


In [None]:
upload = widgets.FileUpload(accept='.pt', multiple=False)

def on_upload_change(change):
    if upload.value:
        file_info = next(iter(upload.value.values()))
        # Save uploaded file
        with open("temp.pt", "wb") as f:
            f.write(file_info['content'])

        # Load tensor
        tensor = torch.load("temp.pt")
        if len(tensor.shape) == 1:
            tensor = tensor.unsqueeze(0)  # Add batch dimension if needed

        # Load model and classify
        model = load_model()
        classify_feature_tensor(tensor, model)

upload.observe(on_upload_change, names='value')
display(upload)
