In [8]:
from PIL import Image
import requests
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("Organika/sdxl-detector")
model = AutoModelForImageClassification.from_pretrained("Organika/sdxl-detector")

def get_prediction(image_url, convert_to_label=False):
  """
  Predicts if an image is AI-generated or real.

  Args:
    image_url: The URL of the image to be classified.
    convert_to_label: Whether to return the prediction as a label (reliable/fake) or an integer (0/1).

  Returns:
    The prediction as a label (reliable/fake) if convert_to_label is True, otherwise an integer (0/1).
  """
  image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") # Convert image to RGB format
  inputs = processor(images=image, return_tensors="pt").to("cpu")
  outputs = model(**inputs)
  probs = outputs.logits.softmax(1)
  d = {
      0: "fake",
      1: "reliable"
  }
  if convert_to_label:
      return d[int(probs.argmax())]
  else:
      return int(probs.argmax())


# Example usage with an image URL:
image_url = "https://images.unsplash.com/photo-1580128660010-fd027e1e587a?fm=jpg&q=60&w=3000&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8Mnx8ZG9uYWxkJTIwdHJ1bXB8ZW58MHx8MHx8fDA%3D"  # Replace with your image URL
prediction = get_prediction(image_url, convert_to_label=True)
print(prediction)

reliable
