# Stage 3 - Tooth Distance Classifiers

In the classification outcomes from the previous stages, the search space of tooth numbers is effectively narrowed down. Given the locations of molar and non-molar teeth, one can now try to estimate their tooth numbers by considering the natural tooth sequence, tooth sizes, and their relative positions. In this stage, this heuristic estimation is implemented using a number of Random Forest (RF) or Support Vector Machine (SVM) classifiers using the Scikit-Learn module.

### Calculate the dataset of teeth bounding box distances

In [None]:
import os
import math 
import numpy as np
import random
import collections
import shutil
import torchvision
import copy
from PIL import Image

train_img_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/train/images'
train_label_32cls_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/train/labelTxt_32Cls'

val_img_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/val/images'
val_label_32cls_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/val/labelTxt_32Cls'

nonmolar_lbls = ['t11','t12','t13','t14','t15',
                 't41','t42','t43','t44','t45', 
                 't21','t22','t23','t24','t25',
                 't31','t32','t33','t34','t35']

molar_lbls = ['t16','t17','t18',
              't46','t47','t48',
              't26','t27','t28',
              't36','t37','t38']

valid_lbls = ['t11','t12','t13','t14','t15','t16','t17','t18',
              't21','t22','t23','t24','t25','t26','t27','t28',
              't31','t32','t33','t34','t35','t36','t37','t38',
              't41','t42','t43','t44','t45','t46','t47','t48']

padding_log_path = './4 Private ZOON Anonymised/Bitewing_proc/paddinglogs'


            

In [None]:
from shapely.geometry import Polygon

def CreateTeethDistanceAndWidthData(data_label_path, image_path):
    teeth_distances = {}
    for path in os.listdir(data_label_path):

        pf, sf = path.split('.')
        pil_img = Image.open(os.path.join(image_path, pf + '.png'))
        
        img_diag = (pil_img.size[0]**2 + pil_img.size[1]**2)**0.5
#         img_diag = img_diag_with_no_pad[pf] 

        path = os.path.join(data_label_path, path)
        if os.path.isfile(path):
            with open(path, 'r') as f:
                box_coors = {}
                ann_points = f.readlines()

                for row in ann_points:
                    row = row.split()
                    lbl = row[8]

                    coor1 = np.array([int(row[0]), int(row[1])])
                    coor2 = np.array([int(row[2]), int(row[3])])
                    coor3 = np.array([int(row[4]), int(row[5])])
                    coor4 = np.array([int(row[6]), int(row[7])])

                    # assume the rectangle coordinates in clockwise arrangement, calculate the centre using opposite corners
                    center = 0.5*(coor1 + coor3)
                    
                    w1 = np.linalg.norm(coor1-coor2)/img_diag
                    w2 = np.linalg.norm(coor2-coor3)/img_diag
                    width = min(w1,w2)
                    
                    
                    corners = np.array([[int(row[0]), int(row[1])], [int(row[2]), int(row[3])], 
                                       [int(row[4]), int(row[5])], [int(row[6]), int(row[7])]])
                    
#                     box_coors[lbl] = center
                    box_coors[lbl] = [center, width, corners]
    


        for ln,t in box_coors.items():
            if ln in valid_lbls:
                for lm in valid_lbls:
                    u = box_coors.get(lm,None)
                    if u is not None:
                        label = ln + '_' + lm
                        dist = np.linalg.norm(t[0]-u[0])

                        # Normalize by image diagonal length
                        dist = dist/img_diag
                                   
                        r1 = Polygon(t[2])
                        r2 = Polygon(u[2])
                        iou = r1.intersection(r2).area / r1.union(r2).area 
                        
                        if iou>0.12 :
                            if abs(int(ln[-1]) - int(lm[-1])) == 1 and ln[0:2]==lm[0:2]:
                                print(f'***{pf} {label} {iou}')

                        if (dist_list:=teeth_distances.get(label, None)) is None:
                            teeth_distances[label] = [[dist, t[1], u[1], iou]]
                        else:
                            dist_list.append([dist, t[1], u[1], iou])
                            
    merged_teeth_distances = {}
    for n in range(1,9):
        for m in range(1,9):

            # both Upper left and Upper right quadrants are merged together as they are simply mirrored 
            upperlbl = 'U' + str(n) + '_' + 'U' + str(m)
            lbl = 't' + str(10+n) + '_' + 't' + str(10+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                merged_teeth_distances[upperlbl] = copy.deepcopy(lst)
            lbl = 't' + str(20+n) + '_' + 't' + str(20+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                if merged_teeth_distances.get(upperlbl,None) is not None:
                    merged_teeth_distances[upperlbl].extend(copy.deepcopy(lst))
                else:
                    merged_teeth_distances[upperlbl] = copy.deepcopy(lst)
                    

            # both Lower left and Lower right quadrants are merged together as they are simply mirrored     
            lowerlbl = 'L' + str(n) + '_' + 'L' + str(m)
            lbl = 't' + str(30+n) + '_' + 't' + str(30+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                merged_teeth_distances[lowerlbl] = copy.deepcopy(lst)
            lbl = 't' + str(40+n) + '_' + 't' + str(40+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                if merged_teeth_distances.get(lowerlbl,None) is not None:
                    merged_teeth_distances[lowerlbl].extend(copy.deepcopy(lst))
                else:
                    merged_teeth_distances[lowerlbl] = copy.deepcopy(lst)

            # Lower  and Upper      
            crosslbl = 'L' + str(n) + '_' + 'U' + str(m)
            lbl = 't' + str(40+n) + '_' + 't' + str(10+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                merged_teeth_distances[crosslbl] = copy.deepcopy(lst)
            lbl = 't' + str(30+n) + '_' + 't' + str(20+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                if merged_teeth_distances.get(crosslbl, None) is not None:
                    merged_teeth_distances[crosslbl].extend(copy.deepcopy(lst))
                else:
                    merged_teeth_distances[crosslbl] = copy.deepcopy(lst)

            # Upper  and Lower      
            crosslbl = 'U' + str(n) + '_' + 'L' + str(m)
            lbl = 't' + str(10+n) + '_' + 't' + str(40+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                merged_teeth_distances[crosslbl] = copy.deepcopy(lst)
            lbl = 't' + str(20+n) + '_' + 't' + str(30+m)
            if (lst := teeth_distances.get(lbl,None)) is not None:
                if merged_teeth_distances.get(crosslbl, None) is not None:
                    merged_teeth_distances[crosslbl].extend(copy.deepcopy(lst))
                else:
                    merged_teeth_distances[crosslbl] = copy.deepcopy(lst)
                    
    return merged_teeth_distances

In [None]:
train_teeth_distancesAndwidths = CreateTeethDistanceAndWidthData(train_label_32cls_path, train_img_path)
test_teeth_distancesAndwidths = CreateTeethDistanceAndWidthData(val_label_32cls_path, val_img_path)

In [None]:
def CreateTeethWidthData(data_label_path, image_path):
    teeth_widths = {}
    for path in os.listdir(data_label_path):

        pf, sf = path.split('.')
        pil_img = Image.open(os.path.join(image_path, pf + '.png'))
        img_diag = (pil_img.size[0]**2 + pil_img.size[1]**2)**0.5
        
#         print(f'{path} {pil_img.size}' )

        path = os.path.join(data_label_path, path)
        if os.path.isfile(path):
            with open(path, 'r') as f:
                box_coors = {}
                ann_points = f.readlines()

                for row in ann_points:
                    row = row.split()
                    lbl = row[8]
                    
                    if lbl[0:2] == 't1' or lbl[0:2] == 't2':
                        lbl = 'L' + lbl[2]
                    else:
                        lbl = 'U' + lbl[2]
                    
                    coors = np.array([
                                        [int(row[0]), int(row[1])],
                                        [int(row[2]), int(row[3])],
                                        [int(row[4]), int(row[5])],
                                        [int(row[6]), int(row[7])]
                                    ])
                    
                    vc1 = coors[1] - coors[0]
                    vc2 = coors[2] - coors[1]
                     
                    if pil_img.size[0] > pil_img.size[1]:
                        ax = np.array([1,0])
                    else:
                        ax = np.array([0,1])
                        
                    c1 = 180*np.arccos( np.dot(vc1,ax)/(np.linalg.norm(ax) * np.linalg.norm(vc1)) ) / math.pi
                    c2 = 180*np.arccos( np.dot(vc2,ax)/(np.linalg.norm(ax) * np.linalg.norm(vc2)) ) / math.pi
                    c1 = c1-90 if c1>=90 else c1
                    c2 = c2-90 if c2>=90 else c2
                    
#                     boxwidth = np.linalg.norm(vc1)/img_diag if c1<=c2 else np.linalg.norm(vc2)/img_diag

                    boxwidth = min(np.linalg.norm(vc1), np.linalg.norm(vc2)) / img_diag
                    
                    if (width_list:=teeth_widths.get(lbl, None)) is None:
#                         if lbl in valid_lbls:
                        teeth_widths[lbl] = [boxwidth]
                    else:
                        width_list.append(boxwidth)
                        
                        
                    
    return teeth_widths
       
                       
train_teeth_widths =CreateTeethWidthData(train_label_32cls_path, train_img_path)
                    

In [None]:
# plot the boxplots for distribution of adjacent tooth distances
# plot the boxplots for distribution of tooth widths
# train_teeth_distances['L1_L2']

from matplotlib import pyplot as plt

adj_dists = {}
for i in range(3,8):
    lbl = 'U' + str(i) + '_U' + str(i+1)
    if (lst := train_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_dists.get(lbl,None) is None:
            adj_dists[lbl] = [d[0] for d in lst]
        else:
            adj_dists[lbl].extend([d[0] for d in lst])
        
    if (lst := test_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_dists.get(lbl,None) is None:
            adj_dists[lbl] = [d[0] for d in lst]
        else:
            adj_dists[lbl].extend([d[0] for d in lst])
    
fig, ax = plt.subplots(1, 1)

ax.boxplot(adj_dists.values(), labels=adj_dists.keys())
ax.set_xlabel('Adjacent Tooth Pair (Upper Quadrants)')
ax.set_ylabel('Normalized Distance')
plt.show()


In [None]:
adj_dists = {}
for i in range(3,8):
    lbl = 'L' + str(i) + '_L' + str(i+1)
    if (lst := train_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_dists.get(lbl,None) is None:
            adj_dists[lbl] = [d[0] for d in lst]
        else:
            adj_dists[lbl].extend([d[0] for d in lst])
        
    if (lst := test_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_dists.get(lbl,None) is None:
            adj_dists[lbl] = [d[0] for d in lst]
        else:
            adj_dists[lbl].extend([d[0] for d in lst])
    
fig, ax = plt.subplots(1, 1)
ax.boxplot(adj_dists.values(), labels=adj_dists.keys())
ax.set_xlabel('Adjacent Tooth Pair (Lower Quadrants)')
ax.set_ylabel('Normalized Distance')
plt.show()


In [None]:
adj_iou = {}
for i in range(3,8):
    lbl = 'L' + str(i) + '_L' + str(i+1)
    if (lst := train_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_iou.get(lbl,None) is None:
            adj_iou[lbl] = [d[3] for d in lst]
        else:
            adj_iou[lbl].extend([d[3] for d in lst])
        
    if (lst := test_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_iou.get(lbl,None) is None:
            adj_iou[lbl] = [d[3] for d in lst]
        else:
            adj_iou[lbl].extend([d[3] for d in lst])
    
fig, ax = plt.subplots(1, 1)
ax.boxplot(adj_iou.values(), labels=adj_iou.keys())
ax.set_xlabel('Adjacent Tooth Pair (Lower Quadrants)')
ax.set_ylabel('IoU of Bounding Box')
plt.show()

In [None]:
adj_iou = {}
for i in range(3,8):
    lbl = 'U' + str(i) + '_U' + str(i+1)
    if (lst := train_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_iou.get(lbl,None) is None:
            adj_iou[lbl] = [d[3] for d in lst]
        else:
            adj_iou[lbl].extend([d[3] for d in lst])
        
    if (lst := test_teeth_distancesAndwidths.get(lbl, None)) is not None:
        if adj_iou.get(lbl,None) is None:
            adj_iou[lbl] = [d[3] for d in lst]
        else:
            adj_iou[lbl].extend([d[3] for d in lst])
    
fig, ax = plt.subplots(1, 1)
ax.boxplot(adj_iou.values(), labels=adj_iou.keys())
ax.set_xlabel('Adjacent Tooth Pair (Lower Quadrants)')
ax.set_ylabel('IoU of Bounding Box')
plt.show()

### Define a function to fit a SVM/Random Forest classifier

In [None]:
# Fit a SVM classifier. 
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
from sklearn import svm, metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap


def FitDistancesClassifier(train_distance_set, test_distance_set, source_labels, retain_labels, model_name):   

    y_train_labels = []
    X_train = []

    y_test_labels = []
    X_test = []

    print('-----------------------------------------------------------')
    print(f'Fit classifier for {source_labels}')
    for lb in source_labels:

        y_lb = lb if lb in retain_labels else 'OTHERS'

        if (lst := train_distance_set.get(lb,None)) is not None:
            y_train_labels.extend([y_lb for i in range(len(lst))])
            X_train.extend(lst)
        else:
            print(f'{lb} is not found in training set')

        if (lst := test_distance_set.get(lb,None)) is not None:
            y_test_labels.extend([y_lb for i in range(len(lst))])
            X_test.extend(lst)
        else:
            print(f'{lb} is not found in test set')
            
            
    
    if len(X_train) > 30:
        print(f'Length of X_train is {len(X_train)}')
        
        if len(np.array(X_train).shape) == 1:
            X_train = np.array(X_train).reshape(-1,1)
            X_test = np.array(X_test).reshape(-1,1)

        class_le = LabelEncoder()
        y_train = class_le.fit_transform(y_train_labels)
        y_test = class_le.fit_transform(y_test_labels)

        # class_le.transform(['L5_L7', 'L3_L6'])
        # class_le.inverse_transform([3,4])
        
        if model_name == 'SVM':
            classifier = svm.SVC(probability=True, class_weight='balanced', random_state=721)
        elif model_name == 'RandomForest':
            classifier = RandomForestClassifier(class_weight='balanced', random_state=721)
       
        classifier.fit(X_train, y_train)
       
        
        predicted = classifier.predict(X_test)
        
        predicted_prob = classifier.predict_proba(X_test)
      
        ap_scores = average_precision_score(y_test, predicted_prob, average='weighted')
        

        return classifier, class_le, metrics.classification_report(y_test, predicted), ap_scores
    else:
        print(f'*WARNING: Length of X_train is {len(X_train)}. Failed to fit due to insufficient of training data.')
        return None, None, None
    
        
def PlotToothDistanceDistribution(tooth_distance_set, source_labels, retain_labels, plot_title):   

    y_labels = []
    X = []

    for lb in source_labels:

        y_lb = lb if lb in retain_labels else 'OTHERS'

        if (lst := tooth_distance_set.get(lb,None)) is not None:
            y_labels.extend([y_lb for i in range(len(lst))])
            X.extend(lst)
                  
    
        if len(np.array(X).shape) == 1:
            X = np.array(X).reshape(-1,1)
           
        class_le = LabelEncoder()
        Y = class_le.fit_transform(y_labels)
        
    
    
    
    cmap = ListedColormap(sns.color_palette(as_cmap=True))


    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111, projection='3d')

    x = [feature[0] for feature in X]
    y = [feature[1] for feature in X]
    z = [feature[2] for feature in X]

    sc = ax.scatter(x, y, z, s=3, c=Y, marker='o', cmap=cmap, alpha=1)

    ax.set_xlabel("Tooth distance")
    ax.set_ylabel("i-th Tooth width")
    ax.set_zlabel("j-th Tooth width")
    ax.set_title(plot_title)

#     plt.legend(handles=*sc.legend_elements(), labels=class_le.classes_, bbox_to_anchor=(1.05, 1), loc="upper right")
    handles, labels=sc.legend_elements()
    labels = class_le.classes_
    plt.legend(handles, labels,  loc="best", fontsize="small", ncols=2)


    plt.show()

    

### Fit classifiers for the teeth in the same quadrant

In [None]:
# cls2_sameQ = {}
# lblenc2_sameQ = {}
# cls2_rpt_sameQ = {}


# ### Fit classifiers for the teeth in the same quadrant
def GenLabelPairs(isUpper,isNonMolar,isMolar):
    source_labels = []
    pf = 'L'
    if isUpper:
        pf = 'U'
        
    arange = []
    if isNonMolar and isMolar:
        arange = range(3,6)
        brange = range(6,9)
    elif not isNonMolar and isMolar:
        arange = range(6,9)
        brange = range(6,9)
    elif isNonMolar and not isMolar:
        arange = range(3,6)
        brange = range(3,6)
        
    for j in arange:
        for k in brange:
            if k>j:
                source_labels.append(pf+str(j)+'_'+pf+str(k))
            
    return source_labels


cls_sameQ = {}
lblenc_sameQ = {}
cls_rpt_sameQ = {}
ap_scores_sameQ = {}


keys = ['U_MolarNonmolar', 'U_Molar', 'U_Nonmolar', 
       'L_MolarNonmolar', 'L_Molar', 'L_Nonmolar']
switches = [[True, True, True], [True, False, True], [True, True, False],
            [False, True, True], [False, False, True], [False, True, False]]


train_set = {}
for k,v in train_teeth_distancesAndwidths.items():
    lst = [r[0:3] for r in v]
    train_set[k]=lst
    
test_set = {}
for k,v in test_teeth_distancesAndwidths.items():
    lst = [r[0:3] for r in v]
    test_set[k]=lst
    

for key, switch in zip(keys, switches):
    source_labels = GenLabelPairs(isUpper=switch[0],isNonMolar=switch[1],isMolar=switch[2])
    
    if len(source_labels)>1:
        cls_sameQ[key], lblenc_sameQ[key], cls_rpt_sameQ[key], ap_scores_sameQ[key] = FitDistancesClassifier(train_set, 
                                                                                test_set, 
                                                                                source_labels, 
                                                                                source_labels,
                                                                                'SVM')



        if cls_rpt_sameQ[key] is not None:       
            print([str(ind) + ": " + val for ind,val in enumerate(lblenc_sameQ[key].classes_)])

            print("Classification report for classifier %s:\n%s\n"
                  % (cls_sameQ[key], cls_rpt_sameQ[key]))
            
            print("Average Precision score for classifier %s:\n%s\n"
                  % (cls_sameQ[key], ap_scores_sameQ[key]))
    else:
        cls_sameQ[key], lblenc_sameQ[key], cls_rpt_sameQ[key], ap_scores_sameQ[key] = None, None, None, None
        print('---------------------------------------------------------')
        print(f'{key} has only single class {source_labels}. No classifier is fitted.')

    
    




In [None]:
# Visualize the tooth pair distributions 

%matplotlib notebook
  
source_labels = ['U3_U4', 'U3_U5', 'U4_U5']
PlotToothDistanceDistribution(train_set, source_labels, 
                              source_labels,
                             "Scatter Plot of Tooth Pairs in Upper Quadrants")  
     

source_labels = ['L3_L4', 'L3_L5', 'L4_L5']
PlotToothDistanceDistribution(train_set, source_labels, 
                              source_labels,
                             "Scatter Plot of Tooth Pairs in Lower Quadrants")  

            

### Fit classifiers for the teeth in the opposite quadrant

In [None]:
#classifier for teeth in opposite quadrants
cls_oppositeQ = {}
lblenc_oppositeQ = {}
cls_rpt_oppositeQ = {}
ap_scores_oppositeQ = {}

source_labels = {}

arange = range(3,9)
for j in arange:

    key = 'U' + str(j)
    lst = ['U' + str(j) + '_L' + str(k) for k in arange]
    source_labels[key] = lst
    
    key = 'L' + str(j)
    lst = ['L' + str(j) + '_U' + str(k) for k in arange]
    source_labels[key] = lst
    
    

train_set = {}
for k,v in train_teeth_distancesAndwidths.items():
    lst = [r[0:3] for r in v]
    train_set[k]=lst
    
test_set = {}
for k,v in test_teeth_distancesAndwidths.items():
    lst = [r[0:3] for r in v]
    test_set[k]=lst
    

### Fit classifiers for the teeth in the opposite quadrant
for key,lblst in source_labels.items():
    src_lbls = lblst
    retain_lbls = lblst[0]
    
    print(src_lbls) 
    cls_oppositeQ[key], lblenc_oppositeQ[key], cls_rpt_oppositeQ[key], ap_scores_oppositeQ[key] = FitDistancesClassifier(train_set, 
                                                                                test_set, 
                                                                                src_lbls, 
                                                                                src_lbls,
                                                                                'SVM')
                                                               
    
    if cls_rpt_oppositeQ[key] is not None:       
        print([str(ind) + ": " + val for ind,val in enumerate(lblenc_oppositeQ[key].classes_)])
        
        print("Classification report for classifier %s:\n%s\n"
              % (cls_oppositeQ[key], cls_rpt_oppositeQ[key]))
        
        print("Average Precision score for classifier %s:\n%s\n"
                  % (cls_oppositeQ[key], ap_scores_oppositeQ[key]))

In [None]:
# Visualize the tooth distributions 

%matplotlib notebook
  
    
source_labels = ['L5_U3', 'L5_U4', 'L5_U5', 'L5_U6', 'L5_U7', 'L5_U8']
PlotToothDistanceDistribution(train_set, source_labels, 
                              source_labels,
                             "Scatter Plot of Tooth Pairs in Opposite Quadrants")  

source_labels = ['L6_U3', 'L6_U4', 'L6_U5', 'L6_U6', 'L6_U7', 'L6_U8']

PlotToothDistanceDistribution(train_set, source_labels, 
                              source_labels,
                             "Scatter Plot of Tooth Pairs in Opposite Quadrants")

# Stage 4 - Inference by combining Stage 1 to Stage 3

In [None]:
from shapely.geometry import Polygon
import collections
import sys

stage2_inference_label_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_4cls_scratch_yolo5m_detect_conf0p6_iou0p45_/labels'
inference_img_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_4cls_scratch_yolo5m_detect_conf0p6_iou0p45_'
stage3_inference_label_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage3/stage3outputlabels'

shutil.rmtree(stage3_inference_label_path)

detect_classes = ['U_molar', 'U_nonmolar', 'L_molar', 'L_nonmolar']

valid_lbls = { 'UpperQ': ['U1', 'U2', 'U3', 'U4', 'U5', 'U6', 'U7', 'U8'],
               'LowerQ': ['L1', 'L2', 'L3', 'L4', 'L5', 'L6', 'L7', 'L8']
             }


#### Define functions and perform steps by steps decisions

In [None]:
IoU_THRESHOLD = 0.5

def sort_coordinates(list_of_xy_coords, is_clockwise):
    cx, cy = list_of_xy_coords.mean(0)
    x, y = list_of_xy_coords.T
    angles = np.arctan2(x-cx, y-cy)
    indices = np.argsort(-1*angles) if is_clockwise else np.argsort(angles)
    
    return list_of_xy_coords[indices]

    
def check_overlapped_teeth_box(teeth):

    for i in range(0, len(teeth)):
        for j in range(i+1, len(teeth)):
            r1 = Polygon(teeth[i]['corners'])
            r2 = Polygon(teeth[j]['corners'])
            
            if r1.intersection(r2).area / r1.union(r2).area >= IoU_THRESHOLD:
                # IoU above threshold. Compare the object confidence 
                teeth[j]['is_redundant'] = (teeth[i]['conf'] >= teeth[j]['conf']) or teeth[j]['is_redundant']
                teeth[i]['is_redundant'] = (teeth[i]['conf'] < teeth[j]['conf']) or teeth[i]['is_redundant']



# To identify the nonmolar and molar pair with shortest distance
def determine_nonmolar_molar(teeth, tooth_register, is_upperQ):
   
    detect_count = 0

    quad_class = []
    if is_upperQ:
        quad_class.append('U_nonmolar')
        quad_class.append('U_molar')
        quad_class.append('U_MolarNonmolar')
    else:
        quad_class.append('L_nonmolar')
        quad_class.append('L_molar')
        quad_class.append('L_MolarNonmolar')
        
    mindist = sys.float_info.max
    
    # find the minimum distance between a molar and nonmolar
    for x in teeth:
        if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
            for y in teeth:
                if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y:  
                    dist = np.linalg.norm(x['center'] - y['center'])
                    mindist = dist if dist < mindist else mindist
    
    
    for x in teeth:
        if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
            for y in teeth:
                if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y:  
                    dist = np.linalg.norm(x['center'] - y['center'])
                    if dist == mindist:
                        # x, y is the closest nonmolar  and molar teeth. Use the classifier to predict it.
                        
                        wx1 = np.linalg.norm(x['corners'][0] - x['corners'][1])
                        wx2 = np.linalg.norm(x['corners'][1] - x['corners'][2])
                        wx = min(wx1,wx2)/x['img_diag']
                        
                        wy1 = np.linalg.norm(y['corners'][0] - y['corners'][1])
                        wy2 = np.linalg.norm(y['corners'][1] - y['corners'][2])
                        wy = min(wy1,wy2)/y['img_diag']
                        
                        # img_diag should be the same in x and y as they in the same image
                        dist = dist/x['img_diag']
                        
                     
                        
                        # the widths and box distance should be normalized before feed to classifier
                        output_prob = cls_sameQ[quad_class[2]].predict_proba(np.array([[dist, wx, wy]]))
                        output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(output_prob)]
                        tootha, toothb = output_label.split('_')
                        x['tooth_number_pred'].append(tootha)
                        x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                        y['tooth_number_pred'].append(toothb)
                        y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                        
                        tooth_register[tootha] = 'FOUND' 
                        tooth_register[toothb] = 'FOUND'
                        detect_count += 1
                        
    return detect_count




# To identify the molar pairs
def determine_molar(teeth, tooth_register, is_upperQ):
   
    detect_count = 0

    quad_class = []
    if is_upperQ:
        quad_class.append('U_molar')
        quad_class.append('U_molar')
        quad_class.append('U_Molar')
    else:
        quad_class.append('L_molar')
        quad_class.append('L_molar')
        quad_class.append('L_Molar')
        
    molar_cnt = 0
    for x in teeth:
        if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant']:
            molar_cnt +=1
            
    if molar_cnt<2:
        print(f'There is only {molar_cnt} teeth. No need to predict molar pairs!')
        return
    
    
    # find the minimum distance between a molar pair
    is_already_determined_molar = False
    
    for x in teeth:
        if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
            if len(x['tooth_number_pred'])>0:
                is_already_determined_molar = True
                    
                    
    if not is_already_determined_molar:
        # No molar tooth is identified before, pick the closes distance pair to predict
        
        mindist = sys.float_info.max
        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y:  
                        dist = np.linalg.norm(x['center'] - y['center'])
                        mindist = dist if dist < mindist else mindist


        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y:  
                        dist = np.linalg.norm(x['center'] - y['center'])
                        if dist == mindist:
                            # x, y is the closest nonmolar  and molar teeth. Use the classifier to predict it.

                            wx1 = np.linalg.norm(x['corners'][0] - x['corners'][1])
                            wx2 = np.linalg.norm(x['corners'][1] - x['corners'][2])
                            wx = min(wx1,wx2)/x['img_diag']

                            wy1 = np.linalg.norm(y['corners'][0] - y['corners'][1])
                            wy2 = np.linalg.norm(y['corners'][1] - y['corners'][2])
                            wy = min(wy1,wy2)/y['img_diag']

                            # img_diag should be the same in x and y as they in the same image
                            dist = dist/x['img_diag']
                            
                      
                            # the widths and box distance should be normalized before feed to classifier
                            output_prob = cls_sameQ[quad_class[2]].predict_proba(np.array([[dist, wx, wy]]))
                            output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(output_prob)]
                            tootha, toothb = output_label.split('_')
                            x['tooth_number_pred'].append(tootha)
                            x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                            y['tooth_number_pred'].append(toothb)
                            y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))

                            tooth_register[tootha] = 'FOUND' 
                            tooth_register[toothb] = 'FOUND'
                            detect_count += 1
                            
    
    # a molar tooth is identified before, predict from the identified one
    is_all_molar_determined = False
    
    while not is_all_molar_determined:
        
        mindist = sys.float_info.max
        
        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] and len(x['tooth_number_pred'])>0:
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y and \
                        len(y['tooth_number_pred'])==0:  
                        dist = np.linalg.norm(x['center'] - y['center'])
                        mindist = dist if dist < mindist else mindist


        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] and len(x['tooth_number_pred'])>0:
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y and \
                        len(y['tooth_number_pred'])==0:  

                        dist = np.linalg.norm(x['center'] - y['center'])
                        if dist == mindist:
                            # x, y is the closest  molar teeth. Use the classifier to predict it.

                            wx1 = np.linalg.norm(x['corners'][0] - x['corners'][1])
                            wx2 = np.linalg.norm(x['corners'][1] - x['corners'][2])
                            wx = min(wx1,wx2)/x['img_diag']

                            wy1 = np.linalg.norm(y['corners'][0] - y['corners'][1])
                            wy2 = np.linalg.norm(y['corners'][1] - y['corners'][2])
                            wy = min(wy1,wy2)/y['img_diag']

                            # img_diag should be the same in x and y as they in the same image
                            dist = dist/x['img_diag']
                            
                        

                            # the widths and box distance should be normalized before feed to classifier
                            output_prob = cls_sameQ[quad_class[2]].predict_proba(np.array([[dist, wx, wy]])) 
                            #output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(output_prob)]
                            #tootha, toothb = output_label.split('_')
                            
                            
                            # To prevent from duplicate prediction
                            try_prob = output_prob.copy()
                            is_found = False
                            while not is_found:
                                output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(try_prob)]
                                tootha, toothb = output_label.split('_')
                             
                                if (tooth_register[tootha]=='FOUND' and tooth_register[toothb]=='FOUND') or \
                                    (tooth_register[tootha]!='FOUND' and tooth_register[toothb]!='FOUND') :
                                    # both teeth had been identified or both not found, use the lower conf predictions
                                    try_prob[0][np.argmax(try_prob)] = -1
                                else:
                                    is_found=True
                                
                                # quit the loop if trying all 
                                if try_prob[0][np.argmax(try_prob)] == -1:
                                    is_found = True
                            

                            if x['tooth_number_pred'][-1]==tootha:
                                y['tooth_number_pred'].append(toothb)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                tooth_register[toothb]='FOUND' 
                                detect_count += 1
                            elif x['tooth_number_pred'][-1]==toothb:
                                y['tooth_number_pred'].append(tootha)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                tooth_register[tootha]='FOUND' 
                                detect_count += 1
                            else:
                                xlbl = x['tooth_number_pred'][-1]
                                print(f'*****WARNING: Previous stage determined {xlbl}, but now have ({tootha}, {toothb})')
                                
                                x['tooth_number_pred'].append(tootha)
                                x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                x['tooth_number_pred'].append(toothb)
                                x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                y['tooth_number_pred'].append(tootha)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                y['tooth_number_pred'].append(toothb)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                tooth_register[tootha]='FOUND' 
                                tooth_register[toothb]='FOUND' 
                                detect_count += 1
                                
                                
        is_all_molar_determined = True
        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] and len(x['tooth_number_pred'])==0:
                is_all_molar_determined = False
            
    return detect_count

    

    
# To identify the nonmolar pairs
def determine_nonmolar(teeth, tooth_register, is_upperQ):
   
    detect_count = 0

    quad_class = []
    if is_upperQ:
        quad_class.append('U_nonmolar')
        quad_class.append('U_nonmolar')
        quad_class.append('U_Nonmolar')
    else:
        quad_class.append('L_nonmolar')
        quad_class.append('L_nonmolar')
        quad_class.append('L_Nonmolar')
        
    nonmolar_cnt = 0
    for x in teeth:
        if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant']:
            nonmolar_cnt +=1
            
    if nonmolar_cnt<2:
        print(f'There is only {nonmolar_cnt} teeth. No need to predict nonmolar pairs!')
        return 
    
    
    # find the minimum distance between a molar pair
    is_already_determined_nonmolar = False
    
    for x in teeth:
        if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
            if len(x['tooth_number_pred'])>0:
                is_already_determined_nonmolar = True
                    
                    
    if not is_already_determined_nonmolar:
        # No nonmolar tooth is identified before, pick the closes distance pair to predict
        
        mindist = sys.float_info.max
        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y:  
                        dist = np.linalg.norm(x['center'] - y['center'])
                        mindist = dist if dist < mindist else mindist


        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] :
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y:  
                        dist = np.linalg.norm(x['center'] - y['center'])
                        if dist == mindist:
                            # x, y is the closest nonmolar   teeth. Use the classifier to predict it.

                            wx1 = np.linalg.norm(x['corners'][0] - x['corners'][1])
                            wx2 = np.linalg.norm(x['corners'][1] - x['corners'][2])
                            wx = min(wx1,wx2)/x['img_diag']

                            wy1 = np.linalg.norm(y['corners'][0] - y['corners'][1])
                            wy2 = np.linalg.norm(y['corners'][1] - y['corners'][2])
                            wy = min(wy1,wy2)/y['img_diag']

                            # img_diag should be the same in x and y as they in the same image
                            # the widths and box distance should be normalized before feed to classifier
                            dist = dist/x['img_diag']
                            
                       
                            if cls_sameQ[quad_class[2]] != None:
                                output_prob = cls_sameQ[quad_class[2]].predict_proba(np.array([[dist, wx, wy]]))
                                output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(output_prob)]
                                
                            else:
                                # a None classifier is due to insufficient t3 data to fit a classifier, only a workaround here.
                                output_prob = 1
                                if is_upperQ:
                                    output_label = 'U4_U5'
                                else:
                                    output_label = 'L4_L5'
                                
                            tootha, toothb = output_label.split('_')
                            x['tooth_number_pred'].append(tootha)
                            x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                            y['tooth_number_pred'].append(toothb)
                            y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))

                            tooth_register[tootha] = 'FOUND' 
                            tooth_register[toothb] = 'FOUND'
                            detect_count += 1
                            
                                
                            
    
    # a nonmolar tooth is identified before, predict from the identified one
    is_all_nonmolar_determined = False
    
    while not is_all_nonmolar_determined:
        
        mindist = sys.float_info.max
        
        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] and len(x['tooth_number_pred'])>0:
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y and \
                        len(y['tooth_number_pred'])==0:  
                        dist = np.linalg.norm(x['center'] - y['center'])
                        mindist = dist if dist < mindist else mindist


        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] and len(x['tooth_number_pred'])>0:
                for y in teeth:
                    if y['molar_nonmolar'] == quad_class[1] and not y['is_redundant'] and x is not y and \
                        len(y['tooth_number_pred'])==0:  

                        dist = np.linalg.norm(x['center'] - y['center'])
                        if dist == mindist:
                            # x, y is the closest nonmolar  and molar teeth. Use the classifier to predict it.

                            wx1 = np.linalg.norm(x['corners'][0] - x['corners'][1])
                            wx2 = np.linalg.norm(x['corners'][1] - x['corners'][2])
                            wx = min(wx1,wx2)/x['img_diag']

                            wy1 = np.linalg.norm(y['corners'][0] - y['corners'][1])
                            wy2 = np.linalg.norm(y['corners'][1] - y['corners'][2])
                            wy = min(wy1,wy2)/y['img_diag']

                            # img_diag should be the same in x and y as they in the same image
                            dist = dist/x['img_diag']
                            


                            # the widths and box distance should be normalized before feed to classifier
                            if cls_sameQ[quad_class[2]] != None:
                                output_prob = cls_sameQ[quad_class[2]].predict_proba(np.array([[dist, wx, wy]]))
                                #output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(output_prob)]
                                
                                # To prevent from duplicate prediction
                                try_prob = output_prob.copy()
                                is_found = False
                                while not is_found:
                                    output_label = lblenc_sameQ[quad_class[2]].classes_[np.argmax(try_prob)]
                                    tootha, toothb = output_label.split('_')

                                    if (tooth_register[tootha]=='FOUND' and tooth_register[toothb]=='FOUND') or \
                                        (tooth_register[tootha]!='FOUND' and tooth_register[toothb]!='FOUND') :
                                        # both teeth had been identified or both not found, use the lower conf predictions
                                        try_prob[0][np.argmax(try_prob)] = -1
                                    else:
                                        is_found=True

                                    # quit the loop if trying all 
                                    if try_prob[0][np.argmax(try_prob)] == -1:
                                        is_found = True
                                
                            else:
                                # a None classifier is due to insufficient t3 data to fit a classifier, only a workaround here.
                                output_prob = 1
                                if is_upperQ:
                                    output_label = 'U4_U5'
                                else:
                                    output_label = 'L4_L5'
                                    
                            
                            
                            tootha, toothb = output_label.split('_')

                            if x['tooth_number_pred'][-1]==tootha:
                                y['tooth_number_pred'].append(toothb)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                tooth_register[toothb]='FOUND' 
                                detect_count += 1
                            elif x['tooth_number_pred'][-1]==toothb:
                                y['tooth_number_pred'].append(tootha) 
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                tooth_register[tootha]='FOUND' 
                                detect_count += 1
                            else:
                                xlbl = x['tooth_number_pred'][-1]
                                print(f'*****WARNING: Previous stage determined {xlbl}, but now have ({tootha}, {toothb})')
                                
                                x['tooth_number_pred'].append(tootha)
                                x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                x['tooth_number_pred'].append(toothb)
                                x['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                y['tooth_number_pred'].append(tootha)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                y['tooth_number_pred'].append(toothb)
                                y['tooth_number_pred_conf'].append(str(round(np.max(output_prob),2)))
                                
                                tooth_register[tootha]='FOUND' 
                                tooth_register[toothb]='FOUND' 
                                detect_count += 1
                                
                                
        is_all_nonmolar_determined = True
        for x in teeth:
            if x['molar_nonmolar'] == quad_class[0] and not x['is_redundant'] and len(x['tooth_number_pred'])==0:
                is_all_nonmolar_determined = False
            
    return detect_count




# To identify the nonmolar pairs
def determine_opposite_quad(teeth, tooth_register, is_upperQ):
   
    detect_count = 0
    
    quad_class = []
    if is_upperQ:
        quad_class.append('U_nonmolar')
        quad_class.append('U_molar')
        quad_class.append('L_nonmolar')
        quad_class.append('L_molar')
    else:
        quad_class.append('L_nonmolar')
        quad_class.append('L_molar')
        quad_class.append('U_nonmolar')
        quad_class.append('U_molar')

    mindist = sys.float_info.max
    for x in teeth:
        
        mindist = sys.float_info.max
        
        if (x['molar_nonmolar'] == quad_class[0] or x['molar_nonmolar'] == quad_class[1]) and not x['is_redundant'] :
            for y in teeth:
                if (y['molar_nonmolar'] == quad_class[2] or y['molar_nonmolar'] == quad_class[3]) and not y['is_redundant'] \
                    and len(y['tooth_number_pred'])>0:
                    
                    dist = np.linalg.norm(x['center'] - y['center'])
                    mindist = dist if dist < mindist else mindist
                    
                    
            for y in teeth:
                if (y['molar_nonmolar'] == quad_class[2] or y['molar_nonmolar'] == quad_class[3]) and not y['is_redundant'] \
                    and len(y['tooth_number_pred'])>0:
                    
                    dist = np.linalg.norm(x['center'] - y['center'])
                    
                    if dist == mindist:
                         # x, y is the closest oppositve quad teeth. Use the classifier to predict it.

                        wx1 = np.linalg.norm(x['corners'][0] - x['corners'][1])
                        wx2 = np.linalg.norm(x['corners'][1] - x['corners'][2])
                        wx = min(wx1,wx2)/x['img_diag']

                        wy1 = np.linalg.norm(y['corners'][0] - y['corners'][1])
                        wy2 = np.linalg.norm(y['corners'][1] - y['corners'][2])
                        wy = min(wy1,wy2)/y['img_diag']

                        # img_diag should be the same in x and y as they in the same image
                        dist = dist/x['img_diag']
                        


                        y_tooth_lbl = y['tooth_number_pred'][np.argmax(y['tooth_number_pred_conf'])]
                        
                        output_prob = cls_oppositeQ[y_tooth_lbl].predict_proba(np.array([[dist, wx, wy]]))
                        output_label = lblenc_oppositeQ[y_tooth_lbl].classes_[np.argmax(output_prob)]
                        
                        tootha, toothb = output_label.split('_')

                        x['tooth_number_pred'].append(toothb)
                        x['tooth_number_pred_conf'].append('9999')

                        tooth_register[toothb]='FOUND' 
                       
                        detect_count += 1

    return detect_count

        

    
if not os.path.exists(stage3_inference_label_path):
    os.makedirs(stage3_inference_label_path)
    
# get the inference label
for path in os.listdir(stage2_inference_label_path):
    
    pf, sf = path.split('.')
    pil_img = Image.open(os.path.join(inference_img_path, pf + '.png'))
    
    img_diag = (pil_img.size[0]**2 + pil_img.size[1]**2)**0.5
 
    path = os.path.join(stage2_inference_label_path, path)    
    if os.path.isfile(path):
        with open(path, 'r') as f:
            
            infered_boxes = []
            lines = f.readlines()
            for row in lines:
                box  = {}
                row = row.split()
                box['molar_nonmolar_code'] = row[0]
                box['molar_nonmolar'] = detect_classes[int(row[0])]
                box['conf'] = float(row[-1])
            
                corners = np.array([[float(row[1]), float(row[2])], 
                            [float(row[3]), float(row[4])],
                            [float(row[5]), float(row[6])],
                            [float(row[7]), float(row[8])]])

                
                corners = sort_coordinates(corners,True)
                box['corners'] = corners
                
                # center coordinates 
                box['center'] = 0.5*(corners[0] + corners[2])
                
                box['img_diag'] = img_diag
                
                
                # a flag to indicate whether the box is overlapped with IoU > threshold. If True, it will be ignored.
                box['is_redundant'] = False
                                   
                box['tooth_number_pred'] = []
                box['tooth_number_pred_conf'] = []
                 
                infered_boxes.append(box)
               
            # a register to hold the detected result of each image. UNK=unknown, FOUND, MISSING
            infered_tooth_register = {}
            for lbl in valid_lbls['UpperQ']:
                infered_tooth_register[lbl] = 'UNK'
            for lbl in valid_lbls['LowerQ']:
                infered_tooth_register[lbl] = 'UNK'
            
            check_overlapped_teeth_box(infered_boxes)
           
            # First step is to determine closest {nonmolar, molar} pair  
            count_upper_nomolar_molar = determine_nonmolar_molar(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=True)
            print(f'{pf} has {count_upper_nomolar_molar} Upper (Nonmolar,Molar) detection!')
                
            count_lower_nomolar_molar = determine_nonmolar_molar(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=False)
            print(f'{pf} has {count_lower_nomolar_molar} Lower (Nonmolar,Molar) detection!')
                
            
            # Second step is to determine molar pairs, also given a molar has been found in first step
            count = determine_molar(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=True)
            print(f'{pf} has {count} Upper (Molar,Molar) detection!')
                
            count = determine_molar(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=False)
            print(f'{pf} has {count} Lower (Molar,Molar) detection!')
            
            
            # Third step is to determine nonmolar pairs, also given a nonmolar has been found in first step
            count = determine_nonmolar(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=True)
            print(f'{pf} has {count} Upper (Nonmolar,Nonmolar) detection!')
                
            count = determine_nonmolar(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=False)
            print(f'{pf} has {count} Lower (Nonmolar,Nonmolar) detection!')
            
            
            # Finally, if no nonmolar and molar pair identified in the same quad, then we need to opposite quad for prediction
            # opposite quad prediction is overriding as it is more accurate
            if count_upper_nomolar_molar==0 and count_lower_nomolar_molar>0:
                count = determine_opposite_quad(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=True)
                print(f'{pf} has {count} Upper detection using opposite quadrant tooth!')
                
            if count_lower_nomolar_molar==0 and count_upper_nomolar_molar>0:
                count = determine_opposite_quad(teeth=infered_boxes, tooth_register=infered_tooth_register, is_upperQ=False)
                print(f'{pf} has {count} Lower detection using opposite quadrant tooth!')
                
            
        

            ############################################################
                
            # write the result to label files
            with open(os.path.join(stage3_inference_label_path, pf + '.txt'), 'w') as f2:
                lines = []
                for box in infered_boxes:
                    
                    box_dup = 'redundant_true'
                    if not box['is_redundant']:
                        box_dup = 'redundant_false'
                    
                    le = box['molar_nonmolar_code'] + ' '+ ' '.join([*[str(r) for r in box['corners'].flatten()]]) + ' ' \
                        + str(box['conf']) + ' ' + box_dup + ' [' + ','.join(box['tooth_number_pred']) + '] ' \
                        + '[' +  ','.join(box['tooth_number_pred_conf']) + ']\n'
                    lines.append(le)

                f2.writelines(lines)
            
            

   

In [None]:
%matplotlib notebook

# try to visualize results
import cv2
img_name = 'Image221.png'

pf, sf = img_name.split('.')
img = cv2.imread(os.path.join(inference_img_path, img_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
label_fname = pf+'.txt'
dimg = img.copy()
with open(os.path.join(stage3_inference_label_path, label_fname), 'r') as f:
    rows = f.readlines()
    for row in rows:
        if row.strip()!='YOLO_OBB':
            row = row.split(' ')
            x1,y1 = float(row[1]), float(row[2])
            x2,y2 = float(row[3]), float(row[4])
            x3,y3 = float(row[5]), float(row[6])
            x4,y4 = float(row[7]), float(row[8])
            
           
            tooth_lbl = row[11]
            
            if tooth_lbl != 'UNK':
                dimg = cv2.drawContours(dimg, [np.asarray([[x1,y1], [x2,y2], [x3,y3], [x4,y4]], dtype=int)], 0, color=(0,255,0), thickness=4)
                dimg = cv2.putText(img=dimg, text=tooth_lbl, org=np.array([(x1+x3)/2, (y1+y3)/2], dtype=int), fontFace=cv2.FONT_HERSHEY_SIMPLEX, 
                                   fontScale=2, color=(0, 200, 0), thickness=2, lineType=cv2.LINE_AA)
            

    plt.imshow(dimg)



### Compute the classification numbering with Stage 1 and 3 results to get Stage 4 predictions

In [None]:
stage1_inference_label_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage1_dataset_LR_classifier/test_preds'
groundtruth_label_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/val/labelTxt_32Cls'
stage4_inference_label_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage4/stage4outputlabels'

shutil.rmtree(stage4_inference_label_path)

if not os.path.exists(stage4_inference_label_path):
    os.makedirs(stage4_inference_label_path)

# This decides the left / right sides of the image from stage 1
image_quad = {}
for fname in os.listdir(stage1_inference_label_path):
    path = os.path.join(stage1_inference_label_path, fname)
    if os.path.isfile(path):
        with open(path, 'r') as f:
            line = f.readline().strip()
            imgname, _ = fname.split('.')
            image_quad[imgname] = line
            

teeth_map ={ 'RIGHT': {'U3':'t13' , 'U4':'t14', 'U5':'t15', 'U6':'t16', 'U7':'t17', 'U8':'t18',
                       'L3':'t43' , 'L4':'t44', 'L5':'t45', 'L6':'t46', 'L7':'t47', 'L8':'t48'},
           
             'LEFT': {'U3':'t23' , 'U4':'t24', 'U5':'t25', 'U6':'t26', 'U7':'t27', 'U8':'t28',
                      'L3':'t33' , 'L4':'t34', 'L5':'t35', 'L6':'t36', 'L7':'t37', 'L8':'t38'}
              } 


teeth_predictions = {}
for fname in os.listdir(groundtruth_label_path):
    path = os.path.join(groundtruth_label_path, fname)
    if os.path.isfile(path):
        with open(path, 'r') as f:
            imgname, _ = fname.split('.')
            teeth_predictions[imgname] = {}
            teeth_predictions[imgname]['ground_truth'] = {}
            
            for k, row in enumerate(f.readlines()):
                row = row.split()
                rect_coor = np.array( [ [float(row[0]), float(row[1])],
                                        [float(row[2]), float(row[3])],
                                        [float(row[4]), float(row[5])],
                                        [float(row[6]), float(row[7])] ] )
                tooth_label = row[8]
                line = [tooth_label, rect_coor]
                teeth_predictions[imgname]['ground_truth'][k] = line


for fname in os.listdir(stage3_inference_label_path):
    path = os.path.join(stage3_inference_label_path, fname)
    if os.path.isfile(path):
        with open(path, 'r') as f:
            imgname, _ = fname.split('.')
            teeth_predictions[imgname]['prediction'] = {}
            
            for k, row in enumerate(f.readlines()):
                row = row.split()
                
                lbs = row[-2].strip('[]').split(',')
                confs = [float(x) for x in row[-1].strip('[]').split(',') if len(x)>0]
                lb = lbs[np.argmax(confs)] if len(confs)>0 else None

                
                tooth_label = teeth_map[image_quad[imgname]].get(lb, 'unknown')
                
                rect_coor = np.array( [ [float(row[1]), float(row[2])],
                                        [float(row[3]), float(row[4])],
                                        [float(row[5]), float(row[6])],
                                        [float(row[7]), float(row[8])] ] )
                
                conf = float(row[9])
                line = [tooth_label, rect_coor, conf]
                teeth_predictions[imgname]['prediction'][k] = line
                #print(f'{fname} {line}')
                



In [None]:
def calculate_iou(r1, r2):
    r1 = Polygon(r1)
    r2 = Polygon(r2)

    iou = r1.intersection(r2).area / r1.union(r2).area
    return iou
            
    
# Export the combined prediction results
def ExportInferenceFiles(export_path, img_names, pred_labels, pred_rects, pred_conf):
 
    if not os.path.exists(export_path):
        os.makedirs(export_path)
    else:
        shutil.rmtree(export_path)
        os.makedirs(export_path)
        
        
    for fname, lbl, rect, conf in zip(img_names, pred_labels, pred_rects, pred_conf):
        path = os.path.join(export_path, fname + '.txt')
        with open(path, 'a') as f:
            if lbl != 'unknown':
                rect = [str(r) for r in rect.flatten()]
                line = ' '.join([lbl, *rect, str(conf)]) + '\n'
                f.writelines(line)


img_names = []
groud_truth_labels = []
pred_labels = []
pred_conf = []
box_ious = []
pred_rects = []

for img, val in teeth_predictions.items():
    
    gt_rects = [ x[1] for x in val['ground_truth'].values()]
        
    for gt_lbl, gt_rect in val['ground_truth'].values():
        
        matched_iou = 0
        matched_pd_lbl = 'unknown'
        matched_pd_rect = []
        matched_pd_conf = 0
        
        #Get the most overlapped predicted box with the gt box
        if val.get('prediction', None) is not None:
        
            for pd_lbl, pd_rect, pd_conf in val['prediction'].values():
                iou = calculate_iou(gt_rect, pd_rect)
                if iou>0.5:
                    matched_pd_lbl = pd_lbl if iou > matched_iou else matched_pd_lbl
                    matched_pd_rect = pd_rect if iou > matched_iou else matched_pd_rect
                    matched_pd_conf = pd_conf if iou > matched_iou else matched_pd_conf
                    matched_iou = iou if iou > matched_iou else matched_iou

        if gt_lbl not in ['t11', 't12', 't21', 't22', 't31', 't32', 't41', 't42']:
            img_names.append(img)
            groud_truth_labels.append(gt_lbl)
            pred_labels.append(matched_pd_lbl)
            box_ious.append(matched_iou)
            pred_conf.append(matched_pd_conf)
            pred_rects.append(matched_pd_rect)

 
# Export the combined prediction results
ExportInferenceFiles(stage4_inference_label_path,  img_names, pred_labels, pred_rects, pred_conf)

        
        

### Examine and compute the performance metrics

In YoLo, the confidence score is confidence whether the bounding box has the object of that class.
Each grid cell also predicts C conditional class probabilities Pr(Classi|Object). It only predicts one set of class probabilities per grid cell, regardless of the number of boxes B. During testing, these conditional class probabilities are multiplied by individual box confidence predictions which give class-specific confidence scores for each box. These scores show both the probability of that class and how well the box fits the object.
https://towardsdatascience.com/object-detection-part1-4dbe5147ad0a

Pr(Class i|Object)*Pr(Object)*IoU = Pr(Class i)*IoU.


In [None]:
fig, ax = plt.subplots(2,1, figsize = (8,5))

ax[0].hist(box_ious, bins=100)
ax[1].hist(pred_conf, bins=100)
plt.show()


In [None]:
print(f'Average IoU at IoU_Threshold=0 : {round(sum(box_ious)/len(box_ious),3)}' )
       

In [None]:
from sklearn.metrics import classification_report

def CustomizeClassificationReport(gt_lbls, pred_lbls):

    cr_dict = classification_report(gt_lbls, pred_lbls, zero_division=0, output_dict=True)
    precisions = []
    recalls = []
    f1_scores = []
    for cr_class, cr_val in cr_dict.items():
        if cr_class in set(groud_truth_labels):
            precisions.append(cr_val['precision'])
            recalls.append(cr_val['recall'])
            f1_scores.append(cr_val['f1-score'])
            line = cr_class + ',' + str(round(cr_val['precision'],3)) + ',' + str(round(cr_val['recall'],3)) + ',' + str(round(cr_val['f1-score'],3))
            print(line)

    line = 'Macro Avg,' + str(round(np.mean(precisions),3)) + ',' + str(round(np.mean(recalls),3)) + ',' + str(round(np.mean(f1_scores),3))
    print(line)
    
print(f"Classification report at IoU_Threshold=0:\n ")
CustomizeClassificationReport(groud_truth_labels, pred_labels)

### Examine the case when directly using YOLO to train 32 classes 

In [None]:
from sklearn.metrics import classification_report

valid_32cls_lbls = ['t11','t12','t13','t14','t15','t16','t17','t18',
                  't21','t22','t23','t24','t25','t26','t27','t28',
                  't31','t32','t33','t34','t35','t36','t37','t38',
                  't41','t42','t43','t44','t45','t46','t47','t48']

valid_16cls_lbls = ['U1','U2','U3','U4','U5','U6','U7','U8',
                  'L1','L2','L3','L4','L5','L6','L7','L8']


teeth_map ={ 'RIGHT': {'U1':'t11' , 'U2':'t12', 'U3':'t13' , 'U4':'t14', 'U5':'t15', 'U6':'t16', 'U7':'t17', 'U8':'t18',
                       'L1':'t41' , 'L2':'t42', 'L3':'t43' , 'L4':'t44', 'L5':'t45', 'L6':'t46', 'L7':'t47', 'L8':'t48'},
           
             'LEFT': {'U1':'t21' , 'U2':'t22', 'U3':'t23' , 'U4':'t24', 'U5':'t25', 'U6':'t26', 'U7':'t27', 'U8':'t28',
                      'L1':'t31' , 'L2':'t32', 'L3':'t33' , 'L4':'t34', 'L5':'t35', 'L6':'t36', 'L7':'t37', 'L8':'t38'}
              } 


def ComputeDetection(yolo_label_path, no_classes, inf_label_path):
    
    if no_classes != 16 and no_classes !=32:
        print(f'Invalid no_classes: {no_classes}')
        return
        
    teeth_predictions_yolo_direct = {}
    for fname in os.listdir(groundtruth_label_path):
        path = os.path.join(groundtruth_label_path, fname)
        if os.path.isfile(path):
            with open(path, 'r') as f:
                imgname, _ = fname.split('.')
                teeth_predictions_yolo_direct[imgname] = {}
                teeth_predictions_yolo_direct[imgname]['ground_truth'] = {}

                for k, row in enumerate(f.readlines()):
                    row = row.split()
                    rect_coor = np.array( [ [float(row[0]), float(row[1])],
                                            [float(row[2]), float(row[3])],
                                            [float(row[4]), float(row[5])],
                                            [float(row[6]), float(row[7])] ] )
                    tooth_label = row[8]
                    line = [tooth_label, rect_coor]
                    teeth_predictions_yolo_direct[imgname]['ground_truth'][k] = line


    for fname in os.listdir(yolo_label_path):
        path = os.path.join(yolo_label_path, fname)
        if os.path.isfile(path):
            with open(path, 'r') as f:
                imgname, _ = fname.split('.')
                teeth_predictions_yolo_direct[imgname]['prediction'] = {}

                for k, row in enumerate(f.readlines()):
                    row = row.split()
                    
                    if no_classes==32:
                        tooth_label = valid_32cls_lbls[int(row[0])]
                    elif no_classes==16:
                        tooth_label = valid_16cls_lbls[int(row[0])]

                    rect_coor = np.array( [ [float(row[1]), float(row[2])],
                                            [float(row[3]), float(row[4])],
                                            [float(row[5]), float(row[6])],
                                            [float(row[7]), float(row[8])] ] )

                    conf = float(row[9])
                    line = [tooth_label, rect_coor, conf]
                    teeth_predictions_yolo_direct[imgname]['prediction'][k] = line
                    #print(f'{fname} {line}')


    img_names = []
    groud_truth_labels = []
    pred_labels = []
    pred_conf = []
    box_ious = []
    pred_rects = []
    
    for img, val in teeth_predictions_yolo_direct.items():

        for gt_lbl, gt_rect in val['ground_truth'].values():

            matched_iou = 0
            matched_pd_lbl = 'unknown'
            matched_pd_rect = []
            matched_pd_conf = 0

            #Get the overlapped predicted box>0.5 with the gt box, and the one with highest conf
            if val.get('prediction', None) is not None:

                max_conf = 0
                for pd_lbl, pd_rect, pd_conf in val['prediction'].values():

                    iou = calculate_iou(gt_rect, pd_rect)                
                    if iou>0.5:
                        if no_classes==16:
                            #################
                            # Correction by using Stage 1 results
                            pd_lbl = teeth_map[image_quad[img]][pd_lbl]
                            
                            matched_pd_lbl = pd_lbl if max_conf<pd_conf else matched_pd_lbl
                            matched_pd_rect = pd_rect if max_conf<pd_conf else matched_pd_rect
                            matched_pd_conf = pd_conf if max_conf<pd_conf else matched_pd_conf

                            matched_iou = iou if max_conf<pd_conf else matched_iou
                            max_conf = pd_conf if max_conf<pd_conf else max_conf
                            ###################
                        
                        if no_classes==32:
                            #################
                            # only those predicted box with iou > 0.5 is going to be considered. Take those with highest confidence as prediction 
                            matched_pd_lbl = pd_lbl if max_conf<pd_conf else matched_pd_lbl
                            matched_pd_rect = pd_rect if max_conf<pd_conf else matched_pd_rect
                            matched_pd_conf = pd_conf if max_conf<pd_conf else matched_pd_conf
                            matched_iou = iou if max_conf<pd_conf else matched_iou
                            max_conf = pd_conf if max_conf<pd_conf else max_conf
                            
                            ###################
                                

            if gt_lbl not in ['t11', 't12', 't21', 't22', 't31', 't32', 't41', 't42']:
                img_names.append(img)
                groud_truth_labels.append(gt_lbl)
                pred_labels.append(matched_pd_lbl)
                box_ious.append(matched_iou)
                pred_conf.append(matched_pd_conf)
                pred_rects.append(matched_pd_rect)


    print(f'Average IoU at IoU_Threshold=0 : {round(sum(box_ious)/len(box_ious),3)}' )

    print(f"Classification report at IoU_Threshold=0:")
    CustomizeClassificationReport(groud_truth_labels, pred_labels)
    
    ExportInferenceFiles(inf_label_path,  img_names, pred_labels, pred_rects, pred_conf)
    
    
yolo5n_32cls_inf_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_32cls_scratch_yolo5n_detect_conf0p6_iou0p45_/labels'
yolo5n_16cls_inf_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_16cls_scratch_yolo5n_detect_conf0p6_iou0p45_/labels'
    
yolo5m_32cls_inf_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_32cls_scratch_yolo5m_detect_conf0p6_iou0p45_/labels'
yolo5m_16cls_inf_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_16cls_scratch_yolo5m_detect_conf0p6_iou0p45_/labels'

yolo5x_32cls_inf_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_32cls_scratch_yolo5x_detect_conf0p6_iou0p45_/labels'
yolo5x_16cls_inf_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage2_yolo_obb/6XX_16cls_scratch_yolo5x_detect_conf0p6_iou0p45_/labels'


yolo5m_32cls_output_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage4/stage4outputlabels_yolo5m_32cls'
yolo5m_16cls_output_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage4/stage4outputlabels_yolo5m_16cls'



ComputeDetection(yolo5m_32cls_inf_path, 32, yolo5m_32cls_output_path)

ComputeDetection(yolo5m_16cls_inf_path, 16, yolo5m_16cls_output_path)



### Generate predicted images for evaluation
        

In [None]:
val_img_path

In [None]:
import cv2

stage4_inference_img_path = './4 Private ZOON Anonymised/Bitewing_proc/labels_trial6XX/Stage4/stage4predictedimages'


if not os.path.exists(stage4_inference_img_path):
    os.makedirs(stage4_inference_img_path)
else:
    shutil.rmtree(stage4_inference_img_path)
    os.makedirs(stage4_inference_img_path)
    

def sort_coordinates(list_of_xy_coords, is_clockwise):
    cx, cy = list_of_xy_coords.mean(0)
    x, y = list_of_xy_coords.T
    angles = np.arctan2(x-cx, y-cy)
    indices = np.argsort(-1*angles) if is_clockwise else np.argsort(angles)
    
    return list_of_xy_coords[indices]

def OutputPredImages(inf_lbl_path, ind_lbl, ind_coors, img_suffix, lbl_map, annotation_color):
    for fname in os.listdir(inf_lbl_path):
        path = os.path.join(inf_lbl_path, fname)
        imgname, _ = fname.split('.')
        imgname = imgname + '.png'

        img = cv2.imread(os.path.join(val_img_path, imgname))

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if os.path.isfile(path):
            with open(path, 'r') as f:
                rows = f.readlines()
                for row in rows:
                    if row.strip()!='YOLO_OBB':
                        row = row.split(' ')
                        x1,y1 = float(row[ind_coors[0]]), float(row[ind_coors[1]])
                        x2,y2 = float(row[ind_coors[2]]), float(row[ind_coors[3]])
                        x3,y3 = float(row[ind_coors[4]]), float(row[ind_coors[5]])
                        x4,y4 = float(row[ind_coors[6]]), float(row[ind_coors[7]])
                        
                        coors = np.asarray([[x1,y1], [x2,y2], [x3,y3], [x4,y4]], dtype=int)
                        coors = sort_coordinates(coors, True)
                        
                        if lbl_map is None:
                            tooth_lbl = row[ind_lbl]
                        else:
                            tooth_lbl = lbl_map[int(row[ind_lbl])]
                        
                        diag = np.linalg.norm(np.array([x1,y1])-np.array([x3,y3]))
                        
                        text_coor = np.array(0.5*(coors[0] + coors[2]) , dtype=int)
                      
                    
                        img = cv2.drawContours(img, [coors], 0, color=annotation_color, thickness=4)
                        img = cv2.putText(img=img, text=tooth_lbl, org=text_coor, fontFace=cv2.FONT_HERSHEY_SIMPLEX, 
                                           fontScale=2, color=annotation_color, thickness=2, lineType=cv2.LINE_AA)

                imgname, _ = fname.split('.')
                imgname = imgname + '_' + img_suffix + '.png'
                cv2.imwrite(os.path.join(stage4_inference_img_path, imgname), img)
    
    
# Output the Ground Truth prediction
OutputPredImages(val_label_32cls_path, 8, range(0,8), 'gt', None, (255, 0, 100))
    
# Output the Stage 4 prediction
OutputPredImages(stage4_inference_label_path, 0, range(1,9), 'pred', None, (0, 0, 250))
            


In [None]:
# Output the 32cls detection images
OutputPredImages(yolo5m_32cls_output_path, 0, range(1,9), '32cls_direct', None, (50, 250, 10))

In [None]:
OutputPredImages(yolo5m_16cls_output_path, 0, range(1,9), '16cls_direct', None, (14, 253, 255))