In [None]:
import random

from skimage import io as skio
from skimage import transform

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt

import pathlib
from dataset_utils import mk_dataset
from model import residual_unet, losses


In [None]:
def build_model(pret=None, freeze_enc=False, freeze_dec=False):
    input_shape = (256, 256, 3)
    model = residual_unet.unet(
        input_shape,
        name="unet",
        parallel_dilated=True,
    )
    if pret is not None:
        model.load_weights(pret)

    for layer in model.layers:
        if "down" in layer.name and freeze_enc:
            layer.trainable = False
        if "up" in layer.name and freeze_dec:
            layer.trainable = False

    optimizer = keras.optimizers.Adam(learning_rate=0.001, name="adam")
    loss = losses.DICELoss(name="dice")
    metrics = (
        keras.metrics.MeanIoU(num_classes=2, name="mean_iou"),
        keras.metrics.Precision(name="presision"),
        keras.metrics.Recall(name="recall"),
    )

    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    return model


pret = "/Volumes/GoogleDrive/マイドライブ/卒研/220111/checkpoints/DA_FRZ-DICE-E50-MR1.0/DA_FRZ-DICE-E50-MR1.0"
model = build_model(pret)
model.save_weights("DA_FRZ-DICE-E50-MR1.0")


In [None]:
model = None
input_shape = (256, 256, 3)
model = residual_unet.unet(
    input_shape,
    name="unet",
    parallel_dilated=True,
)
model.load_weights("DA_FRZ-DICE-E50-MR1.0")
model.save("DA_FRZ-DICE-E50-MR1.0")


In [33]:
model = None
model = keras.models.load_model("DA_FRZ-DICE-E50-MR1.0")





In [None]:
model = keras.models.load_model("basemodel_DA_FZHB-DICE-MR-E50")

freeze_enc = True
freeze_dec = False
for layer in model.layers:
    if "down" in layer.name and freeze_enc:
        layer.trainable = False
    if "up" in layer.name and freeze_dec:
        layer.trainable = False
[(layer.name, layer.trainable) for layer in model.layers]


In [None]:
def gends(ds_root, suffix):
    ds_root = pathlib.Path(ds_root)
    pathlist = ds_root.glob(f"map/*.{suffix}")
    pathlist = sorted([path.name for path in pathlist])
    sat_pathlist = sorted([str(ds_root / "sat" / path) for path in pathlist])
    map_pathlist = sorted([str(ds_root / "map" / path) for path in pathlist])
    test_ds = mk_dataset.mkds(sat_pathlist, map_pathlist, batch_size=32, test=True)
    return test_ds.shuffle(1000)


def show_results(images, titles, figsize=(30, 30)):
    nb_images = len(images)
    plt.figure(figsize=figsize)
    for idx, (image, title) in enumerate(zip(images, titles)):
        if len(image.shape) and image.shape[-1] == 1:
            image = image[..., 0]
        plt.subplot(1, nb_images, idx + 1)
        plt.title(title)
        plt.imshow(image, cmap="gray")
    plt.show()


In [None]:
path = pathlib.Path("../../../Datasets/mass_roads9/valid")
path.exists()
ds = gends(path, "png")


In [None]:
im, tar = next(iter(ds))
pred = model.predict(im)
idx = random.randrange(32)
print(idx)
show_results((pred[idx], im[idx], tar[idx]), ("_", "_", "_"))


In [None]:
class DICELoss(keras.losses.Loss):
    """
    Tversky損失の`alpha`=0.5であるが継承せずため独立に実装
    """

    def __init__(self, name=None):
        """
        ゼロ除算対策のためのパラメータ設定
        """
        super().__init__(name=name)
        self.smooth = 1e-10

    def call(self, y_true, y_pred):
        y_true_pos = tf.reshape(y_true, [-1])
        y_pred_pos = tf.reshape(y_pred, [-1])
        tp_mul = tf.math.reduce_sum(y_true_pos * y_pred_pos)
        tp_sum = tf.math.reduce_sum(y_true_pos + y_pred_pos)
        dc = 2 * (tp_mul + self.smooth) / (tp_sum + self.smooth)
        return 1.0 - dc

dice = DICELoss()

In [None]:
path = "/Users/hagayuya/Datasets/mass_roads9/valid/sat/10978735_15_3_1_2.png"
org_im = skio.imread(path)
path = "/Users/hagayuya/Datasets/mass_roads9/valid/map/10978735_15_3_1_2.png"
org_ta = skio.imread(path, as_gray=True)[..., None]

h = org_im.shape[0]
for size in range(1, 11 ):
    crop_size = int(h * 0.1 * size)

    im = org_im[:crop_size, :crop_size]
    im = transform.resize(im, (256, 256))
    tar = org_ta[:crop_size, :crop_size]
    tar = transform.resize(tar, (256, 256))
    pred = model.predict(im[None])[0]
    pred = tf.convert_to_tensor(pred, dtype="float64")
    tar = tf.convert_to_tensor(tar, dtype="float64")
    print(dice(pred[None], tar[None]).numpy())
    show_results((pred, im + pred, tar), ("_", "_", "_"), (10, 10))
