In [1]:
import os
import numpy as np
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [14]:
class VolleyballDataset(Dataset):
    def __init__(self, root_dir, annotation_dir, transform=None):
        self.root_dir = root_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.video_dirs = sorted([int(d) for d in os.listdir(root_dir) if d.isdigit()])
        self.sequence_len = 41
        self.sequence_indices = []
        for video_dir in self.video_dirs:
            frames_dir = os.path.join(annotation_dir, str(video_dir))
            frames = sorted([int(f[:-4]) for f in os.listdir(frames_dir) if f.endswith('.txt')])
            for i in range(len(frames) - self.sequence_len + 1):
                self.sequence_indices.append(frames[i + 20])
        
    def __len__(self):
        return 55
    
    def __getitem__(self, idx):
        video_dir_idx = idx // len(self.sequence_indices)
        sequence_idx = self.sequence_indices[idx % len(self.sequence_indices)]
        video_dir = os.path.join(self.root_dir, str(self.video_dirs[video_dir_idx]))
        sequence_file = os.path.join(self.annotation_dir, str(self.video_dirs[video_dir_idx]), str(sequence_idx)+'.txt')
        with open(sequence_file, 'r') as f:
            line = f.readline()
            while line.startswith('0 0'):
                line = f.readline()
            x, y = [float(coord) for coord in line.split()]
        img_dir = os.path.join(video_dir, str(sequence_idx))
        img_files = sorted(os.listdir(img_dir))
        img_files = [os.path.join(img_dir, f) for f in img_files]
        img_sequence = [Image.open(f) for f in img_files]
        if self.transform:
            img_sequence = [self.transform(img) for img in img_sequence]
        center_idx = len(img_sequence) // 2
        img_tensor = img_sequence[center_idx]
        
        boxes = torch.as_tensor([[x-5, y-5, x+5, y+5]], dtype=torch.float32)
        labels = torch.as_tensor([1], dtype=torch.int64)
        image_id = torch.tensor(idx)
        area = (boxes[0, 3] - boxes[0, 1]) * (boxes[0, 2] - boxes[0, 0])
        is_crowd = torch.zeros((1,), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = is_crowd

        return img_tensor, target


root_dir = 'C:/Users/salba/Documents/videos'
annotation_dir = 'C:/Users/salba/Documents/volleyball_ball_annotation'
dataset = VolleyballDataset(root_dir, annotation_dir, transform=torchvision.transforms.ToTensor())

train_videos = [1, 3, 6, 7, 10, 13, 15, 16, 18, 22, 23, 31, 32, 36, 38, 39, 40, 41, 42, 48, 50, 52, 53, 54]
val_videos = [0, 2, 8, 12, 17, 19, 24, 26, 27, 28, 30, 33, 46, 49, 51]
test_videos = [4, 5, 9, 11, 14, 20, 21, 25, 29, 34, 35, 37, 43, 44, 45, 47]

train_idx = [i for i in range(len(dataset)) if dataset.video_dirs[i] in train_videos]
val_idx = [i for i in range(len(dataset)) if dataset.video_dirs[i] in val_videos]
test_idx = [i for i in range(len(dataset)) if dataset.video_dirs[i] in test_videos]

train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)
test_dataset = torch.utils.data.Subset(dataset, test_idx)

In [7]:
print('Number of train samples:', len(train_dataset))
print('Number of val samples:', len(val_dataset))
print('Number of test samples:', len(test_dataset))
print(dataset.video_dirs)

Number of train samples: 24
Number of val samples: 15
Number of test samples: 16
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]


In [18]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights
import utils
import torchvision.models.detection as detection
from engine import train_one_epoch, evaluate

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)

num_classes = 2  # ball and background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=utils.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=1)
    lr_scheduler.step()
    evaluate(model, val_loader, device=device)


Epoch: [0]  [ 0/12]  eta: 0:07:30  lr: 0.000459  loss: 2.3845 (2.3845)  loss_classifier: 0.5452 (0.5452)  loss_box_reg: 0.0040 (0.0040)  loss_objectness: 1.3749 (1.3749)  loss_rpn_box_reg: 0.4604 (0.4604)  time: 37.5280  data: 1.3625
Epoch: [0]  [ 1/12]  eta: 0:06:55  lr: 0.000913  loss: 1.9470 (2.1657)  loss_classifier: 0.5452 (0.5547)  loss_box_reg: 0.0023 (0.0031)  loss_objectness: 1.0355 (1.2052)  loss_rpn_box_reg: 0.3451 (0.4027)  time: 37.7588  data: 1.3481
Epoch: [0]  [ 2/12]  eta: 0:06:17  lr: 0.001367  loss: 1.9470 (1.9707)  loss_classifier: 0.5452 (0.5307)  loss_box_reg: 0.0030 (0.0031)  loss_objectness: 1.0355 (1.0692)  loss_rpn_box_reg: 0.3451 (0.3676)  time: 37.7980  data: 1.3743
Epoch: [0]  [ 3/12]  eta: 0:05:37  lr: 0.001821  loss: 1.5807 (1.8214)  loss_classifier: 0.4829 (0.4784)  loss_box_reg: 0.0030 (0.0031)  loss_objectness: 0.8884 (1.0240)  loss_rpn_box_reg: 0.2974 (0.3159)  time: 37.5426  data: 1.3618
Epoch: [0]  [ 4/12]  eta: 0:04:58  lr: 0.002275  loss: 1.5807 (1

TypeError: 'float' object is not subscriptable