In [83]:
import sys
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [84]:
root='/mnt/c/Users/anany/Desktop/gradual_shift/data/cars-trucks-flickr'
EXTS = ['.png', '.jpg', '.jpeg']
dim = 64

In [85]:
def get_extension(name):
    return name[name.rfind('.'):]

def get_filenames(dir):
    if sys.version_info >= (3, 5):
        # Faster and available in Python 3.5 and above
        filenames = [d.name for d in os.scandir(dir) if not(d.is_dir()) and
                     get_extension(d.name) in EXTS]
    else:
        filenames = [d for d in os.listdir(dir) if not(os.path.isdir(os.path.join(dir, d)))
                     and get_extension(d) in EXTS]
    return filenames

def check_extensions(file_names):
    for name in file_names:
        ext = name[name.rfind('.'):]
        if ext not in EXTS:
            raise ValueError(f'{ext} is not a valid extension in {EXTS}, filename: {name}')

In [86]:
file_names = get_filenames(root + '/car 1940')

In [88]:

def get_car_data(dim, years, class_names, rng):
    xs, ys = [], []
    num_images = 0
    bad_shaped_images = 0
    for year in years:
        cur_xs, cur_ys = [], []
        for class_idx, class_name in zip(range(len(class_names)), class_names):
            folder_name = class_name + ' ' + str(year)
            file_names = get_filenames(root + "/" + folder_name)
            check_extensions(file_names)
            for name in file_names:
                path = root + '/' + folder_name + '/' + name
                try:
                    img = Image.open(path)
                except Image.UnidentifiedImageError:
                    bad_shaped_images += 1
                    continue
                img = img.resize((dim,dim), Image.ANTIALIAS)
                img_data = np.array(img) / 255.0
                if img_data.shape == (dim, dim, 3):
                    cur_xs.append(img_data)
                    cur_ys.append(class_idx)
                    num_images += 1
                else:
                    bad_shaped_images += 1
        cur_xs = np.array(cur_xs)
        cur_ys = np.array(cur_ys)
        perm = rng.permutation(len(cur_xs))
        cur_xs = cur_xs[perm]
        cur_ys = cur_ys[perm]
        xs.append(cur_xs)
        ys.append(cur_ys)
    xs = np.concatenate(xs, axis=0)
    ys = np.concatenate(ys, axis=0)
    return xs, ys
    

rng = np.random.RandomState(seed=0)
years = list(range(1940, 2021))
class_names = ['car', 'truck']
dim = 64
xs, ys = get_car_data(dim=dim, years=years, class_names=class_names, rng=rng)
pickle.dump( (xs, ys), open("cars-trucks-flickr-" + str(dim) + '.pkl', "wb"))

In [90]:
xs, ys = pickle.load( open( "cars-trucks-flickr-64.pkl", "rb" ) )

In [82]:
print(xs.shape)

(14717, 64, 64, 3)


In [23]:
ys[:10]

array([1, 0, 1, 0, 1, 0, 0, 1, 1, 0])