## 牙齿分割模型训练与测试

### 模型训练

#### 数据预处理

使用swfaug.ipynb中的相关程序

#### 生成训练数据

```bash
./run.sh build_data
```

#### 训练模型

```bash
./run.sh train
```

#### 生成模型文件

```bash
./run.sh export
```

#### 生成测试结果

```bash
./run.sh eval -o teeth_swf/result
```

### 生成可视化的结果图

In [2]:
# Package import
import os
from pathlib import Path
from shutil import rmtree
from collections import Counter

from matplotlib import pyplot as plt
import matplotlib as mpl
from tqdm.auto import tqdm
import numpy as np
from PIL import Image, ImageDraw, ImageFont

In [16]:
# File paths define
pwd = Path(os.path.dirname(os.getcwd()))
base_path = pwd / 'Raw-GroundTruth-Testing'
ground_truth = base_path / 'groundtruth'
source_img = base_path / 'image'

teeth_path = base_path / 'teeth'
numpy_result = teeth_path / 'result'
plt_result = teeth_path / 'plt_result'
vis_result = teeth_path / 'result_vis'
vis_result_tumb = teeth_path / 'result_vis_tumb'
matlab_result = teeth_path / 'result_matlab'

In [17]:
vis_result.mkdir(exist_ok=True)
matlab_result.mkdir(exist_ok=True)
vis_result_tumb.mkdir(exist_ok=True)

for result_np_path in tqdm(list(numpy_result.glob('*.npy'))):
    # Get file paths.
    source_img_path =  next(source_img.glob(result_np_path.stem + '.*'))
    gt_img_path = ground_truth / (result_np_path.stem + '.png')
    res_vis_path = vis_result / (result_np_path.stem + '.jpg')
    res_vis_tumb_path = vis_result_tumb / (result_np_path.stem + '.jpg')
    res_matlab_path = matlab_result / (result_np_path.stem + '.png')
    
    # Open images and numpy arrays.
    src_img = Image.open(str(source_img_path))
    gt_img = Image.open(str(gt_img_path))
    res_np = np.load(str(result_np_path))
    
    # Draw result images.
    dst_img = src_img.copy()
    draw_img = ImageDraw.Draw(dst_img, mode='RGBA')
    x, y = np.where(res_np == 2)
    dental_point = np.vstack((y, x)).T.flatten()
    draw_img.point(list(dental_point), fill=(255, 255, 0, 64))
    # Save result images.
    dst_img.save(res_vis_path)
    dst_img.resize((100, 100)).save(res_vis_tumb_path)
    
    # Save result image for matlab
    res_matlab_np = res_np.astype('uint8')
    res_matlab_np[res_matlab_np == 1] = 125
    res_matlab_np[res_matlab_np == 2] = 255
    res_matlab_img = Image.fromarray(res_matlab_np)
    res_matlab_img.save(str(res_matlab_path))

HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))




### 定义IOU计算函数

In [36]:
import os
import csv
from pathlib import Path
from collections import Counter

import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from docx import Document

In [37]:
def compute_iou(pred_path, gt_path):
    gt_img = Image.open(gt_path).resize((513, 513))
    gt_np = np.asarray(gt_img, dtype='uint8').copy()
    gt_np[gt_np == 125] = 1
    gt_np[gt_np == 255] = 2
    
    res_np = np.load(str(pred_path)).astype('uint8')
    res_np = res_np * (gt_np > 0)
    intersection = res_np * (res_np==gt_np)

    area_pred = Counter(res_np.flatten())
    area_lab = Counter(gt_np.flatten())
    area_intersection = Counter(intersection.flatten())
    area_union = area_pred + area_lab - area_intersection
    sum_iou = lambda x: x[1]+x[2]
    iou = sum_iou(area_intersection) / sum_iou(area_union)
    
    return iou

#### 写入结果到文档

In [43]:
# File paths define
pwd = Path(os.path.dirname(os.getcwd()))
base_path = pwd / 'Raw-GroundTruth-Testing'
ground_truth = base_path / 'groundtruth'
ground_truth_tumb = base_path / 'groundtruth_tumb'
source_img = base_path / 'image_tumb'

teeth_path = base_path / 'teeth_pzn'
teeth_vis_result = teeth_path / 'result_vis_tumb'

swfaug_path = base_path / 'teeth_swf'
swfaug_vis_result = swfaug_path / 'result_vis_tumb'
swfaug_numpy_result = swfaug_path / 'result'

In [44]:
def add_row(row, pic1, pic2, pic3, pic4, iou):
    row.cells[0].paragraphs[0].add_run().add_picture(str(pic1))
    row.cells[1].paragraphs[0].add_run().add_picture(str(pic2))
    row.cells[2].paragraphs[0].add_run().add_picture(str(pic3))
    row.cells[3].paragraphs[0].add_run().add_picture(str(pic4))
    row.cells[4].text = iou

document = Document()
document.add_heading('牙菌斑测试结果', 0)
table = document.add_table(rows=140, cols=5)

iou_sum = 0
count = 0

for i, source_img_path in enumerate(tqdm(sorted(list(source_img.glob('*'))))):
    key = source_img_path.stem
    gt_img_path = ground_truth / (key + '.png')
    gt_img_path_tumb = ground_truth_tumb / (key + '.png')
    teeth_res_vis_path = teeth_vis_result / (key + '.jpg')
    teeth_res_vis_path = teeth_vis_result / (key + '.jpg')
    swfaug_res_vis_path = swfaug_vis_result / (key + '.jpg')
    swfaug_numpy_result_path = swfaug_numpy_result / (key + '.npy')
    iou = compute_iou(swfaug_numpy_result_path, gt_img_path)
    
    count += 1
    iou_sum += iou
    iou = f'{iou:.5f}'
    
    row = table.rows[i+1]
    add_row(row, source_img_path, gt_img_path_tumb, teeth_res_vis_path, swfaug_res_vis_path, iou)
    
document.save(swfaug_path.stem + ".docx")
iou_sum / count

HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))




0.8356691834862533

In [45]:
# Teeth: 0.8491325351978396
# SWFAug: 0.8356691834862533