In [None]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
# Custom Dataset
class PolypDataset(Dataset):
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.image_filenames = sorted(os.listdir(image_dir))
        self.label_filenames = sorted(os.listdir(label_dir))
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        img = Image.open(img_path).convert("RGB")
        
        # Load annotations
        label_path = os.path.join(self.label_dir, self.label_filenames[idx])
        with open(label_path, "r") as f:
            boxes = []
            labels = []
            for line in f:
                data = list(map(float, line.split()))
                # Convert normalized coordinates to absolute
                label, cx, cy, w, h = data
                x_min = (cx - w / 2) * img.width
                y_min = (cy - h / 2) * img.height
                x_max = (cx + w / 2) * img.width
                y_max = (cy + h / 2) * img.height
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(int(label))  # Assuming label is 0 for polyp
        
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        # Target dictionary
        target = {"boxes": boxes, "labels": labels}
        
        # Apply transforms
        if self.transforms:
            img = self.transforms(img)
        
        return img, target

# Data Transformations
def get_transform(train):
    transforms = []
    if train:
        transforms.append(F.to_tensor)
        transforms.append(torchvision.transforms.RandomHorizontalFlip(0.5))
    else:
        transforms.append(F.to_tensor)
    return torchvision.transforms.Compose(transforms)

# Model
def get_model(num_classes):
    # Load pre-trained Faster R-CNN
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    # Get the number of input features
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # Replace the head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

# Training
def train_model(model, train_loader, optimizer, device):
    model.train()
    for images, targets in tqdm(train_loader):
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    return losses.item()

# Main Code
def main():
    # Paths
    train_image_dir = r"dataset\images\train"
    train_label_dir = r"dataset\labels\train"
    test_image_dir = "test/images"
    test_label_dir = "test/labels"

    # Dataset and DataLoader
    train_dataset = PolypDataset(train_image_dir, train_label_dir, transforms=get_transform(train=True))
    test_dataset = PolypDataset(test_image_dir, test_label_dir, transforms=get_transform(train=False))
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    print(0)
    # Model and Optimizer
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = get_model(num_classes=3)  # Background and polyp
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    print(1)
    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        loss = train_model(model, train_loader, optimizer, device)
        print(f"Epoch {epoch+1}, Loss: {loss}")

    # Save the model
    torch.save(model.state_dict(), "polyp_detection_model.pth")

if __name__ == "__main__":
    main()
