diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index bfe8e1ba5eb..ca103080ecf 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -4,14 +4,20 @@ import os import os.path -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', -] +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + """Checks if a file is an image. + + Args: + filename (string): path to a file + + Returns: + bool: True if the filename ends with a known image extension + """ + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) def find_classes(dir):