In [1]:
%matplotlib inline
#!pip install fastai -q --upgrade
#!pip install pycocotools
#!pip install -U albumentations

In [2]:
#imports
import math

#from fastai import *
#from fastai.vision import *
from fastai.vision.all import * 

import matplotlib.pyplot as plt
import matplotlib.image as mpim
import matplotlib.patches as patches

from pathlib import Path
import pandas as pd
import json
import os

import torch 
import torch.nn as nn
import torch.nn.functional as F 
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms as trans

os.chdir("includes")
from engine import train_one_epoch, evaluate
import utils
os.chdir("..")

import albumentations as A
import cv2

In [3]:
#gather ressources
path=Path('images')
testpath=Path('test')

#get training/validation and test images
images=get_image_files(path)
testimages=get_image_files(testpath)

#recover annotations from csv
annotations=pd.read_csv(path/'wappen.csv')
annotations=annotations.drop(['file_size','file_attributes','region_count','region_id'],axis=1)

In [4]:
#mpl helpers
def printImg(image,ax=None,size=None):
    if ax==None:
        im,ax=plt.subplots(figsize=size)
        
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.imshow(image)
    return ax
        
def drawBB(ax, bl, tr,col):
    height=tr[1]-bl[1]
    width=tr[0]-bl[0]
    bb=patches.Rectangle(bl,width,height,fill=False,color=col,lw=2.0)
    ax.add_patch(bb)
    
def drawText(ax,x,y,text,col):
    ax.text(x,y,text,color=col,fontsize=14)
    
def printImages(images,annotation_dict):
    for image in images:
        ax=printImg(mpim.imread(image),size=[14,14])
        imgname=os.path.normpath(image).split(os.sep)[-1]
        annotations=annotation_dict[imgname]
        for bb,label in annotations:
            color=colordict[label]
            drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)
            #drawText(ax,(bb[2]+bb[0])//2-8*len(label),(bb[3]+bb[1])//2-14,label,color)


In [5]:
#some global lookup stuff
colordict={"wappen":'magenta',"text":'cyan',"#na#":'green',"objekt":'red'}
clsToId={"wappen":1,"text":2,"objekt":1}
IdToClsSeperated={1:"wappen",2:"text"}
IdToClsMerged={1:"objekt"}

In [6]:
#global structures

#transforms
mytransforms = A.Compose([
    A.RandomBrightnessContrast(p=0.2),
    #A.Rotate(limit=5),
    A.Resize(1024,1024)
], bbox_params=A.BboxParams(format='pascal_voc', min_visibility=0.1, label_fields=['labels']))

testTransforms=trans.Compose([
    trans.Resize((1024,1024)),
    trans.ToTensor()
])

In [7]:
#dataset creation helpers
def selectClosestText(filename,bbox,maxYdown=100,maxYup=800,maxXdiff=300):
    #collect all text boxes that are located closely to the top of the wappen box or None if there is none
    maxY=min(bbox[1],bbox[3])
    medX=(bbox[0]+bbox[2])/2
    closestBox=None
    closestDist=0
    for tbox,cls in imtoann[filename]:
        if cls!='text':
            continue
        mdlY=(tbox[1]+tbox[3])/2
        if (mdlY-maxY)>maxYdown:
            continue
        if (maxY-mdlY)>maxYup:
            continue
        if abs((bbox[0]-tbox[0]))>maxXdiff and abs((bbox[2]-tbox[2]))>maxXdiff:
            continue
        mdlX=(tbox[0]+tbox[2])/2
        dist=(maxY-mdlY)**2 +(medX-mdlX)**2
        if closestBox==None:
            closestBox=(tbox,cls)
            closestDist=dist
        elif dist<closestDist:
            closestBox=(tbox,cls)
            closestDist=dist
            
    return closestBox

#transforms a given dictionary to a dictionary that is needed to create torch datasets
def getDatasetDictionary(inDict):
    imtoannlist={}
    for image in inDict:
        bblist=[]
        lbllist=[]
        for bb,lbl in inDict[image]:
            bblist.append(bb)
            lbllist.append(lbl)
        imtoannlist[image]=(bblist,lbllist)
    return imtoannlist

def createTorchDataset(dataset,val_pct=0.2):
    #split indices
    indices=torch.randperm(len(dataset)).tolist()
    
    lastTrain=len(dataset)-(int(len(dataset)*val_pct))

    #create troch datasets
    trainingSet=torch.utils.data.Subset(dataset, indices[:lastTrain])
    validationSet=torch.utils.data.Subset(dataset, indices[lastTrain:])
    
    return (trainingSet,validationSet)

def summarizeData(inDict):
    numImgs=len(inDict)
    numDict={}
    for key in inDict:
        for j in range(len(inDict[key][0])):
            labels=inDict[key][1]
            for i in range(len(labels)):
                if(labels[i] in numDict):
                    numDict[labels[i]]+=1
                else:
                    numDict[labels[i]]=1
            
    print("Dataset comprises "+str(numImgs)+" images")
    for key in numDict:
        print("Number of "+str(key)+": "+str(numDict[key]))
    

In [8]:
#definition of the dataset class
class EmblemTextSet(torch.utils.data.Dataset):
    def __init__(self,images,annotation_dict,transforms=None):
        self.transforms=transforms
        self.images=images
        self.dict=annotation_dict
    
    def __getitem__(self, idx):
        img = Image.open(images[idx]).convert("RGB")
        imgkey=os.path.normpath(images[idx]).split(os.sep)[-1]
        boxes=torch.tensor(self.dict[imgkey][0])
        labels=torch.tensor([clsToId[label] for label in self.dict[imgkey][1]],dtype=torch.int64)    
       
        if self.transforms!=None:
            img_transformed = self.transforms(image=np.array(img), bboxes=boxes,labels=labels)
        
        imgt=trans.ToTensor()(img_transformed['image'])
        target = {}
        target["boxes"] = torch.tensor(img_transformed['bboxes'],dtype=torch.float32)
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["area"]=torch.tensor([(b[3]-b[1])*(b[2]-b[0]) for b in target['boxes']])
        target["iscrowd"]=torch.tensor([0]*len(target['labels']), dtype=torch.int64)
    
        return imgt, target
    
    def __len__(self):
         return len(self.images)

In [9]:
#train functions

#train function that stops training process when the precision on the validation set decreases or stagnates
def train_model_prevent_overfit(model,datasets,bs,epochs,filename="./stored_models/currentModelEpochFinder",optimizer=None,lr=0.005,
                               max_decrease=0.3,max_worsening_epochs=10):
    device="cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    
    #create loaders
    train,valid=datasets
    train_loader=torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True,collate_fn=utils.collate_fn)
    validation_loader=torch.utils.data.DataLoader(valid, batch_size=bs, shuffle=True,collate_fn=utils.collate_fn)
      
    params = [p for p in model.parameters() if p.requires_grad]
    
    if optimizer==None:
        optimizer = torch.optim.SGD(params, lr=lr,momentum=0.9,weight_decay=0.0005)
        #optimizer =torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999), eps=1e-08)
        
    lr_scheduler =torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=4, T_mult=2)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)
    
    bestmAp=0.0
    mAphistory=[]
    strikes=0
    
    #training loop
    for epoch in range(epochs):
        #train on training set
        train_one_epoch(model, optimizer, train_loader, device, epoch,
                   print_freq=10)
        #adjust lr
        lr_scheduler.step()
  
        #check valid set
        coco_evaluator=evaluate(model, validation_loader, device=device)
        
        mAp=0
        for _,evaluator in coco_evaluator.coco_eval.items():
            mAp=sum(evaluator.stats[0:3])/3
            
        mAphistory.append(mAp)
            
        if(mAp<bestmAp*(1-max_decrease)):
            print("precision decreased by more than "+str(max_decrease)+" percent from best value at epoch: "+str(epoch)+" abandoning!")
            break
        
        if(mAp>bestmAp):
            strikes=0
            bestmAp=mAp
            torch.save(model.state_dict(),filename+"-bestmap")
        else:
            strikes+=1
            
        if(strikes>=max_worsening_epochs):
            print("precision on validation has not increased in "+str(max_worsening_epochs)+" epochs in epoch "+str(epoch)+" abandoning!")
            break
        
    #store trained parameters
    torch.save(model.state_dict(),filename)
    #return the epoch number with the best loss on validation set
    print(mAphistory)
    return mAphistory.index(max(mAphistory))+1
    


#default train function
def train_model(model,datasets,bs,epochs,filename="./stored_models/currentModel",optimizer=None,lr=0.005):
    device="cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    
    #create loaders
    train,valid=datasets
    train_loader=torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True,collate_fn=utils.collate_fn)
    validation_loader=torch.utils.data.DataLoader(valid, batch_size=bs, shuffle=True,collate_fn=utils.collate_fn)
      
    params = [p for p in model.parameters() if p.requires_grad]
    
    if optimizer==None:
        optimizer = torch.optim.SGD(params, lr=lr,momentum=0.9,weight_decay=0.0005)
        #optimizer =torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999), eps=1e-08)
        
    lr_scheduler =torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=4, T_mult=2)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)
    
    #training loop
    for epoch in range(epochs):
        #train on training set
        train_one_epoch(model, optimizer, train_loader, device, epoch,
                   print_freq=10)
        #adjust lr
        lr_scheduler.step()
  
        #check valid set
        evaluate(model, validation_loader, device=device)
        
    #store trained parameters
    torch.save(model.state_dict(),filename)

In [10]:
#shameless copy of train_one_epoch from torchvision engine, with slight adjustments 
#such that loss and learning rate values are stored in respective lists

def train_one_epoch_LRfind(model, optimizer, data_loader, device, epoch, 
                           lr_updater,print_freq,lrlist,losslist,stopFactor,lrfactor):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
    
    #this loops all minibatches
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            #print("Loss is {}, stopping training".format(loss_value))
            #print(loss_dict_reduced)
            #sys.exit(1)
            return

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        #append (smoothed) current loss to losslist
        
        curloss=float(losses.cpu().detach().numpy())
        if len(losslist)>0:
            losslist.append(0.05  * curloss + (1 - 0.05) * losslist[-1])
        else:
            losslist.append(curloss)

        curlr=optimizer.state_dict()["param_groups"][0]["lr"] 
        lrlist.append(curlr)
        
        if lr_scheduler is not None:
            lr_scheduler.step()

        lr_updater.step()
            
        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        
#plots learning rate vs loss
def plotLR(lrates,losses,skipf=1,skipl=1,minimal=None,steepest=None):
    fig, ax = plt.subplots(figsize=(12,6))
    plt.plot(lrates[skipf:-skipl],losses[skipf:-skipl])
    plt.xscale("log")
    plt.xlabel("learning rate")
    plt.ylabel("smoothened loss")
    if(minimal!=None):
        plt.scatter([minimal[0]],[minimal[1]])
    if(steepest!=None):
        plt.scatter([steepest[0]],[steepest[1]])
    plt.show()

#learning rate finder, if optimizer is passed, it is expected to have lowerBound as learning rate
def findStartingLR(data,model,bs,lowerBound=1e-7,upperBound=0.1,stopFactor=10,optimizer=None,steps=100,plot=True):
    #load model to device
    device="cuda" if torch.cuda.is_available() else "cpu" 
    model.to(device)
    
    train_loader=torch.utils.data.DataLoader(data[0], batch_size=bs, shuffle=True,collate_fn=utils.collate_fn)
    
    curLR=lowerBound
    
    lrates=[]
    losses=[]

    #set optimizer with initial lr at lower bound
    params = [p for p in model.parameters() if p.requires_grad]
    
    if optimizer==None:
        optimizer=torch.optim.SGD(params, lr=lowerBound)
        
    #set lr scheduler to a lambda function such that lr gets increased exponentially from lower bound to upper bound
    totalfactor=upperBound/lowerBound
    factor=totalfactor**(1.0/steps)
    multiplyLR= lambda x: factor**x
    multiplyLR_sched=torch.optim.lr_scheduler.LambdaLR(optimizer, multiplyLR)
    
    #set number of epochs to include at least steps number of steps
    batches_per_epoch=math.ceil(len(data[0])/bs)
    num_epochs=math.ceil(steps/batches_per_epoch)
    
    for i in range(num_epochs):
        train_one_epoch_LRfind(model, optimizer, train_loader, device, i, lr_updater=multiplyLR_sched, print_freq=10,
                               lrlist=lrates,losslist=losses,stopFactor=stopFactor,lrfactor=factor)
        
    losses.pop(0)
    lrates.pop(0)
    
    #find min point
    mind=losses.index(min(losses))
    minimal=(lrates[mind],min(losses))
    
    descents=[]
    lrates_descents=[]
    for i in range(1,len(losses)):
        descents.append(losses[i-1]-losses[i])
        lrates_descents.append(lrates[i-1])
    
    steepestind=descents.index(max(descents))
    steepest=(lrates_descents[steepestind],0.5*(losses[steepestind]+losses[steepestind+1]))
    
    #print(lrates)
    #print(losses)
    #print(lrates_descents)
    #print(descents)
    
    #plot LR vs losses
    if plot:
        plotLR(lrates,losses,minimal=minimal,steepest=steepest)
    
    #return minimal and steepest point
    return (steepest[0],minimal[0])
    




In [11]:
#trains a frcnn model with resnet50 backbone on the given data
#if a state dict is given, the model parameters will be initialized according to it
#otherwise a pretrained model on CoCo is used
def singleTrainingSession(data,num_classes,bs,stateDict=None,filename="./stored_models/currentModel"):
    #create trained model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)
    
    #create the model that is used for finding the lr
    lrfindmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    lrfindmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)
    
    #load previously trained parameters
    if stateDict!=None:
        model.load_state_dict(torch.load(stateDict))
        lrfindmodel.load_state_dict(torch.load(stateDict))
    
    #use lrfindmodel to get a decent learning rate for sgd
    lrsuggestions=findStartingLR(data,lrfindmodel,bs=bs,lowerBound=1e-7,upperBound=1,steps=200,plot=False)
    print(lrsuggestions)
    
    #in cases where the steepest descent is at a higher lr than the minimal loss, no reasonable learning rate could be discovered
    if(lrsuggestions[0]>lrsuggestions[1]):
        print("No reasonable learning rate could be discovered")
        return
    
    #take a value that is close to the steepest descent shifted a bit towards the minimal loss
    lr=(995/1000*lrsuggestions[0]+5/1000*lrsuggestions[1])
    
    #train_model(model,data,bs=bs,epochs=20,filename=filename,lr=lr)
    bestNepochs=train_model_prevent_overfit(model,data,bs,epochs=200,filename=filename,lr=lr,max_decrease=0.1,max_worsening_epochs=10)
    print("best results after "+str(bestNepochs)+" epochs")
    
    #return the trained model
    return model
    




In [12]:
#filters a list of annotations based on their score, by looking at a models confidence in the predictions
def filerConfidenceGap(boxes,labels,confidences,always_include_above=0.9,never_include_below=0.2):
    #find distances between confidences to discover largest gap
    largest_gap=0
    gap_confidence=0
    for i in range(1,len(confidences)):
        gap=confidences[i]-confidences[i-1]
        if(gap>largest_gap):
            largest_gap=gap
            gap_confidence=0.5*(confidences[i]+confidences[i-1])
   
    #if largest gap < lower limit, set it to the lower limit
    #if largest gap > upper limit, set it to the upper limit
    gap_confidence=min(max(gap_confidence,never_include_below),always_include_above)
    
    fboxes=[]
    flabels=[]
    fconfidences=[]
    #filter
    for i in range(len(confidences)):
        if confidences[i]>gap_confidence:
            fboxes.append([float(boxes[i][0]),float(boxes[i][1]),float(boxes[i][2]),float(boxes[i][3])])
            flabels.append(int(labels[i]))
            fconfidences.append(confidences[i])
    
    return (fboxes,flabels,fconfidences)

def evaluateAndPrintData(model,dataset,labelDict,max_num=4,filename="./stored_models/currentModel",score_threshold=0.5):
    #load state dict from path
    model.load_state_dict(torch.load(filename))
    #put model in eval mode
    model.eval()
    
    for i in range(min(max_num,len(dataset))):
        image,gtannotations=dataset[i]
        #print(mpim.imread(images[0]))
        
        img=Image.fromarray(image.mul(255).permute(1, 2,0).byte().numpy())
        ax=printImg(np.asarray(img),size=[14,14])
       
        annotations=(gtannotations['boxes'],gtannotations['labels'])
        #print ground truth
        for j in range(len(annotations[1])):
            label=labelDict[int(annotations[1][j])]
            bb=annotations[0][j]
            color=colordict[label]
            drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)
            #drawText(ax,(bb[2]+bb[0])//2-8*len(label),(bb[3]+bb[1])//2-14,label,color)
        
        prediction=model([image])
        ax=printImg(np.asarray(img),size=[14,14])
        
        boxes=prediction[0]["boxes"]
        labels=prediction[0]["labels"]
        scores=prediction[0]["scores"]
        #print(scores)
       
        #filter out low scoring predictions
        fboxes,flabels,_=filerConfidenceGap(boxes,labels,scores,always_include_above=0.9,never_include_below=score_threshold)
        
        #print(fboxes)
        #print(flabels)
        annotations=(torch.tensor(fboxes),torch.tensor(flabels))
        
        #print prediction
        for j in range(len(annotations[1])):
            label=labelDict[int(annotations[1][j])]
            bb=annotations[0][j]
            color=colordict[label]
            drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)
            #drawText(ax,(bb[2]+bb[0])//2-8*len(label),(bb[3]+bb[1])//2-14,label,color)

#evaluate unseen image with model and print the resulting predicted boxes
def evaluateAndPrintUnseenData(model,images,transforms,labelDict,max_num=4,filename="./stored_models/currentModel",score_threshold=0.5):
    #load state dict from path
    model.load_state_dict(torch.load(filename))
    #put model in eval mode
    model.eval()
    
    for i in range(min(max_num,len(images))):
        img=Image.open(images[i]).convert("RGB")
        imgTensor=transforms(img)
        #get prediction
        prediction=model([imgTensor])
        #get image
        image=Image.fromarray(imgTensor.mul(255).permute(1, 2,0).byte().numpy())
        ax=printImg(np.asarray(image),size=[14,14])
        
        boxes=prediction[0]["boxes"]
        labels=prediction[0]["labels"]
        scores=prediction[0]["scores"]
        
        #filter out low scoring predictions
        fboxes=[]
        flabels=[]
        for j in range(len(boxes)):
            if scores[j]>score_threshold:
                fboxes.append([boxes[j][0],boxes[j][1],boxes[j][2],boxes[j][3]])
                flabels.append(labels[j])
                
        annotations=(torch.tensor(fboxes),torch.tensor(flabels))
        
        #print prediction
        for j in range(len(annotations[1])):
            label=labelDict[int(annotations[1][j])]
            bb=annotations[0][j]
            color=colordict[label]
            drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)
            #drawText(ax,(bb[2]+bb[0])//2-8*len(label),(bb[3]+bb[1])//2-14,label,color)

#prints evaluations on seen and unseen data
def summary(model,images,data,transforms,labelDict,max_n=4,filename="./stored_models/currentModel",score_threshold=0.5):
    evaluateAndPrintData(model,data,labelDict,max_num=max_n,filename=filename,score_threshold=score_threshold)
    evaluateAndPrintUnseenData(model,images,transforms,labelDict,max_num=max_n,filename=filename,score_threshold=score_threshold)

In [13]:
#more functions to display images 
def area(box1,box2):
    x1,y1,x2,y2=box1
    x3,y3,x4,y4=box2
    xdiff=min(max(x1,x2),max(x3,x4))-max(min(x1,x2),min(x3,x4))
    ydiff=min(max(y1,y2),max(y3,y4))-max(min(y1,y2),min(y3,y4))
    
    if(xdiff<0 or ydiff<0):
        return 0
    
    return xdiff*ydiff

def IoU(box1,box2):
    x1,y1,x2,y2=box1
    x3,y3,x4,y4=box2
    
    area1=(max(x1,x2)-min(x1,x2))*(max(y1,y2)-min(y1,y2))
    area2=(max(x3,x4)-min(x3,x4))*(max(y3,y4)-min(y3,y4))
    
    intersectArea=area(box1,box2)
    unitedArea=area1+area2-intersectArea
    
    IoU=0.
    if(unitedArea>0):
            IoU=intersectArea/unitedArea
            
    return IoU

#check weather first box is completely surrounded by second box
def fullSurround(box1,box2):
    x1,y1,x2,y2=box1
    x3,y3,x4,y4=box2
    if x3<x1 and y3<y1 and x4>x2 and y4>y2:
        return True
    else:
        return False
    

#filter boxes of the same class that overlap too much
def filterOverlappingBoxes(boxes,label,scores,threshold=0.9):
    filteredBoxes=[]
    filteredLabels=[]
    filteredOut=[]
    for i in range(len(boxes)):
        conflict=False
        for j in range(len(boxes)):
            if i==j:
                continue
            if label[i]!=label[j]:
                continue
            if j in filteredOut:
                continue       
            #check for overlap with box j
            overlap=IoU(boxes[i],boxes[j])
            if overlap<threshold and not fullSurround(boxes[i],boxes[j]):
                continue
            #if overlap is big enough, and confidence in second box is higher, mark for removal
            if scores[i]<scores[j]:
                conflict=True
                break
        
        if conflict:
            filteredOut.append(i)
        else:
            filteredBoxes.append(boxes[i])
            filteredLabels.append(label[i])
        
    return (filteredBoxes,filteredLabels)



#convert bounding box predictions trained on a 1024x1024 to original image size
def convertBoxesToOriginalSize(boxes,image,transformedSize=(1024,1024)):
    img=Image.open(image).convert("RGB")
    origSize=img.size
    #print(transformedSize[0])
    #print(origSize[0])
    scales=(origSize[0]/transformedSize[0] , origSize[1]/transformedSize[1])
    scaledBoxes=[]
    for box in boxes:
        scaledBoxes.append([scales[0]*box[0],scales[1]*box[1],scales[0]*box[2],scales[1]*box[3]])
    return scaledBoxes


#print predictions from two different models next to each other for comparison
def displayComparison(images,models,labelDicts,size=(6,4),annotations=None,gtLabels=None,score_min=0.3,score_max=0.9,threshold=0.9):
    for model in models:
        model.eval()
    #get number of rows/columns for subplots
    height=len(images)
    cols=len(models)
    
    if(height>=10):
        half=len(images)//2
        displayComparison(images[:half],models,labelDicts,size,annotations,gtLabels,score_min,score_max,threshold)
        displayComparison(images[half:],models,labelDicts,size,annotations,gtLabels,score_min,score_max,threshold)
        return
   
    #set up overlying figure
    #fig=plt.figure(figsize=size)
    fig,axarr=plt.subplots(height,cols,figsize=(size[0]*cols,size[1]*height))
    
    
    
    #loop through all subplots
    for row in range(height):
        image=images[row]
        img=Image.open(image).convert("RGB")
        for col in range(cols):
            index=row*cols+col+1
               
            #print annotated ground truth image
            if(annotations!=None and col==cols-1):
                ax=showImg(mpim.imread(image),fig,height,cols,index,size)
                imgname=os.path.normpath(image).split(os.sep)[-1]
                annList=annotations[imgname]
                for bb,label in annList:
                    color=gtLabels[label]
                    drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)
            
            #display predicted image
            else:
                model=models[col]
                #transform
                imgTensor=testTransforms(img)
                #image=Image.fromarray(imgTensor.mul(255).permute(1, 2,0).byte().numpy())
                #predict
                prediction=model([imgTensor])
                #filter
                fboxes,flabels,fconfidences=filerConfidenceGap(prediction[0]["boxes"],prediction[0]["labels"],prediction[0]["scores"],always_include_above=score_max,never_include_below=score_min)
                #print(fboxes)
                #print(flabels)
                fboxes,flabels=filterOverlappingBoxes(fboxes,flabels,fconfidences,threshold=threshold)
                
                fboxes=convertBoxesToOriginalSize(fboxes,image,transformedSize=(1024,1024))
                annList=(torch.tensor(fboxes),torch.tensor(flabels))
                
                #print(annList[0])
                ax=axarr[row,col]
                #ax=plt.subplot(height,cols,index)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
                ax.imshow(img)
                #ax.imshow(np.asarray(image))
                #display
                #ax=showImg(mpim.imread(image),fig,height,cols,index,size)
                for i in range(len(annList[0])):
                    bb=annList[0][i]
                    label=annList[1][i]
                    labeltxt=labelDicts[col][int(label)]
                    color=colordict[labeltxt]
                    drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)

                    
#slight adaptation to the dataset variants of these methods to allow for      
def selectClosestText2(boxes,label,index,maxYdown=100,maxYup=800,maxXdiff=300):
    #collect all text boxes that are located closely to the top of the wappen box or None if there is none
    bbox=boxes[index]
    
    maxY=min(bbox[1],bbox[3])
    medX=(bbox[0]+bbox[2])/2
    closestBox=None
    closestDist=0
    for i in range(0,len(boxes)):
        tbox=boxes[i]
        cls=label[i]
        if i==index:
            continue
        if cls!=2:
            continue
        mdlY=(tbox[1]+tbox[3])/2
        if (mdlY-maxY)>maxYdown:
            continue
        if (maxY-mdlY)>maxYup:
            continue
        if abs((bbox[0]-tbox[0]))>maxXdiff and abs((bbox[2]-tbox[2]))>maxXdiff:
            continue
        mdlX=(tbox[0]+tbox[2])/2
        dist=(maxY-mdlY)**2 +(medX-mdlX)**2
        if closestBox==None:
            closestBox=tbox
            closestDist=dist
        elif dist<closestDist:
            closestBox=tbox
            closestDist=dist
            
    return closestBox        

#return a list of merged
def mergeBoxes(boxes,labels,maxYdown=100,maxYup=800,maxXdiff=300,mergedList=None):
    mergedBoxes=[]
    mergedLabels=[]
    for i in range(len(boxes)):
        x1,y1,x2,y2=boxes[i]
        if labels[i]==1:
            textbox=selectClosestText2(boxes,labels,i,maxYdown,maxYup,maxXdiff)
            if textbox==None:
                mergedBoxes.append(boxes[i])
            else:
                x3,y3,x4,y4=textbox
                mergedBox=[min(x1,x3),min(y1,y3),max(x2,x4),max(y2,y4)]
                mergedBoxes.append(mergedBox)
            if mergedList!=None:
                mergedList.append((boxes[i],textbox))
            #new structure will always have the label 'object'
            mergedLabels.append(1)
            
    return (mergedBoxes,mergedLabels)       


#prints several steps next to each other from the post processing of seperated images
def summarySeperated(model,images,labelDict,size=(12,10)):
    n_steps=3 #3-steps prediction-overlapping box pruning-wappen,text merging
    height=len(images)
    if(height*size[1]>=2**16):
        height=(2**16 -1)//size[1]
    
    
    fig,axarr=plt.subplots(height,n_steps,figsize=(size[0]*n_steps,size[1]*height))
    
    for row in range(height):
        #get predictions and postprocess
        image=images[row]
        img=Image.open(image).convert("RGB")
        imgTensor=testTransforms(img)
        prediction=model([imgTensor])
        fboxes,flabels,fconfidences=filerConfidenceGap(prediction[0]["boxes"],prediction[0]["labels"],prediction[0]["scores"])
        fboxes=convertBoxesToOriginalSize(fboxes,image,transformedSize=(1024,1024))  
        annList1=(torch.tensor(fboxes),torch.tensor(flabels))
        fboxes,flabels=filterOverlappingBoxes(fboxes,flabels,fconfidences,threshold=0.2)
        annList2=(torch.tensor(fboxes),torch.tensor(flabels))
        fboxes,flabels=mergeBoxes(fboxes,flabels,maxYdown=100,maxYup=800,maxXdiff=300)
        #print(fboxes)
        #print(flabels)
        annList3=(torch.tensor(fboxes),torch.tensor(flabels))
        annLists=[annList1,annList2,annList3]
        for col in range(n_steps):
            #index=row*n_steps+col+1   
            ax=axarr[row,col]
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            ax.imshow(img)
            annList=annLists[col]
            for i in range(len(annList[0])):
                bb=annList[0][i]
                label=annList[1][i]
                labeltxt=labelDict[int(label)]
                if(col==2):
                    labeltxt="objekt"
                color=colordict[labeltxt]
                drawBB(ax,(bb[0],bb[1]),(bb[2],bb[3]),color)
    

In [14]:
#functions for outputting predicted image data to json
def isIn(box,boxlist):
    x1,y1,x2,y2=box
    for box1,box2 in boxlist:
        x3,y3,x4,y4=box1
        if(x1==x3 and x2==x4 and y1==y3 and y2==y4):
            return True
        if(box2==None):
            continue
        x3,y3,x4,y4=box2
        if(x1==x3 and x2==x4 and y1==y3 and y2==y4):
            return True
    
    return False

#create a json file for image and 
def createJson(image,model,fname):
    
    imgname=os.path.normpath(image).split(os.sep)[-1]
    data={}
    op=open(fname,"w+")
    data["filename"]=imgname
    data["regions"]=[]
      
    img=Image.open(image).convert("RGB")
    imgTensor=testTransforms(img)
    prediction=model([imgTensor])
       
    fboxes,flabels,fconfidences=filerConfidenceGap(prediction[0]["boxes"],prediction[0]["labels"],prediction[0]["scores"])
    fboxes=convertBoxesToOriginalSize(fboxes,image,transformedSize=(1024,1024))  
    fboxes,flabels=filterOverlappingBoxes(fboxes,flabels,fconfidences,threshold=0.2)
    annList1=(torch.tensor(fboxes),torch.tensor(flabels))
    mergedList=[]
    fboxes,flabels=mergeBoxes(fboxes,flabels,maxYdown=100,maxYup=800,maxXdiff=300,mergedList=mergedList)
    annList2=(torch.tensor(fboxes),torch.tensor(flabels))
    
    for i in range (len(annList2[0])):
        curdict={}
        curdict["class"]="merged"
        box=annList2[0][i]
        text=mergedList[i][1]
        symbol=mergedList[i][0]
        curdict["box"]=[float(box[0]),float(box[1]),float(box[2]),float(box[3])]
        if text!=None:
            curdict["textbox"]=[float(text[0]),float(text[1]),float(text[2]),float(text[3])]
        curdict["symbolbox"]=[float(symbol[0]),float(symbol[1]),float(symbol[2]),float(symbol[3])]
        data["regions"].append(curdict)
        
    #take care of the artifacts that were not merged
    for j in range(len(annList1[0])):
        box=annList1[0][i]
        cls=annList1[1][i]
        if not isIn(box,mergedList):
            curdict={}
            if(cls==1):
                curdict["class"]="symbol"
            if(cls==2):
                curdict["class"]="text"
            curdict["box"]=[float(box[0]),float(box[1]),float(box[2]),float(box[3])]
            data["regions"].append(curdict)
    
    #print(data)
    json.dump(data,op)
    
    
#create a distinct json file for each region of intrest in image
def createIndividualJsons(image,model,fname):
    imgname=os.path.normpath(image).split(os.sep)[-1]
    
    img=Image.open(image).convert("RGB")
    imgTensor=testTransforms(img)
    prediction=model([imgTensor])
    
    
    fboxes,flabels,fconfidences=filerConfidenceGap(prediction[0]["boxes"],prediction[0]["labels"],prediction[0]["scores"])
    fboxes=convertBoxesToOriginalSize(fboxes,image,transformedSize=(1024,1024))  
    fboxes,flabels=filterOverlappingBoxes(fboxes,flabels,fconfidences,threshold=0.2)
    annList1=(torch.tensor(fboxes),torch.tensor(flabels))
    mergedList=[]
    fboxes,flabels=mergeBoxes(fboxes,flabels,maxYdown=100,maxYup=800,maxXdiff=300,mergedList=mergedList)
    annList2=(torch.tensor(fboxes),torch.tensor(flabels))
    
    #print(annList2[0])
    #print(mergedList)
    
    coveredParts=[]
    #loop through every merged box
    for i in range (len(annList2[0])):
        bbox=annList2[0][i]
        data={}
        outfile=fname+"-merged-"+str(i)+".json"
        op=open(outfile,"w+")
        data["filename"]=imgname
        data["region"]=[]
        curdict={}
        curdict["class"]="merged"
        curdict["bbox"]=[float(bbox[0]),float(bbox[1]),float(bbox[2]),float(bbox[3])]
        data["region"].append(curdict)
        originSymbol=mergedList[i][0]
        curdict={}
        curdict["class"]="symbol"
        curdict["bbox"]=[float(originSymbol[0]),float(originSymbol[1]),float(originSymbol[2]),float(originSymbol[3])]
        data["region"].append(curdict)
        originText=mergedList[i][1]
        if originText!=None:
            curdict={}
            curdict["class"]="text"
            curdict["bbox"]=[float(originText[0]),float(originText[1]),float(originText[2]),float(originText[3])]
            data["region"].append(curdict)
        #print("")
        #print(data)
        json.dump(data,op)
    
    idx=0
    for i in range (len(annList1[0])):
        box=annList1[0][i]
        cls=annList1[1][i]
        #search if box was already covered as part of the merged boxes
        if not isIn(box,mergedList):
            data={}
            outfile=fname+"-fragment-"+str(idx)+".json"
            op=open(outfile,"w+")
            idx+=1
            data["filename"]=imgname
            data["region"]=[]
            curdict={}
            if(cls==1):
                curdict["class"]="symbol"
            if(cls==2):
                curdict["class"]="text"
            curdict["bbox"]=[float(box[0]),float(box[1]),float(box[2]),float(box[3])]
            data["region"].append(curdict)
            #print("")
            #print(data)
            json.dump(data,op)


#create json files for each image in images
def createJsonsFor(model,images,paramfile="./stored_models/currentModel",splitOnSymbols=False):
    model.load_state_dict(torch.load(paramfile))
    model.eval()
    for image in images:
        imgname=os.path.normpath(image).split(os.sep)[-1]
        fname="./outputfiles/"+imgname.split('.')[0]
        if splitOnSymbols:
            createIndividualJsons(image,model,fname)
        else:
            createJson(image,model,fname+".json")
       
        
        
    

In [15]:
#create dictionary from image to a list of annotations
imtoann={}
for index,row in annotations.iterrows():
    imgname=row['filename']
    #print(imgname)
    bbdata=json.loads(row['region_shape_attributes'])
    #print(bbdata)
    labeldata=json.loads(row['region_attributes'])
    x1=int(bbdata['x'])
    y1=int(bbdata['y'])
    x2=x1+int(bbdata['width'])
    y2=y1+int(bbdata['height'])
    bb=[x1,y1,x2,y2]
    label=labeldata['region']
    annotation=(bb,label)
    if imgname in imtoann:
        imtoann[imgname].append(annotation)
    else:
        imtoann[imgname]=[annotation]
        
imtoannlist=getDatasetDictionary(imtoann)

In [16]:
#create dictionary of merged bbs
imToMergedBB={}
for item in imtoann:
    filename=item
    for item in imtoann[filename]:
        bbox,cls=item
        if cls!='wappen':
            continue
        matchingText=selectClosestText(filename,bbox)
        #print(item)
        #print(matchingText)
        
        annotation=None
        if matchingText==None:
            annotation=(bbox,'objekt')
        else:
            txtbox=matchingText[0]
            minX=min(bbox[0],bbox[2],txtbox[0],txtbox[2])
            minY=min(bbox[1],bbox[3],txtbox[1],txtbox[3])
            maxX=max(bbox[0],bbox[2],txtbox[0],txtbox[2])
            maxY=max(bbox[1],bbox[3],txtbox[1],txtbox[3])
            annotation=([minX,minY,maxX,maxY],'objekt')
        if(filename in imToMergedBB):
            imToMergedBB[filename].append(annotation)
        else:
            imToMergedBB[filename]=[annotation] 
            
imToAnnMerged=getDatasetDictionary(imToMergedBB)

In [17]:
#create datasets
#box and text seperated
datasetSeperated=EmblemTextSet(images,imtoannlist,mytransforms)
dataSeperated=createTorchDataset(datasetSeperated,0.2)
#box and text merged
datasetMerged=EmblemTextSet(images,imToAnnMerged,mytransforms)
dataMerged=createTorchDataset(datasetMerged,0.2)

In [24]:
#findLRSeperated
testmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = testmodel.roi_heads.box_predictor.cls_score.in_features
testmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features,3)

findStartingLR(dataSeperated,testmodel,bs=4,lowerBound=1e-7,upperBound=1,stopFactor=10,optimizer=None,steps=200,plot=True)

In [23]:
#findLRMerged
testmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = testmodel.roi_heads.box_predictor.cls_score.in_features
testmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features,2)

findStartingLR(dataMerged,testmodel,bs=4,lowerBound=1e-7,upperBound=1,stopFactor=10,optimizer=None,steps=200,plot=True)

In [22]:
trainedModelSeperated=singleTrainingSession(dataSeperated,3,bs=4)

In [21]:
trainedModelMerged=singleTrainingSession(dataMerged,2,bs=4)

In [18]:
#summarize seperated
testmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = testmodel.roi_heads.box_predictor.cls_score.in_features
testmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features,3)
summary(testmodel,testimages,datasetSeperated,testTransforms,IdToClsSeperated,max_n=10,filename="./stored_models/modelSeperated",score_threshold=0.8)

In [19]:
#summarize merged
testmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = testmodel.roi_heads.box_predictor.cls_score.in_features
testmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features,2)
summary(testmodel,testimages,datasetMerged,testTransforms,IdToClsMerged,max_n=10,filename="./stored_models/modelMerged",score_threshold=0.8)

In [20]:
modelSeperated = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = modelSeperated.roi_heads.box_predictor.cls_score.in_features
modelSeperated.roi_heads.box_predictor = FastRCNNPredictor(in_features,3)
modelSeperated.load_state_dict(torch.load("./stored_models/modelSeperated"))

modelMerged = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = modelMerged.roi_heads.box_predictor.cls_score.in_features
modelMerged.roi_heads.box_predictor = FastRCNNPredictor(in_features,2)
modelMerged.load_state_dict(torch.load("./stored_models/modelMerged"))

displayComparison(testimages,[modelSeperated,modelMerged],[IdToClsSeperated,IdToClsMerged],size=(14,10),annotations=None,gtLabels=None,threshold=0.2)

<All keys matched successfully>

In [26]:
summarySeperated(modelSeperated,testimages,IdToClsSeperated,size=(12,10))

In [72]:
createJsonsFor(modelSeperated,testimages,paramfile="./stored_models/modelSeperated",splitOnSymbols=False)#for one file per image
#createJsonsFor(modelSeperated,testimages,paramfile="./stored_models/modelSeperated",splitOnSymbols=False) #for one file per symbol

In [25]:
summarizeData(imtoannlist)

Dataset comprises 77 images
Number of wappen: 17935
Number of text: 19748
