### Utils

In [11]:
image_size = 96
num_samples = 5
epochs = 50


from typing import List, Tuple, Optional
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm import tqdm
from scipy.ndimage import rotate
import random
from sys import maxsize as max_int


# Function to load images and labels from a directory
def load_images(folder: str, n: Optional[int] = None) -> Tuple[List[np.ndarray], List[str]]:
    images = []
    labels = []
    label_folders = os.listdir(folder)
    for label_folder in tqdm(label_folders):
        count = 0
        label_path = os.path.join(folder, label_folder)
        if os.path.isdir(label_path):
            for img_file in glob(os.path.join(label_path, "*.jpg")):
                if n is not None and count >= n:
                    continue
                img = cv2.imread(img_file, cv2.IMREAD_COLOR)
                images.append(img)
                labels.append(label_folder)
                count += 1

    assert len(images) == len(labels), "Mismatch in number of images and labels"
    return images, labels


def display_samples(images: List[np.ndarray], labels: List[str]) -> None:
    # display some sample images with their labels
    plt.figure(figsize=(10, 2))
    for i in range(num_samples):
        rand = np.random.randint(0, len(images))
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(images[rand], cmap='gray')
        plt.title(f"{labels[rand]}")
        plt.axis('off')
    plt.show()


def downsample_image(image: np.ndarray) -> np.ndarray:
    # Downsample the image if its dimensions are larger than 1000
    max_dim = 1000
    if max(image.shape) > max_dim:
        scale = max_dim / max(image.shape)
        new_size = (int(image.shape[1] * scale), int(image.shape[0] * scale))
        image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
    return image


def detect_circles(image: np.ndarray, max = max_int, threshold = 100, show: bool = False) -> np.ndarray:
    circles = [[]]
    threshold = 100
    downsampled_image = downsample_image(image)
    blur = min(downsampled_image.shape) // 35
    blurred_image = cv2.blur(downsampled_image, (blur, blur))
    while True:
        circles = cv2.HoughCircles(
            blurred_image,
            cv2.HOUGH_GRADIENT,
            dp=1,
            minDist=20,
            param1=threshold,
            param2=threshold,
            minRadius=10,
            maxRadius=min(blurred_image.shape)
        )
        if circles is not None or threshold <= 5:
            break
        else:
            threshold -= 5
    
    circles = np.round(circles[0, :]).astype("int")

    if show:
        circle_image = image.copy()
        if circles is not None:
            for (x, y, r) in circles:
                cv2.circle(circle_image, (x, y), r, (0, 255, 0), 4)

        image_list = {
            "Original": image,
            "Blurred": blurred_image,
            "Detected Circles": circle_image
        }

        plt.figure(figsize=(20, 4))
        for i, (name, img) in enumerate(image_list.items()):
            plt.subplot(1, 3, i + 1)
            plt.imshow(img, cmap='gray')
            plt.title(name)
            plt.axis('off')

        plt.show()

    return circles[:max]


def crop_to_circle(img: np.ndarray, circles: np.ndarray) -> np.ndarray:
    cropped_imgs = []
    if circles is not None:
        for (x, y, r) in circles:
            cropped_img = img[y-r:y+r, x-r:x+r]
            cropped_imgs.append(cropped_img)
    return cropped_imgs


def remove_nones(images: List[np.ndarray], labels: List[str]) -> Tuple[np.ndarray, List[str]]:
    new_images = []
    new_labels = []
    if labels is None:
        labels = [None] * len(images)
    for img, label in zip(images, labels):
        if img is not None and img.shape[0] >= image_size and img.shape[1] >= image_size:
            new_images.append(img)
            new_labels.append(label)
    return np.array(new_images), new_labels


def preprocess_images(images: List[np.ndarray], labels: List[str], max: int = max_int) -> Tuple[List[np.ndarray], List[str]]:
    cropped_images = []
    cropped_labels = []
    images_with_labels = list(zip(images, labels))
    for img, label in tqdm(images_with_labels):
        circles = detect_circles(img, max=max)
        cis = crop_to_circle(img, circles)
        cropped_images.extend(cis)
        cropped_labels.extend([label] * len(cis))

    processed_images, processed_labels = downsample_images(cropped_images, cropped_labels)
    return processed_images, processed_labels


def downsample_images(images: List[np.ndarray], labels: List[str] = None, size: int = image_size) -> Tuple[np.ndarray, List[str]]:
    downsampled_images = []
    for img in images:
        if img is None or img.shape[0] < size or img.shape[1] < size:
            downsampled_images.append(None)
        else:
            downsampled_images.append(cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA))
    
    downsampled_images, labels = remove_nones(downsampled_images, labels)
    return np.array(downsampled_images), labels


def circle_mask_images(images: List[np.ndarray]) -> np.ndarray:
    masked_images = []
    for image in images:
        # Get the dimensions of the image
        height, width = image.shape[:2]

        # Create a mask with a filled circle in the center
        mask = np.zeros((height, width), dtype=np.uint8)
        center = (width // 2, height // 2)
        radius = min(center[0], center[1], width - center[0], height - center[1])
        cv2.circle(mask, center, radius, 255, -1)

        # Apply the mask to the image
        masked_image = cv2.bitwise_and(image, image, mask=mask)
        masked_images.append(masked_image)

    return np.array(masked_images)


def augment_image(image):
    # randomly rotate the image
    angle = random.uniform(-30, 30)
    rotated_image = rotate(image, angle, reshape=False)

    # randomly flip the image
    if random.choice([True, False]):
        flipped_image = np.fliplr(rotated_image)
    else:
        flipped_image = np.flipud(rotated_image)

    # pad to original size and return
    return flipped_image

### Crop and Save Images

In [13]:
import os
import cv2
import numpy as np
from typing import List
from tqdm import tqdm


def save_to_file(imgs: List[np.ndarray], label: str, num: int):
    if not os.path.exists(f"data/color_ball/{label}"):
            os.makedirs(f"data/color_ball/{label}")
            
    for i, img in enumerate(imgs):
        if img is None or img.shape[0] < 48 or img.shape[1] < 48:
            continue
        
        path = os.path.join("data", "color_ball", label, f"image{i}-{num}.jpg")
        cv2.imwrite(path, img)


# train_images, train_labels = load_images("data/train")
images_with_labels = list(enumerate(zip(train_images, train_labels)))
for i, (img, label) in tqdm(images_with_labels):
    grey_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    circles = detect_circles(grey_img, max=5)
    cis = crop_to_circle(img, circles)
    save_to_file(cis, label, i)

100%|██████████| 5632/5632 [05:42<00:00, 16.46it/s]


### Swap Numbers

In [None]:
import os
import re

def rename_files(folder: str):
    # Regular expression to match the file names
    pattern = re.compile(r'image(\d+)-(\d+)\.jpg')
    
    for filename in os.listdir(folder):
        match = pattern.match(filename)
        if match:
            x, y = match.groups()
            new_filename = f'image{y}-{x}.jpg'
            old_path = os.path.join(folder, filename)
            new_path = os.path.join(folder, new_filename)
            os.rename(old_path, new_path)
            print(f'Renamed: {filename} -> {new_filename}')

# Example usage
rename_files('data/ball/billiard_ball')

### Remove Images

In [14]:
import os

def remove_matching_files(folder1: str, folder2: str):
    subfolders = os.listdir(folder2)
    
    for subfolder in subfolders:
        path1 = os.path.join(folder1, subfolder)
        path2 = os.path.join(folder2, subfolder)
        
        if os.path.exists(path1) and os.path.exists(path2):
            files1 = set(os.listdir(path1))
            files2 = set(os.listdir(path2))
            
            matching_files = files1.intersection(files2)
            
            for file in matching_files:
                file_path = os.path.join(path2, file)
                os.remove(file_path)
                print(f'Removed: {file_path}')


def remove_non_matching_files(folder1: str, folder2: str):
    subfolders = os.listdir(folder2)
    
    for subfolder in subfolders:
        path1 = os.path.join(folder1, subfolder)
        path2 = os.path.join(folder2, subfolder)
        
        if os.path.exists(path1) and os.path.exists(path2):
            files1 = set(os.listdir(path1))
            files2 = set(os.listdir(path2))
            
            non_matching_files = files2 - files1
            
            for file in non_matching_files:
                file_path = os.path.join(path2, file)
                os.remove(file_path)
                print(f'Removed: {file_path}')

# remove_matching_files('data/ball', 'data/processed')
remove_non_matching_files('data/ball', 'data/color_ball')

Removed: data/color_ball\baseball\image0-185.jpg
Removed: data/color_ball\baseball\image4-112.jpg
Removed: data/color_ball\baseball\image1-111.jpg
Removed: data/color_ball\baseball\image0-12.jpg
Removed: data/color_ball\baseball\image1-304.jpg
Removed: data/color_ball\baseball\image1-347.jpg
Removed: data/color_ball\baseball\image1-102.jpg
Removed: data/color_ball\baseball\image4-128.jpg
Removed: data/color_ball\baseball\image0-213.jpg
Removed: data/color_ball\baseball\image0-77.jpg
Removed: data/color_ball\baseball\image4-291.jpg
Removed: data/color_ball\baseball\image0-236.jpg
Removed: data/color_ball\baseball\image1-363.jpg
Removed: data/color_ball\baseball\image2-320.jpg
Removed: data/color_ball\baseball\image1-287.jpg
Removed: data/color_ball\baseball\image0-349.jpg
Removed: data/color_ball\baseball\image0-313.jpg
Removed: data/color_ball\baseball\image1-247.jpg
Removed: data/color_ball\baseball\image2-112.jpg
Removed: data/color_ball\baseball\image0-134.jpg
Removed: data/color_ba