In [None]:
import os
import time
import concurrent.futures
from tqdm.notebook import tqdm
import numpy as np
from PIL import Image
import random
import matplotlib.pyplot as plt

In [None]:
def is_grayscale_fast(image_path):
    '''
    Determines if an image is grayscale by checking the standard deviation of a sample of pixels
    '''
    try:
        with Image.open(image_path) as img:
            if img.mode == 'L':
                return image_path
            img_array = np.array(img)
            if img_array.ndim == 2 or img_array.shape[2] == 1:
                return image_path
            # Check a sample of pixels
            sample_size = 1000
            h, w = img_array.shape[:2]
            samples = img_array[np.random.randint(0, h, sample_size),
                                np.random.randint(0, w, sample_size)]
            if np.max(np.std(samples, axis=1)) < 0.1:
                return image_path
    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
    return None

def process_chunk(chunk):
    return [is_grayscale_fast(img) for img in chunk]

def find_grayscale_images(root_dir, chunk_size=1000):
    all_images = []
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith((".jpg")):
                all_images.append(os.path.join(subdir, file))
    print(f"Total images found: {len(all_images)}")
    
    start_time = time.time()
    
    num_workers = os.cpu_count()
    print(f"Using {num_workers} workers")
    
    grayscale_images = []
    chunks = [all_images[i:i + chunk_size] for i in range(0, len(all_images), chunk_size)]
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(process_chunk, chunk) for chunk in chunks]
        
        for future in tqdm(concurrent.futures.as_completed(futures), 
                           total=len(chunks), 
                           desc="Processing chunks"):
            result = future.result()
            grayscale_images.extend([img for img in result if img is not None])
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    print(f"\nTotal images processed: {len(all_images)}")
    print(f"Grayscale images detected: {len(grayscale_images)}")
    print(f"Percentage grayscale: {len(grayscale_images) / len(all_images) * 100:.2f}%")
    print(f"Time taken: {elapsed_time:.2f} seconds ({elapsed_time/3600:.2f} hours)")
    
    return grayscale_images

def save_grayscale_list(grayscale_images, output_file):
    '''Save the list of grayscale images '''
    with open(output_file, 'w') as f:
        for img_path in grayscale_images:
            f.write(f"{img_path}\n")
    print(f"\nList of grayscale images saved to: {output_file}")

# Usage
root_dir = "img_data"
output_file = "grayscale_images.txt"
chunk_size = 1000

grayscale_images = find_grayscale_images(root_dir, chunk_size)
save_grayscale_list(grayscale_images, output_file)

print("\nTo delete these images, you can use the following command:")
print(f"xargs -a {output_file} rm")
print("\nWARNING: Be careful with the deletion. Make sure to review the list before deleting.")

In [None]:
# sample 10 images marked as grayscale
sample_size = 10
sample_images = random.sample(grayscale_images, sample_size)
fig, axs = plt.subplots(1, sample_size, figsize=(20, 20))
for i, img_path in enumerate(sample_images):
    img = Image.open(img_path)
    axs[i].imshow(img)
    axs[i].axis('off')

In [None]:
# remove the sample images
for img_path in sample_images:
    os.remove(img_path)