# Test model training using pytorch, cuda, and resnet50

### Extract files and masks from directory

Make backup of data before performing the following task

In [None]:
import os
import json
from PIL import Image
import torch
from torchvision import transforms
import numpy as np

# Directory paths
images_dir = 'path/to/images'
annotations_dir = 'path/to/annotations'

# List all image files
image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')])
annotation_files = sorted([f for f in os.listdir(annotations_dir) if f.endswith('.json')])

# Function to parse a single annotation
def parse_annotation(json_path):
    with open(json_path) as f:
        data = json.load(f)
    
    # Example for object detection
    boxes = []
    labels = []
    for shape in data['shapes']:
        label = shape['label']
        points = shape['points']  # [[x1, y1], [x2, y2]]
        x_coords = [p[0] for p in points]
        y_coords = [p[1] for p in points]
        xmin, xmax = min(x_coords), max(x_coords)
        ymin, ymax = min(y_coords), max(y_coords)
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(label)
    
    return boxes, labels


### Create dataset

In [None]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, transforms=None):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')])
        self.annotation_files = sorted([f for f in os.listdir(annotations_dir) if f.endswith('.json')])
        self.transforms = transforms
        self.label_map = self.create_label_map()

    def create_label_map(self):
        # Create a mapping from label names to integers
        labels = set()
        for ann_file in self.annotation_files:
            with open(os.path.join(self.annotations_dir, ann_file)) as f:
                data = json.load(f)
                for shape in data['shapes']:
                    labels.add(shape['label'])
        label_map = {label: idx+1 for idx, label in enumerate(sorted(labels))}  # Background is 0
        return label_map

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")
        
        # Load annotations
        ann_path = os.path.join(self.annotations_dir, self.annotation_files[idx])
        boxes, labels = parse_annotation(ann_path)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor([self.label_map[label] for label in labels], dtype=torch.int64)
        
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = torch.tensor([idx])
        
        if self.transforms:
            img = self.transforms(img)

        return img, target


### Prep dataloader

In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

# Define any transformations (you can add data augmentation here)
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Create dataset and dataloader
dataset = CustomDataset(images_dir, annotations_dir, transforms=transform)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))


### Modify model

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Load a pre-trained model for classification and return
# only the features
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# Replace the classifier with a new one
num_classes = len(dataset.label_map) + 1  # +1 for background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


### Loss function and optimizer

In [None]:
import torch.optim as optim

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)


### Train the model

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    i = 0
    for images, targets in data_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass and optimization
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        if i % 10 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], Loss: {losses.item():.4f}")
        i += 1

    print(f"Epoch {epoch} finished.")


### Save the model

In [None]:
torch.save(model.state_dict(), 'fasterrcnn_resnet50.pth')


### Sanity test

In [None]:
model.eval()
# Load an image
img = Image.open('path/to/test/image.jpg').convert("RGB")
img = transform(img).to(device)
with torch.no_grad():
    prediction = model([img])

print(prediction)
