In [1]:
from os import listdir
from PIL import Image, ImageOps, ImageFilter, ImageEnhance

import matplotlib.pyplot as plt
import numpy as np
import cv2
from matplotlib.patches import Rectangle

class_ = 'poisonous'
img_folder = 'drive/MyDrive/mushroom-photos/original-data/{}'.format(class_)
bbox_folder = 'drive/MyDrive/mushroom-photos/original-data/{}-bbox'.format(class_)

In [16]:
imgFiles = [file for file in listdir(img_folder) if file.endswith('.jpg')]

In [3]:
def get_bbox(txt_file): 
  # function to get the list of bounding box in a .txt file 

  bbox_ls = []

  with open(txt_file, 'r') as f:
    while True:
      ln = f.readline().rstrip()
      if ln == '':
        break
      ln = np.array(ln.split(), dtype = float)
      bbox_ls.append(ln)

  return np.array(bbox_ls)

In [4]:
def draw_bounding_box(img, bbox_ls):
  # Input: img (Image file), bbox (np array)
  # Output: no ouput, just show the img and the bounding box

  pic_width, pic_height = img.size
  plt.imshow(img)

  for bbox in bbox_ls:
    label, x, y, w, h = bbox
    x,y,w,h = pic_width*x, pic_height*y, pic_width*w, pic_height*h
    upper_left_corner = (x-w//2, y-h//2)
      
    ax = plt.gca()

    # Create a Rectangle patch
    rect = Rectangle(upper_left_corner, w,h, linewidth=1, edgecolor='r', facecolor = 'None')

    # Add the patch to the Axes
    ax.add_patch(rect)

  plt.axis('off')
  plt.show()

# Data augmentation

In [5]:
def add_noise(pic):
  # to add noise to a picture 

  pic_np = np.array(pic)
  noise = np.random.randint(-200, 200, pic_np.shape)
  noisy_pic = np.clip(np.add(pic_np, noise), 0, 255)
  return Image.fromarray(noisy_pic.astype('uint8'))

In [7]:
def translate(pic, bbox_ls):
  # to translate the image 
  pic_width, pic_height = pic.size

  # randomly select the amount to be shifted 
  shift = np.arange(40, 201, 20)
  prob = 1/ (np.ones(len(shift))*len(shift))
  n = len(shift)

  x_shift = np.random.choice(shift, p = prob)
  x_shift = np.random.choice([x_shift, -x_shift], p = [1/2, 1/2])

  y_shift = np.random.choice(shift, p = prob)
  y_shift = np.random.choice([y_shift, -y_shift], p = [1/2, 1/2])

  # translate image 
  a,b,c,d,e,f = 1,0,x_shift,0,1,y_shift
  translated_img = pic.transform(pic.size, Image.AFFINE, (a, b, c, d, e, f))

  # fixing the bounding boxes
  new_bbox_ls = []

  for bbox in bbox_ls:
    label, x, y, w, h = bbox
    x,y,w,h = pic_width*x, pic_height*y, pic_width*w, pic_height*h

    x -= c
    y -= f

    upper_left_corner = (x-w//2, y-h//2)
    lower_right_corner = (x+w//2,y+h//2)

    x1, y1 = upper_left_corner
    x2, y2 = lower_right_corner

    # if the bounding box get completely out of the boundaries, just remove it
    if x2 < 0 or y2 < 0 or x1 > pic_width or y1 > pic_height:
      pass
    # if the bounding box is still inside the image, fix the bounding box center and its width and height 
    else:
      x1, y1 = max(1,x1), max(1, y1)
      x2, y2 = min(pic_width-1, x2), min(pic_height-1, y2)

      new_x, new_y = (x2+x1)//2, (y2+y1)//2
      new_w, new_h = x2-x1, y2-y1

      new_x, new_y, new_w, new_h = new_x/pic_width, new_y/pic_height, new_w/pic_width, new_h/pic_height

      new_bbox = np.array([label, new_x, new_y, new_w, new_h])
      new_bbox_ls.append(new_bbox)

  return translated_img, np.array(new_bbox_ls)

In [8]:
def transform(pic, bbox_ls, transform_type):
  if transform_type == 1: # mirror, change bbox x-coords
    pic = ImageOps.mirror(pic)
    new_bbox_ls = [] 
    for label, x, y, w, h in bbox_ls:
      new_bbox_ls.append(np.array([label, 1-x, y, w, h]))
    bbox_ls = np.array(new_bbox_ls)

  elif transform_type == 2: # flip, change bbox y-coords
    pic = ImageOps.flip(pic)
    new_bbox_ls = [] 
    for label, x, y, w, h in bbox_ls:
      new_bbox_ls.append(np.array([label, x, 1-y, w, h]))
    bbox_ls = np.array(new_bbox_ls)

  elif transform_type == 3: # add noise, no change
    pic = add_noise(pic) 

  elif transform_type == 4: # translation 
    pic, bbox_ls = translate(pic, bbox_ls)

  return pic, bbox_ls

In [None]:
full_list = []
types = ['mirrored', 'flipped', 'noisy', 'translated']

for i in range(len(imgFiles)):
  try:
    path = imgFiles[i]
    name = path.split('.')[0]

    txt_file_path = bbox_folder + '/' + name + '.txt'
    img_path = img_folder + '/' + path

    bbox_ls = get_bbox(txt_file_path)
    img = Image.open(img_path)

    for j in range(1,5):
      transformed_img, new_bbox_ls = transform(img, bbox_ls, j)
      full_list.append(('{}-{}'.format(name, types[j-1]), 
                        transformed_img, new_bbox_ls))
      
  except FileNotFoundError:
    pass


In [None]:
np.random.shuffle(full_list)

In [None]:
m = len(full_list)//4
m

160

In [None]:
selected = full_list[:400-m]

In [None]:
correct_labels = {'edible': 0,
                  'inedible': 1, 
                  'poisonous': 2}

In [None]:
correct_labels[class_]

2

In [None]:
new_bbox_folder = 'drive/MyDrive/mushroom-photos/{}-bbox'.format(class_)
new_img_folder = 'drive/MyDrive/mushroom-photos/{}'.format(class_)
for name, transformed_img, new_bbox_ls in selected:
  img_name = new_img_folder + '/' + name + '.jpg'
  txt_name = new_bbox_folder + '/' + name + '.txt'

  with open(txt_name, 'w') as f:
    for label, x, y, w, h in new_bbox_ls:
      ln = [correct_labels[class_], x, y, w, h]
      ln = ' '.join([str(elem) for elem in ln])
      f.write(ln+'\n')

  transformed_img.save(img_name)

In [None]:
# save the origina images and fix their labels 

for i in range(len(imgFiles)):
  try:
    path = imgFiles[i]
    name = path.split('.')[0]

    og_txt_name = bbox_folder + '/' + name + '.txt'
    og_img_name = img_folder + '/' + path
    og_img = Image.open(og_img_name)

    img_name = new_img_folder + '/' + path
    txt_name = new_bbox_folder + '/' + name + '.txt'

    bbox_ls = get_bbox(og_txt_name)

    with open(txt_name, 'w') as f:
      for label, x, y, w, h in bbox_ls:
        ln = [correct_labels[class_], x, y, w, h]
        ln = ' '.join([str(elem) for elem in ln])
        f.write(ln+'\n')

    og_img.save(img_name)
  except FileNotFoundError:
    pass

In [None]:
# check the number of generated images 

finalFiles = [file for file in listdir(new_img_folder) if file.endswith('.jpg')]
len(finalFiles)

400