In [86]:
import glob
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image

In [None]:
data_folder = '/opt/ml/input/data/train/new_imgs/'
filenames = glob.glob(f'{data_folder}/*/*.jpg')
image_paths = filenames[:10]
plt.imshow(np.array(Image.open(image_paths[0])))
plt.show()

In [None]:
image_batch = []
image_batch_labels = []

n_images = 4

for i in range(n_images):
    image = np.array(Image.open(image_paths[i]))
    image_batch.append(image)
    
    print(image_paths[i])
    # male: 0, female: 1
    male_label = 0 if image_paths[i].split('/')[-2].split('_')[1] == 'male' else 1
    # young(~30): 0, middle(30~59), old(60~)
    age_label = 0 if int(image_paths[i].split('/')[-2].split('_')[-1]) < 30 else 1 if int(image_paths[i].split('/')[-2].split('_')[-1]) < 60 else 2
    # wear: 0, incorrect: 1, not wear: 2
    mask_label = 1 if 'incorrect' in image_paths[i].split('/')[-1] else 2 if 'normal' in image_paths[i].split('/')[-1] else 0
    multi_class_label = mask_label * 6 + male_label * 3 + age_label
    
    label = [0] * 18
    label[multi_class_label] = 1
    
    image_batch_labels.append(label)
    
# Convert image_batch to numpy array
image_batch = np.array(image_batch)
# Convert image_batch_labels to numpy array
image_batch_labels = np.array(image_batch_labels)

# Print labels
print()
print(f"Image labels: {image_batch_labels}\n")

for i in range(2):
    for j in range(2):
        plt.subplot(2, 2, 2 * i + j + 1)
        plt.imshow(image_batch[2 * i + j])
plt.show()

In [89]:
def vertical_bbox(size):
    """ Generate random bounding box 
    Args:
        - size: [width, breadth] of the bounding box
    Returns:
        - Bounding box
    """
    W = size[1]
    H = size[0]
    
    bbx1 = 0
    bby1 = 0
    bbx2 = W // 2
    bby2 = H
    
    return bbx1, bby1, bbx2, bby2

In [None]:
# Read an image
image = np.array(Image.open(image_paths[i]))

# Crop a vertical bounding box
size = image.shape
bbox = vertical_bbox(size)

# Draw bounding box on the image
im = image.copy()
x1 = bbox[0]
y1 = bbox[1]
x2 = bbox[2]
y2 = bbox[3]
cv2.rectangle(im, (x1, y1), (x2, y2), (255, 0, 0), 3)
plt.imshow(im)
plt.title('Original image with random bounding box')
plt.show()

# Show cropped image
plt.imshow(image[y1:y2, x1:x2])
plt.title('Cropped image')
plt.show()

In [98]:
def generate_cutmix_image(image_batch, image_batch_labels):
    """ Generate a CutMix augmented image from a batch 
    Args:
        - image_batch: a batch of input images
        - image_batch_labels: labels corresponding to the image batch
    Returns:
        - CutMix image batch, updated labels
    """
    # generate mixed sample
    rand_index = np.random.permutation(len(image_batch))
    target_a = image_batch_labels
    target_b = image_batch_labels[rand_index]
    bbx1, bby1, bbx2, bby2 = vertical_bbox(image_batch[0].shape)
    image_batch_updated = image_batch.copy()
    image_batch_updated[:, bby1:bby2, bbx1:bbx2, :] = image_batch[rand_index, bby1:bby2, bbx1:bbx2, :]
    
    label = target_a * 0.5 + target_b * 0.5
    
    return image_batch_updated, label

In [99]:
input_image = image_batch[0]
image_batch_updated, image_batch_labels_updated = generate_cutmix_image(image_batch, image_batch_labels)

In [None]:
# Show original images
print("Original Images")
for i in range(2):
    for j in range(2):
        plt.subplot(2, 2, 2 * i + j + 1)
        plt.imshow(image_batch[2 * i + j])
        
plt.show()

# Show CutMix images
print("CutMix Images")
for i in range(2):
    for j in range(2):
        plt.subplot(2, 2, 2 * i + j + 1)
        plt.imshow(image_batch_updated[2 * i + j])
        
plt.show()

# Print labels
print('Original labels:')
print(image_batch_labels)
print('Updated labels:')
print(image_batch_labels_updated)