In [None]:
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from torchvision import transforms
import io

# Load your PyTorch model
model_path = "resnet18_finetuned.pth"  # Replace with your model's file path
model = torch.load(model_path)
model.eval()

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Replace with your model's input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize FastAPI app
app = FastAPI()

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        # Read the uploaded image
        content = await file.read()
        image = Image.open(io.BytesIO(content)).convert("RGB")

        # Apply transformations
        input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

        # Perform prediction
        with torch.no_grad():
            output = model(input_tensor)
            predicted_class = torch.argmax(output, dim=1).item()

        # Return prediction result
        return {"predicted_class": predicted_class}

    except Exception as e:
        return JSONResponse(status_code=500, content={"error": str(e)})
