This is not your usual import list - check to ensure that you have all of the libraries installed in your environment first

In [1]:
from PIL import Image
import rasterio
import os
import numpy as np
import xml.etree.ElementTree as ET
import copy
import random
from shutil import copyfile, copy2, move
import shutil
from rasterio.warp import reproject, calculate_default_transform, Resampling
from rasterio.crs import CRS

When working with raw satellite imagery, we find that some values are extremely high or extremely low, which make visualization difficult. As such, we construct a normalizaton function to take only values up to the 97th percentile, and then scale those values to 255. We do this so that we can construct JPEG images, for which RGB values must be between 0 and 255. 

In [2]:
def normalize(numpy_array):
    mid_range = (np.percentile(numpy_array, 97))
    numpy_array = numpy_array / mid_range * 255
    numpy_array = np.clip(numpy_array, a_min = 0, a_max = 255)
    return numpy_array

This function accomplishes two purposes: 
- it reads in the .tif file and reprojects it to a crs measured in meters (here, epsg 3857 - web mercator)
- it copies and saves the file in .jpeg format in the same folder it found the original image

In [38]:
def TifftoJPEGreproj(tiff):
    
    # open the tiff using rasterio, assign to variable 'raw'
    with rasterio.open(tiff) as raw:
        
        # get the raw data as a numpy array, make a copy 'x'
        arr = raw.read()
        x = arr.copy()
        
        # define the destination CRS
        dest_CRS = CRS(init='epsg:3857')
        
        # we work out here what the destination width, height and affine will be if reprojected to dest_CRS
        # from the existing crs
        dst_affine, dst_width, dst_height = calculate_default_transform(
        raw.crs, dest_CRS, raw.width, raw.height, *raw.bounds)
        
        # we make an empty array of this shape 
        x = np.empty([arr.shape[0],dst_height, dst_width])
        
        # then we go ahead and do the reprojection, using the empty array and the new affine transformation, 'dst_affine'
        reproject(source=arr,
              src_crs = raw.crs, 
              src_transform = raw.transform,
              destination = x,
              dst_transform = dst_affine,
              dst_crs = dest_CRS,
              resampling=Resampling.nearest,
              num_threads=2)
        
        # we pick out the red, green and blue bands from the .tif, which is known ex-ante to be an 8-band 
        # World View 2 or World View 3 image
        r, g, b = x[4], x[2], x[1]
        
        # we apply the previously defined normalize function to these bands so that we can generate our JPEG image 
        r = normalize(r)
        g = normalize(g)
        b = normalize(b)
        
        # we stack the data as unsigned 8 bit integers and pass this off rasterio to generate the JPEG
        outData = np.stack((r,g,b), axis = 0)
        outData = outData.astype(np.uint8)
        
        # file naming - we take the original file path and adjust the end to read '.jpg'
        name = os.path.split(tiff)[1].split('.')[0]
        outFile = os.path.join(pth, '%s.jpg' % name)
        
        # Generate the jpeg object
        jpeg = rasterio.open(outFile, 'w', driver="JPEG", dtype='uint8',
                                count=outData.shape[0],height=outData.shape[1],width=outData.shape[2])
        # write the data to the object
        jpeg.write(outData)
        
        # close the object
        jpeg.close()

In this block, we set the source of the raw imagery as the variable 'pth'. Then, we walk through looking for .tif files, and apply the TifftoJPEGreproj function to each. 

In [41]:
pth = r'D:\WWTP\GBDX_images\MechaTurk'

for root, dirs, files in os.walk(pth, topdown=False):
    for name in files:
        fil = os.path.join(root, name)
        if os.path.split(fil)[1].split('.')[-1] == 'tif':
            try:
                TifftoJPEGreproj(fil)
            except:
                pass

  


In this second phase, we take the newly generated JPEG images and manipualte them to generate additional training data. First we define a class for the images, IMAGE, and two helper classes, BBOX and DIMS, that work within IMAGE. 

The reason these are more complicated classes is because of the .xmls. When tagging the base JPEG image, the user identifies a Waste Water Treatment Plant (WWTP) with a bounding box. This is stored in an accompanying .xml file. We do not want to the user to have to re-tag rotations of the same original image! To avoid this, when generating a rotated or flipped image, we also clone and adjust the accompanying image xml. This is what is happening in the lines starting self.xmltree ...

### Image Flipping

In [8]:
class BBOX:
    def __init__(self, element):
        self.name = element[0].text
        for e in element:
            if e.tag == 'bndbox':
                self.xmin = e[0].text
                self.ymin = e[1].text
                self.xmax = e[2].text
                self.ymax = e[3].text
    
    def shape(self):
        return 'Class: %s, Bounding box: (%s, %s, %s, %s)' % (self.name, self.xmin, self.ymin, self.xmax, self.ymax)
    
class DIMS:
    def __init__(self, element):
        for e in element:
            if e.tag == 'width':
                self.width = e.text
            elif e.tag == 'height':
                self.height = e.text
            elif e.tag == 'depth':
                self.depth = e.text
        self.dims = ('Shape: %s x %s, depth: %s' % (self.width, self.height, self.depth))

# The main base class for re-imported JPEG images
class IMAGE:
    
    # user is only able to instantiate an instance if given a file name, fname, and filepath, path    
    def __init__(self, fname, path):
        
        # set filename and filepath as attributes 
        self.fname = fname
        self.path = path
        
        # the newly modified images will be saved to a new folder called 'Output' which is a subfolder of the original location
        # if this doesn't already exist, generate it here
        self.outpath = os.path.join(self.path, 'Output')
        if os.path.exists(self.outpath) == False:
            os.mkdir(os.path.join(self.path, r'Output'))
            
        # the xml describing tagged location of WWTPs should have the same name but different extension vs. the JPEG file
        self.xml = os.path.join(path,self.fname.replace('.jpg','.xml'))
        
        # make an output name and location for the modified xml
        self.outxml = os.path.join(self.outpath, self.fname.replace('.jpg','.xml'))
        
        # read in the original xml
        self.xmltree = ET.parse(self.xml)
        self.root = self.xmltree.getroot()
        
        # read in the original image 
        self.image = Image.open(os.path.join(self.path,self.fname))
        
        # save the original image and the original xml to the output folder - so Output folder contains everything necessary
        self.image.save(os.path.join(self.outpath, fname))
        self.xmltree.write(self.outxml)
        
        # open up two empty lists for the dimensions and bounding box info
        self.bbox_list, self.dims = [], []
        
        # find all objects in the xml tree called 'object'. These are the bounding boxes of the WWTP elements in the image. 
        bboxes = self.root.findall('object')
        self.dims = DIMS(self.root.find('size'))
        
        # add bounding boxes to the list of bounding boxes. Note the use of the helper BBOX class
        for b in bboxes:
            self.bbox_list.append(BBOX(b))
        
    # Here we define a function for rotating a single bounding box through 90 degrees.     
    def Rotate90(self, bbox, dim, count):
        h1,k1 = int(bbox.xmin), int(bbox.ymin)
        h2,k2 = int(bbox.xmax), int(bbox.ymax)
        if count == 1:
            o = int(dim.width)
        else:
            o = int(dim.height)
        bbox.xmin, bbox.ymin = k1, (o - h1)
        bbox.xmax, bbox.ymax = k2, (o - h2)
        return bbox
    
    # this function is a child of the IMAGE class, and so can only be used on / inherits an IMAGE object
    def CreateRotatedImages(self, degrees):
        
        # for the rotated image we deepcopy the xml tree
        newtree = copy.deepcopy(self.xmltree)
        root = newtree.getroot()
        
        ### Adjust JPEG
        
        # Happily, there is an easy PIL function for rotating a PIL image...so we don't reinvent the wheel
        rim = self.image.rotate(degrees, expand = True)
        
        # we rename the image with the number of degrees it has been rotated through
        newname = self.fname.replace('.jpg','_%s.jpg' % degrees)
        
        # ... and save it down
        rim.save(os.path.join(self.outpath, newname))
        
        ### Adjust XML
        
        # There may be many bounding boxes to iterate through - so we set up a loop that goes through all 
        # 'object' objects in the xml tree
        for obj in root.iter('object'):
            i = BBOX(obj)
            
            # we are quite lazy. So, for rotatings of more than 90 degrees, we just call the rotate 90 function more than once!
            if degrees == 90:
                newbbox = self.Rotate90(i, self.dims, 1)
                
            # God I know this is lazy but it works
            elif degrees == 180:
                newbbox = self.Rotate90(i, self.dims, 1)
                newbbox = self.Rotate90(newbbox, self.dims, 2)
                
            # Stop I know already but it works
            elif degrees == 270:
                newbbox = self.Rotate90(i, self.dims, 1)
                newbbox = self.Rotate90(newbbox, self.dims, 2)
                newbbox = self.Rotate90(newbbox, self.dims, 1)
                
            # Here we go ahead and adjust the text in the xml itself using the dimensions of the newbbox
            # (the bit above is calculating the correct coordinates of the new bbox, but not doing anything with it)
            # Here, we deploy it into our new xml
            for t in obj:
                if t.tag == 'bndbox':
                    t[0].text = str(min(newbbox.xmin, newbbox.xmax))
                    t[1].text = str(min(newbbox.ymin, newbbox.ymax))
                    t[2].text = str(max(newbbox.xmax, newbbox.xmin))
                    t[3].text = str(max(newbbox.ymax, newbbox.ymin))
        
        ### Adjust metadata of image as captured in the xml - which has a record of not 
        # just tagged objects, but also the image it relates to...
        
        # adjusted base image dimensions...
        if degrees == 90  or degrees == 270:
            for obj in root.iter("size"):
                w = obj.find("width")
                w.text = str(self.dims.height)
                h = obj.find("height")
                h.text = str(self.dims.width)
                
        # Adjusted base image filename...
        for obj in root.iter("filename"):
            obj.text = str(newname)
        for obj in root.iter("path"):
            obj.text = str(os.path.join(self.path, newname))
        
        # Write new file out
        newtree.write(self.outxml.replace(".xml", "_%s.xml" % str(degrees)))
    
    # in addition to rotating images, we also want to flip them on their axes - left:right, and top:bottom
    # for this we need a transpose function for the bounding boxes, which is a bit easier to write than a rotaional one
    def Transpose(self, bbox, dim, flip):
        h1,k1 = int(bbox.xmin), int(bbox.ymin)
        h2,k2 = int(bbox.xmax), int(bbox.ymax)
        if flip == 'lr':
            o = int(self.dims.width)
            bbox.xmin, bbox.ymin = (o - h1), k1
            bbox.xmax, bbox.ymax = (o - h2), k2
        elif flip == 'tb':
            o = int(self.dims.height)
            bbox.xmin, bbox.ymin = h1, (o - k1)
            bbox.xmax, bbox.ymax = h2, (o - k2)
        return bbox
    
    # this function deals with creating flipped images using the Transpose function above for associated bounding boxes
    def CreateFlippedImages(self, flip):
        
        newtree = copy.deepcopy(self.xmltree)
        root = newtree.getroot()
        
        ### Adjust JPEG
        # again, PIL has done the hard part for us - so we use their functions for doing the JPEG
        if flip == 'lr':
            flim = self.image.transpose(method = Image.FLIP_LEFT_RIGHT)
        elif flip == 'tb':
            flim = self.image.transpose(method = Image.FLIP_TOP_BOTTOM)
        newname = self.fname.replace('.jpg','_%s.jpg' % flip)
        flim.save(os.path.join(self.outpath, newname))
        
        ### Adjust XML
        # we are left with the headache of adjusting the corresponding xml bounding boxes and image metadata
        for obj in root.iter('object'):
            i = BBOX(obj)
            if flip == 'lr':
                newbbox = self.Transpose(i, self.dims, flip)
            elif flip == 'tb':
                newbbox = self.Transpose(i, self.dims, flip)
            for t in obj:
                if t.tag == 'bndbox':
                    t[0].text = str(min(newbbox.xmin, newbbox.xmax))
                    t[1].text = str(min(newbbox.ymin, newbbox.ymax))
                    t[2].text = str(max(newbbox.xmax, newbbox.xmin))
                    t[3].text = str(max(newbbox.ymax, newbbox.ymin))
                
        # Adjust filename 
        for obj in root.iter("filename"):
            obj.text = str(newname)
        for obj in root.iter("path"):
            obj.text = str(os.path.join(self.outpath, newname))
        
        # Write new file
        newtree.write(self.outxml.replace(".xml", "_%s.xml" % str(flip)))

### Run Image creation tool
Here, we iterate through each .jpeg which also has a .xml pair, rotating it through 90, 180 and 270 degrees, and also flipping it on its axes top to bottom and left to right. Thus, each jpeg / xml pair generates 12 output files in the Output folder 

In [10]:
# filepath containing pre-tagged xml / jpeg pairs
pth = r'D:\WWTP\GBDX_images\Mexico'

# we iterate through this and apply our methodologies described above
for root, dirs, files in os.walk(pth, topdown=False):
    for name in files:
        fil = os.path.join(root, name)
        if os.path.split(name)[1].split('.')[-1] == 'jpg' and os.path.exists(os.path.join(root, os.path.split(name)[1].split('.')[-2]+'.xml')) == True:
            image = IMAGE(name, pth)
            image.CreateRotatedImages(90)
            image.CreateRotatedImages(180)
            image.CreateRotatedImages(270)
            image.CreateFlippedImages('lr')
            image.CreateFlippedImages('tb')

### Move files into Tensorflow Suite with random Test / Train split
This is a helper block used to randomly move through the output folder and move some files into the training folder, and some into the test folder. The test:train split is determined by the random integer - which moves image/xml pairs to the test folder onl if the number is greater than 75. 

In [2]:
PATH = r'D:\WWTP\GBDX_images\Mexico\Output'
outtrain = r'C:\tensorflow1\models\research\object_detection\images\train'
outtest = r'C:\tensorflow1\models\research\object_detection\images\test'

In [3]:
for root, dirs, files in os.walk(PATH, topdown=False):
    for name in files:
        tif = name.replace('.jpg','.xml')
        if os.path.split(name)[1].split('.')[-1] == 'jpg':
            if os.path.exists(os.path.join(outtest, name)) == True or os.path.exists(os.path.join(outtrain, name)):
                pass
            else:

                if random.randint(1,100) > 75:
                    shutil.copyfile(os.path.join(root, name), os.path.join(outtest, name))
                    shutil.copyfile(os.path.join(root, tif), os.path.join(outtest, tif))
                else:
                    shutil.copyfile(os.path.join(root, name), os.path.join(outtrain, name))
                    shutil.copyfile(os.path.join(root, tif), os.path.join(outtrain, tif))