<a href="https://colab.research.google.com/github/zachmurphy1/facemask-faster-rcnn/blob/main/Oversampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Oversampling
This notebook performs oversampling on the train data.

## Input
Images and annotations in the train directory.
```
facemask_data/images/train/images
facemask_data/images/train/annotations
```

## Output
Oversampled images and annotations
```
facemask_data/images/train/oversampling/images
facemask_data/images/train/oversampling/annotations
```

Instance counts by class pre-oversampling
```
facemask_data/train/oversampling/train_instance_counts.pkl
```

(Intermediate) Instance crops with annotations by class
```
facemask_data/images/train/oversampling/no_mask/images
facemask_data/images/train/oversampling/no_mask/annotations

facemask_data/images/train/oversampling/masked/images
facemask_data/images/train/oversampling/masked/annotations

facemask_data/images/train/oversampling/incorrect/images
facemask_data/images/train/oversampling/incorrect/annotations
```

In [None]:
# Imports
import numpy as np
import pickle
import sys, os, shutil

from PIL import Image
from bs4 import BeautifulSoup
import torch, torchvision


In [None]:
# Mount data directory
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
%cd /content/gdrive/My\ Drive/facemask-faster-rcnn/

DATADIR = 'facemask_data/train'
ANNDIR = DATADIR + '/annotations'
IMGDIR = DATADIR + '/images'

n = len(next(os.walk(IMGDIR))[2])
print('# imgs:',n)
target_dir = DATADIR + '/' + 'oversampling'

Mounted at /content/gdrive
/content/gdrive/My Drive/DL Final Project
# imgs: 511


## Get instance crops

In [None]:
# Make dirs
os.makedirs(DATADIR + '/oversampling', exist_ok=True)
for f in ['no_mask','masked','incorrect']:
  os.makedirs(target_dir + '/' + f, exist_ok=True)
  os.makedirs(target_dir + '/' + f +'/images', exist_ok=True)
  os.makedirs(target_dir + '/' + f +'/annotations', exist_ok=True)

# Set bounding box border
border = 10

# For each image
counts = {'no_mask':0, 'masked':0, 'incorrect':0}
for i in range(n):
  # Get image
  img = Image.open(IMGDIR + '/' + str(i) + '.png').convert('RGB')

  # Get annotations
  ann_path = ANNDIR + '/' + str(i) + '.xml'
  with open(ann_path) as f:
    ann_xml = f.read()
  ann_parsed = BeautifulSoup(ann_xml,'xml')
  objects = ann_parsed.find_all('object')
  n_objs = len(objects)

  # Get image shape
  width = int(ann_parsed.find('width').text)
  height = int(ann_parsed.find('height').text)

  for o in objects:
    # Get bbox
    xmin = int(o.find('xmin').text)
    ymin = int(o.find('ymin').text)
    xmax = int(o.find('xmax').text)
    ymax = int(o.find('ymax').text)

    # Get image boundaries
    left_border = min(border,xmin)
    top_border = min(border,ymin)
    right_border = min(border,width-xmax)
    bottom_border = min(border,height-ymax)

    # Crop to instance
    cropped = img.copy()
    cropped = cropped.crop((xmin-left_border,ymin-top_border,xmax+right_border,ymax+bottom_border))

    # Edit bbox
    o.find('xmin').string.replaceWith(str(left_border))
    o.find('ymin').string.replaceWith(str(top_border))
    o.find('xmax').string.replaceWith(str(left_border+xmax-xmin))
    o.find('ymax').string.replaceWith(str(top_border+ymax-ymin))
    
    # Get target path
    mask_class = o.find('name').text
    prefix = ''
    if mask_class == 'without_mask':
      dir = target_dir + '/no_mask'
      prefix = 'no_mask'
    elif mask_class == 'with_mask':
      dir = target_dir + '/masked'
      prefix = 'masked'
    elif mask_class == 'mask_weared_incorrect':
      dir = target_dir + '/incorrect'
      prefix = 'incorrect'
    else:
      print('mask label error')

    # Save image
    cropped.save(dir + '/images/' + str(counts[prefix]) + '.png', 'PNG')

    # Save annotation
    with open (dir + '/annotations/' + str(counts[prefix]) + '.xml', 'w') as f:
      f.write(o.prettify())
      f.close()

    # Increment counts
    counts[prefix] += 1

{'no_mask': 461, 'masked': 1862, 'incorrect': 66}

In [None]:
with open(target_dir + '/train_instance_counts.pkl', 'wb') as f:
  pickle.dump(counts,f)
print(counts)

## Stitch instances into images

In [None]:
# Copy original images
orig_img_dir = '/content/gdrive/MyDrive/DL Final Project/facemask_data/train/oversampling/images'
orig_ann_dir = '/content/gdrive/MyDrive/DL Final Project/facemask_data/train/oversampling/annotations'
shutil.copytree('/content/gdrive/MyDrive/DL Final Project/facemask_data/train/images',orig_img_dir)
shutil.copytree('/content/gdrive/MyDrive/DL Final Project/facemask_data/train/annotations',orig_ann_dir)

'/content/gdrive/MyDrive/DL Final Project/facemask_data/train/oversampling/annotations'

In [None]:
for mask_class in ['incorrect','no_mask', 'masked']:
  # Get number to augment
  n_to_augment = int((max(v for k,v in counts.items()) - counts[mask_class])*3/4)

  # Set dirs
  paste_img_dir = target_dir + '/' + mask_class +'/images'
  paste_ann_dir = target_dir + '/' + mask_class +'/annotations'
  _, _, paste_paths = next(os.walk(paste_img_dir))
  _, _, orig_paths =  next(os.walk(orig_img_dir))

  # For each augmentation
  for n in range(n_to_augment):
    # Randomly select image to paste
    paste_idx = np.random.choice(paste_paths)[:-4]
    paste_img = Image.open(paste_img_dir + '/' + paste_idx + '.png')
    with open(paste_ann_dir + '/' + paste_idx + '.xml') as f:
      ann_xml = f.read()
      f.close()
    paste_ann = BeautifulSoup(ann_xml,'xml')

    # Try until get a match
    passed = False
    while passed == False:
      passed = True
      # Randomly select background image
      orig_idx = np.random.choice(orig_paths)[:-4]
      try:
        orig_img = Image.open(orig_img_dir + '/' + orig_idx + '.png').convert('rgb')
      except:
        orig_img = Image.open(orig_img_dir + '/' + orig_idx + '.png')
      with open(orig_ann_dir + '/' + orig_idx + '.xml') as f:
        ann_xml = f.read()
        f.close()
      orig_ann = BeautifulSoup(ann_xml,'lxml')

      # Get background ann data
      orig_width = int(orig_ann.find('width').text)
      orig_height = int(orig_ann.find('height').text)

      orig_o = orig_ann.find_all('object')
      orig_bbs = []
      for o in orig_o:
        orig_bbs.append([int(o.find('xmin').text),
                         int(o.find('ymin').text),
                         int(o.find('xmax').text),
                         int(o.find('ymax').text)])

      # Get paste box
      paste_bb = [round(float(paste_ann.find('xmin').text)),
                  round(float(paste_ann.find('ymin').text)),
                  round(float(paste_ann.find('xmax').text)),
                  round(float(paste_ann.find('ymax').text))]

      xshift = paste_bb[0]
      yshift = paste_bb[1]
      xwidth = paste_bb[2] - paste_bb[0]
      ywidth = paste_bb[3] - paste_bb[1]
      
      paste_img_w = paste_img.width
      paste_img_h = paste_img.height

      # Scale paste to mean of background bounding box widths but within 0.25-4
      orig_w_mean = np.mean([x[2] - x[0] for x in orig_bbs])
      scale = orig_w_mean / xwidth
      scale = scale * np.random.uniform(0.8,1.2)
      if scale > 4 or scale < 0.25:
        passed = False
        continue

      # Get scaled paste image size
      paste_img_w = int(round(paste_img_w*scale))
      paste_img_h = int(round(paste_img_h*scale))

      # Scale paste bounding box
      paste_img_scaled = paste_img.resize((paste_img_w, paste_img_h))
      paste_bb = [int(x*scale) for x in paste_bb]
      x_border = paste_bb[0]
      y_border = paste_bb[1]
      paste_width = paste_bb[2] - paste_bb[0]
      paste_height = paste_bb[3] - paste_bb[1]

      # Random horizontal flip
      if np.random.choice([True, False]):
        paste_img_scaled = paste_img_scaled.transpose(Image.FLIP_LEFT_RIGHT)
        paste_bb[0] = paste_img_scaled.width - paste_bb[2]
        paste_bb[2] = paste_img_scaled.width - x_border
      
      # Random position
      try:
        pos_x = np.random.randint(0,orig_img.width-paste_img_w)
        pos_y = np.random.randint(0,orig_img.height-paste_img_h)
      except:
        passed = False
        continue

      # Get paste box in terms of background image coords
      paste_bb[0] += pos_x
      paste_bb[1] += pos_y
      paste_bb[2] = paste_bb[0] + paste_width
      paste_bb[3] = paste_bb[1] + paste_height

      # IOU check
      iou_threshold = 0
      border_threshold = int(orig_img.width/20)
      for o in orig_bbs:
        ot = torch.Tensor([o[0],
                           o[1],
                           o[2],
                           o[3]]).unsqueeze(0)
        pt = torch.Tensor([pos_x-border_threshold,
                           pos_y - border_threshold,
                           pos_x + paste_img_scaled.width + border_threshold,
                           pos_y + paste_img_scaled.height + border_threshold]).unsqueeze(0)

        iou = torchvision.ops.box_iou(ot,pt)
        if iou > iou_threshold:
          passed = False
    
    # If passed, commit
    orig_img.paste(paste_img_scaled,(pos_x,pos_y))

    # Save image
    orig_img.save(orig_img_dir + '/' + orig_idx + '.png', 'PNG')

    # Save annotation
    paste_ann.find('xmin').string.replaceWith(str(paste_bb[0]))
    paste_ann.find('ymin').string.replaceWith(str(paste_bb[1]))
    paste_ann.find('xmax').string.replaceWith(str(paste_bb[2]))
    paste_ann.find('ymax').string.replaceWith(str(paste_bb[3]))
    orig_ann.find('annotation').append(paste_ann)

    with open(orig_ann_dir + '/' + orig_idx + '.xml', 'w') as f:
      f.write(orig_ann.prettify())
      f.close()