In [10]:
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms.functional import to_tensor
from torchvision.transforms.functional import to_pil_image
from PIL import Image, ImageDraw
import os


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load model with custom number of classes
def get_model(num_classes, model_path):
    model = fasterrcnn_resnet50_fpn(weights="DEFAULT", box_detections_per_img=400)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()
    return model

# Load an image and convert to tensor
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return to_tensor(image)

# Perform detection and return image with annotations
def detect_and_annotate(model, image_tensor, confidence_threshold=0.3):
    image_tensor = image_tensor.to(device)
    # Ensure the input tensor is correctly formatted for the model
    image_tensor = image_tensor.unsqueeze(0)  # Add batch dimension

    model.to(device)
    
    with torch.no_grad():
        predictions = model(image_tensor)  # Pass the batched tensor
    prediction = predictions[0]

    # Convert tensor to PIL Image, ensuring it's correctly formatted
    pil_image = to_pil_image(image_tensor.squeeze(0).mul(255).to('cpu').byte())

    # Draw bounding boxes on the image
    draw = ImageDraw.Draw(pil_image)
    for box, score in zip(prediction['boxes'], prediction['scores']):
        if score > confidence_threshold:
            box = box.to('cpu').tolist()
            label = f"{score:.2f}"
            draw.rectangle(box, outline="red", width=3)
            draw.text((box[0], box[1]), label, fill="yellow")
    return pil_image

# Main function to process the image
def main(image_path, model_path, output_path):
    num_classes = 2  # Update according to the number of classes in your dataset
    model = get_model(num_classes, model_path)
    image_tensor = load_image(image_path)
    annotated_image = detect_and_annotate(model, image_tensor)
    annotated_image.save(output_path)

if __name__ == "__main__":
    image_path = "/home/rohit/AIXI/analysis/single_image_gen/ma2035_03900354.jpg"
    model_path = '/home/rohit/AIXI/model1_8/checkpoints/run_11/model_epoch_17.pth'
    output_path = "output.jpg"
    main(image_path, model_path, output_path)
