## 牙齿分割模型训练与测试

### 模型训练

#### 数据预处理

使用swfaug.ipynb中的相关程序

#### 生成训练数据

```bash
./run.sh build_data
```

#### 训练模型

```bash
./run.sh train
```

#### 生成模型文件

```bash
./run.sh export
```

#### 生成测试结果

```bash
./run.sh eval -o teeth_swf/result
```

### 生成可视化的结果图

In [1]:
# Package import
import os
from pathlib import Path
from shutil import rmtree
from collections import Counter

from matplotlib import pyplot as plt
import matplotlib as mpl
from tqdm.auto import tqdm
import numpy as np
from PIL import Image, ImageDraw, ImageFont

In [52]:
# File paths define
pwd = Path(os.path.dirname(os.getcwd()))
base_path = pwd / 'eval/full-VOC_adult_20200512'
ground_truth = base_path / 'SegmentationClass'
new_ground_truth = base_path / 'SegmentationClass'
source_img = base_path / 'JPEGImages'
source_img_tumb = base_path / 'JPEGImages_tumb'
gt_vis = base_path / 'newSegmentationClass_vis'
gt_vis_tumb = base_path / 'newSegmentationClass_vis_tumb'

teeth_path = base_path
numpy_result = teeth_path / 'result'
# plt_result = teeth_path / 'plt_result'
vis_result = teeth_path / 'result_vis'
vis_result_tumb = teeth_path / 'result_vis_tumb'
# matlab_result = teeth_path / 'result_matlab'

In [53]:
tumb_size = (100, 100)
vis_result.mkdir(exist_ok=True)
vis_result_tumb.mkdir(exist_ok=True)
gt_vis_tumb.mkdir(exist_ok=True)
source_img_tumb.mkdir(exist_ok=True)

for result_np_path in tqdm(list(numpy_result.glob('*.npy'))):
    # Get file paths.
    source_img_path = next(source_img.glob(result_np_path.stem + '.*'))
    source_img_tumb_path = source_img_tumb / source_img_path.name
#     gt_img_path = ground_truth / (result_np_path.stem + '.png')
    gt_vis_path = gt_vis / (result_np_path.stem + '.png')
    gt_vis_tumb_path = gt_vis_tumb / (result_np_path.stem + '.png')
    res_vis_path = vis_result / (result_np_path.stem + '.jpg')
    res_vis_tumb_path = vis_result_tumb / (result_np_path.stem + '.jpg')
#     res_matlab_path = matlab_result / (result_np_path.stem + '.png')
    
    # Open images and numpy arrays.
    src_img = Image.open(str(source_img_path)).resize((513, 513))
#     gt_img = Image.open(str(gt_img_path))
    gt_vis_img = Image.open(gt_vis_path)
    res_np = np.load(str(result_np_path))
    
    # Draw result images.
    dst_img = src_img.copy()
    draw_img = ImageDraw.Draw(dst_img, mode='RGBA')
    x, y = np.where(res_np == 2)
    dental_point = np.vstack((y, x)).T.flatten()
    draw_img.point(list(dental_point), fill=(255, 255, 0, 64))
    # Save result images.
    dst_img.save(res_vis_path)
    gt_vis_img.resize(tumb_size).save(gt_vis_tumb_path)
    dst_img.resize(tumb_size).save(res_vis_tumb_path)
    src_img.resize(tumb_size).save(source_img_tumb_path)
    
    # Save result image for matlab
#     res_matlab_np = res_np.astype('uint8')
#     res_matlab_np[res_matlab_np == 1] = 125
#     res_matlab_np[res_matlab_np == 2] = 255
#     res_matlab_img = Image.fromarray(res_matlab_np)
#     res_matlab_img.save(str(res_matlab_path))

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))




### 定义IOU计算函数

In [3]:
import os
import csv
from pathlib import Path
from collections import Counter

import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from docx import Document

In [15]:
def compute_metric(pred_path, gt_path):
    res_np = np.load(pred_path).astype('uint8')
    gt_img = Image.open(gt_path).resize(res_np.shape[-2:])
    gt_np = np.asarray(gt_img, dtype='uint8').copy()
    gt_np[gt_np == 125] = 1
    gt_np[gt_np == 255] = 2
    
    res_np = res_np * (gt_np > 0)
    intersection = res_np * (res_np==gt_np)

    area_pred = Counter(res_np.flatten())
    area_lab = Counter(gt_np.flatten())
    area_intersection = Counter(intersection.flatten())
    area_union = area_pred + area_lab - area_intersection
    sum_iou = lambda x: x[1]+x[2]
    iou = sum_iou(area_intersection) / sum_iou(area_union)
    
    acc = (res_np.flatten() == gt_np.flatten()).sum() / res_np.flatten().size
    
    return iou, acc

#### 写入结果到文档

In [54]:
def add_row(row, pics, iou, acc):
    for i, pic in enumerate(pics):
        row.cells[i].paragraphs[0].add_run().add_picture(str(pic))
    row.cells[i+1].text = iou
    row.cells[i+2].text = acc

document = Document()
document.add_heading('牙菌斑测试结果', 0)
doc_rows = len(list(source_img.glob('*'))) + 1
table = document.add_table(rows=doc_rows, cols=6)

iou_sum = 0
acc_sum = 0
count = 0

for i, source_img_path in enumerate(tqdm(sorted(list(source_img.glob('*'))))):
    key = source_img_path.stem
    source_img_tumb_path = source_img_tumb / source_img_path.name
    gt_img_path = new_ground_truth / (key + '.png')
    gt_vis_tumb_path = gt_vis_tumb / (key + '.png')
    teeth_res_vis_path = vis_result_tumb / (key + '.jpg')
    teeth_numpy_path = numpy_result / (key + '.npy')
    iou, acc = compute_metric(teeth_numpy_path, gt_img_path)
    
    count += 1
    acc_sum += acc
    iou_sum += iou
    iou = f'{iou:.5f}'
    acc = f'{acc:.5f}'
    
    row = table.rows[i+1]
    add_row(row, [source_img_tumb_path, gt_vis_tumb_path, teeth_res_vis_path], iou, acc)
    
document.save("result.docx")
iou_sum / count

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))




0.2901441304514237

In [55]:
acc_sum / count

0.6147380580539502

In [34]:
# Teeth_ORG: 0.711884330699641
# Teeth: 0.8491325351978396
# SWFAug: 0.8356691834862533
# SWFAug2: 0.851197442264277

In [None]:
# 2166 121090