In [27]:
import os
import cv2 as cv
from glob import glob

categories = ['clean', 'messy']
raw_dir = {'train': './raw/train', 'val': './raw/val'}
output_dir = './images'
extensions = ('*.jpg', '*.png')


In [28]:
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

In [29]:
def resize(image, img_size = 299):
    h, w, c = image.shape
    cropped = image
    if h < w:
        diff = (w - h) // 2
        cropped = image[:, diff: (diff + h), :]
    elif h > w:
        diff = (h - w) // 2
        cropped = image[diff: (diff + w), :, :]

    h, w, c = cropped.shape
    if h > img_size:    # shrink
        return cv.resize(cropped, (img_size, img_size), interpolation=cv.INTER_AREA)
    elif h < img_size:  # enlarge
        return cv.resize(cropped, (img_size, img_size), interpolation=cv.INTER_CUBIC)
    else:
        return cropped

In [30]:
for dataset, path in raw_dir.items():
    output_set_dir = os.path.join(output_dir, dataset)
    if not os.path.exists(output_set_dir):
        os.mkdir(output_set_dir)
        
    for cat in categories:
        output_cat_dir = os.path.join(output_set_dir, cat)
        if not os.path.exists(output_cat_dir):
            os.mkdir(output_cat_dir)  
            
        input_dir = os.path.join(path, cat)
        filenames = list()
        for ext in extensions:
            print(os.path.join(input_dir, ext))
            filenames.extend(glob(os.path.join(input_dir, ext)))          

        for i, file in enumerate(filenames):
            print('processing:', file)
            img = cv.imread(file)
            resized = resize(img)
            img_name = str(i) + '.png'
            filepath = os.path.join(output_cat_dir, img_name)
            cv.imwrite(filepath, resized)            

./raw/train/clean/*.jpg
./raw/train/clean/*.png
processing: ./raw/train/clean/240_F_629580437_fc9kgoFtMGrGodIU0GRcZrNZpzAuteIp.jpg
processing: ./raw/train/clean/240_F_636126525_j8vSnp0sXg9e1WLRigLA8YX8zMrkTOEu.jpg
processing: ./raw/train/clean/240_F_196997285_5TO5iWGsW4ruIjAPJLbePFnOLyol5OB1.jpg
processing: ./raw/train/clean/240_F_136321570_vdzGyD9wnc618kDmYcLsdsyZJqWUtnZf.jpg
processing: ./raw/train/clean/240_F_49619979_2lyfHBejqIsrgAAQrUXjLZJaFiGbeZWW.jpg
processing: ./raw/train/clean/240_F_217951471_LzKKLeyXZ61bRCCrPYqxwphSOq8oNWLY.jpg
processing: ./raw/train/clean/240_F_461370506_HX1zmYH4dRLPvifCGVgFwLjZ0TO5QyEn.jpg
processing: ./raw/train/clean/240_F_639620917_sB1i01bmQUHqcR7nagr4GVey9J8ESy2j.jpg
processing: ./raw/train/clean/240_F_221509420_5guDmOXwuLKY30dayIl38sMkuCCJpSrl.jpg
./raw/train/messy/*.jpg
./raw/train/messy/*.png
processing: ./raw/train/messy/240_F_676663582_KM213e6b95u4RVNlYJSJKxTeZluIrrA1.jpg
processing: ./raw/train/messy/240_F_660567859_VK2knM8RlK92WMluSilv15a3bQXnK

processing: ./raw/train/messy/240_F_292737061_DpbDurRPiNSpB4lROL98VbgeFAl9tfqw.jpg
processing: ./raw/train/messy/240_F_239446277_oUOsEXCFY4SLuHsmkV22lRxvS35TJQmx.jpg
processing: ./raw/train/messy/240_F_65371097_ufDbBicZy9W2OC5a2imWMnTA4vyxSXBD.jpg
processing: ./raw/train/messy/240_F_1106456_CWGv0EGm9WYZDk6qMFlme1GHn1umfX.jpg
processing: ./raw/train/messy/240_F_667115220_ClfT9KTOe1oGc5A3qgMFlFf6jnqN6eLQ.jpg
./raw/val/clean/*.jpg
./raw/val/clean/*.png
processing: ./raw/val/clean/240_F_336510426_T8r3jTg31kgoyoMGULULyl1iJDCQgIvz.jpg
processing: ./raw/val/clean/240_F_577688677_LdbPdKZquQbRAMNCTW3AefgugpygVvol.jpg
processing: ./raw/val/clean/240_F_221509408_ekUpElITcvwAGKDIanHpIDLuEZ5Jdfy7.jpg
processing: ./raw/val/clean/240_F_712950271_KvoXgEfDnp5RroR3ORkIrGxEZfP5aDT4.jpg
processing: ./raw/val/clean/240_F_333004296_vaaERKXMcpAU4q7GhQKoP3s6YHwkN4n9.jpg
processing: ./raw/val/clean/240_F_95558174_cwyIq9NMleDmHSeZZZwjQdAhHe9MuzsN.jpg
processing: ./raw/val/clean/240_F_124351325_GSNxP0OAkyIlZIIVS

In [1]:
import os
from glob import glob
import cv2 as cv
import numpy as np

categories = ['clean', 'messy']
data_dir = {'train': './images/train', 'val': './images/val'}


def load_data():
    """96 images per class in training set, 10 images per class in validation set"""

    x_train = list()
    x_val = list()
    y_train = list()
    y_val = list()

    for dataset, path in data_dir.items():
        for i, cat in enumerate(categories):
            cur_dir = os.path.join(path, cat)
            filenames = glob(os.path.join(cur_dir, '*.png'))
            for file in filenames:
                img = cv.imread(file)
                if dataset == 'train':
                    x_train.append(img)
                    y_train.append(i)
                else:
                    x_val.append(img)
                    y_val.append(i)

    x_train = np.asarray(x_train)
    y_train = np.asarray(y_train)
    x_val = np.asarray(x_val)
    y_val = np.asarray(y_val)

    return (x_train, y_train), (x_val, y_val)


(x_train, y_train), (x_val, y_val) = load_data()

# normalize data
channel_mean = np.mean(x_train, axis=(0, 1, 2))
channel_std = np.std(x_train, axis=(0, 1, 2))

x_train = x_train.astype('float32')
x_val = x_val.astype('float32')

for i in range(3):
    x_train[:, :, :, i] = (x_train[:, :, :, i] - channel_mean[i]) / channel_std[i]
    x_val[:, :, :, i] = (x_val[:, :, :, i] - channel_mean[i]) / channel_std[i]


In [2]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import Xception, EfficientNetB0, MobileNetV3Large

2024-02-11 21:39:21.047946: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-11 21:39:21.343550: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-11 21:39:21.343605: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-11 21:39:21.382308: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-11 21:39:21.443481: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-11 21:39:21.444209: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [3]:
# define data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

augmented_data = x_train.copy()
train_labels = y_train.copy()

# flow in advance, get augmented training data and corresponding labels
for i in range(19):
    for img, label in datagen.flow(x_train, y_train, batch_size=192):
        print(i)
        augmented_data = np.vstack((augmented_data, img))
        print(augmented_data.shape)
        train_labels = np.hstack((train_labels, label))
        print(train_labels.shape)
        break

0
(384, 299, 299, 3)
(384,)
1
(576, 299, 299, 3)
(576,)
2
(768, 299, 299, 3)
(768,)
3
(960, 299, 299, 3)
(960,)
4
(1152, 299, 299, 3)
(1152,)
5
(1344, 299, 299, 3)
(1344,)
6
(1536, 299, 299, 3)
(1536,)
7
(1728, 299, 299, 3)
(1728,)
8


: 