In [43]:
import os
import cv2
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
import torch

In [44]:
augmented_images_dir = '/Users/treyshanks/data_science/Court_Vision/data/augmented/images'
augmented_labels_dir = '/Users/treyshanks/data_science/Court_Vision/data/augmented/labels'

# os.makedirs(augmented_images_dir, exist_ok=True)
# os.makedirs(augmented_labels_dir, exist_ok=True)


In [45]:
# Define transformations
transform = A.Compose([
    A.RandomBrightnessContrast(p=0.8),
    A.HueSaturationValue(p=0.75),
    A.RGBShift(p=0.75),
    A.RandomGamma(p=0.8),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=20, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=64, max_width=64, min_holes=1, min_height=8, min_width=8, p=0.5),
    A.ToGray(p=0.85),  
    ToTensorV2()
], bbox_params=A.BboxParams(format='yolo', label_fields=['labels']))

In [46]:
def read_labels(label_path):
    boxes = []
    with open(label_path, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split()
            if len(parts) == 5:
                try:
                    class_id, x_center, y_center, width, height = map(float, parts)
                    boxes.append([x_center, y_center, width, height])
                except ValueError as e:
                    print(f"Error parsing line '{line.strip()}': {e}")
            else:
                print(f"Skipping malformed line in {label_path}: '{line.strip()}'")
    return boxes


In [47]:
train_images_dir = '/Users/treyshanks/data_science/Court_Vision/agg_lebron/train/images'
train_labels_dir = '/Users/treyshanks/data_science/Court_Vision/agg_lebron/train/labels'

# Apply transformations and save augmented images
for img_file in os.listdir(train_images_dir):
    if img_file.endswith('.jpg'):
        img_path = os.path.join(train_images_dir, img_file)
        label_path = os.path.join(train_labels_dir, img_file.replace('.jpg', '.txt'))
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Read labels and apply transformations
        boxes = read_labels(label_path)
        augmented = transform(image=image, bboxes=boxes, labels=[0] * len(boxes))
        augmented_image = augmented['image']
        augmented_boxes = augmented['bboxes']

        # Save augmented image
        augmented_img_path = os.path.join(augmented_images_dir, img_file)
        cv2.imwrite(augmented_img_path, cv2.cvtColor(augmented_image.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))

        # Save augmented labels
        augmented_label_path = os.path.join(augmented_labels_dir, img_file.replace('.jpg', '.txt'))
        with open(augmented_label_path, 'w') as f:
            for box in augmented_boxes:
                f.write(f"0 {box[0]} {box[1]} {box[2]} {box[3]}\n")



Skipping malformed line in /Users/treyshanks/data_science/Court_Vision/agg_lebron/train/labels/LeBron-James-26-pts-6-rebs-9-asts-vs-Nuggets-2024-PO-G3_mp4-0773_jpg.rf.3be8d6bd6d7f5302dd0e1c483860ed9c.txt: '0 0.208984375 0.6267361109375 0.208984375 0.6024305562500001 0.22265625 0.5642361109375 0.224609375 0.484375 0.2294921875 0.4791666671875 0.2353515625 0.49652777812500004 0.240234375 0.49131944375 0.228515625 0.4635416671875 0.228515625 0.390625 0.2265625 0.36631944375000003 0.2216796875 0.3576388890625 0.212890625 0.36631944375000003 0.212890625 0.390625 0.201171875 0.4045138890625 0.193359375 0.45659722187499996 0.193359375 0.4704861109375 0.197265625 0.4739583328125 0.197265625 0.5052083328125 0.203125 0.5572916671875 0.19921875 0.5920138890625 0.193359375 0.609375 0.185546875 0.6197916671875 0.185546875 0.6302083328125 0.189453125 0.644097221875 0.201171875 0.6579861109375 0.2080078125 0.6493055562500001 0.2177734375 0.6493055562500001 0.220703125 0.644097221875 0.208984375 0.626