Этот ноутбук настроен на обучение нейронной сети U-net в Google Colab.

Предварительно надо поместить `segments.zip` по пути `/binarization/segments.zip`

In [None]:
!pip install keras_unet

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/binarization/segments.zip segments.zip
!unzip segments.zip

In [None]:
import glob

original_paths = glob.glob("segments/original/*.bmp")
gt_paths = list(map(lambda x: x.replace("original", "gt"), original_paths))

print(len(original_paths), len(gt_paths))

In [None]:
first_count = 5
for original_path, gt_path in zip(original_paths[:first_count], gt_paths[:first_count]):
    print(original_path, gt_path)

In [None]:
from sklearn.model_selection import train_test_split

x_train, x_val, y_train, y_val = train_test_split(original_paths, gt_paths, test_size=0.1, random_state=0)

print("x_train: ", len(x_train))
print("y_train: ", len(y_train))
print("x_val: ", len(x_val))
print("y_val: ", len(y_val))

In [None]:
# imcollect.py

from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import load_img


class PairsGenerator(Sequence):
    """Вспомогательный класс для итерации по изображениям. Подходит для обучения моделей Keras.
    Нужен для того, чтобы не загружать весь датасет в память"""

    def __init__(self, batch_size, original_img_paths, gt_img_paths):
        self.batch_size = batch_size
        self.original_paths = original_img_paths
        self.gt_paths = gt_img_paths

    def __len__(self):
        """Количество батчей"""
        return len(self.original_paths) // self.batch_size

    def __getitem__(self, idx):
        """Возвращает батч (пару наборов изображений) по индексу"""
        i = idx * self.batch_size
        batch_original_img_paths = self.original_paths[i: i + self.batch_size]
        batch_gt_img_paths = self.gt_paths[i: i + self.batch_size]
        original = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_original_img_paths):
            img = load_img(path, target_size=self.img_size)
            original[j] = np.array(img) / 255
        gt = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_gt_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            gt[j] = np.expand_dims(img, 2) / 255
        return original, gt


In [None]:
segment_size = 256
batch_size = 32
pairgen = PairsGenerator(batch_size, segment_size, x_train, y_train)
val_pairgen = PairsGenerator(batch_size, segment_size, x_val, y_val)

In [None]:
x, y = pairgen[0]
print(x.shape, y.shape)
print(x.dtype, y.dtype)
print(x[0].max(), y[0].max())
print(len(pairgen), len(val_pairgen))

In [None]:
from keras_unet.models import vanilla_unet, custom_unet

model = custom_unet(input_shape=(256, 256, 3))

In [None]:
from keras.callbacks import ModelCheckpoint


model_filename = 'segm_model_v3.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename,
    verbose=1,
    monitor='val_loss',
    save_best_only=True,
)

In [None]:
from tensorflow.keras.optimizers import Adam, SGD
from keras_unet.metrics import iou, iou_thresholded
from keras_unet.losses import jaccard_distance

model.compile(
    optimizer=Adam(), 
    #optimizer=SGD(lr=0.01, momentum=0.99),
    loss='binary_crossentropy',
    #loss=jaccard_distance,
    metrics=[iou, iou_thresholded]
)

In [None]:
history = model.fit(
    pairgen,
    steps_per_epoch=len(pairgen),
    validation_data=val_pairgen,
    validation_steps=len(val_pairgen),
    epochs=10,
    callbacks=[callback_checkpoint]
)

In [None]:
model.summary()

In [None]:
print(history.history.keys())

In [None]:
plt.plot(history.history['iou'])
plt.plot(history.history['val_iou'])
plt.title('model iou')
plt.ylabel('iou')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

plt.plot(history.history['iou_thresholded'])
plt.plot(history.history['val_iou_thresholded'])
plt.title('model iou thresholded')
plt.ylabel('iou thresholded')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
import time
model.save(f'/content/drive/MyDrive/binarization/models/model.{int(time.time())}')

Далее модель применяется к одному из изображений

In [None]:
# Загрузка модели. В Google Colab можно не запускать
from keras_unet.metrics import iou, iou_thresholded

model_path = f"models/model.1655852129"
model = tf.keras.models.load_model(model_path,
                                   custom_objects=dict(iou_thresholded=iou_thresholded, iou=iou))

In [None]:
import cv2 as cv
path = '/content/drive/MyDrive/binarization/images/original/2image.png'

test_image = cv.imread(path)
test_image = test_image[...,::-1] / 255

plt.imshow(test_image)
print(test_image.shape)

In [None]:
# imsplit.py

def _range_borders(start, finish, distance, step=None, full_cover=True):
    if finish - start < distance:
        return []
    if step is None:
        step = distance
    pairs = []
    for start_border in range(start, finish, step):
        finish_border = start_border + distance
        if finish_border > finish:
            if full_cover:
                pairs.append((finish - distance, finish))
            return pairs
        pairs.append((finish - distance, finish))
    return pairs


def replace_segments(crops, new_segments):
    return [(segment, borders) for segment, (_, borders) in zip(new_segments, crops)]


def imsplit(image, size, step=None, full_cover=True):
    """Метод для разделения изображения на квадратные сегменты

    :param image: изображение
    :param size: размер стороны сегмента в пикселях
    :param step: смещение сегмента. Нужно, если требуется перекрытие сегментов. Если None, сегменты
      не будут перекрываться (кроме последних крайних, что определяется параметром full_cover)
    :param full_cover: нужно ли добавлять крайние сегменты, если при этом будет перекрытие с соседним
      Например, при размере изображения 100 x 100 и размере сегмента 30 x 30 либо будет перекрытие
      сегментов, либо крайние правые и нижние сегменты будут проигнорированы (full_cover=False)
    :return: массив пар (сегмент, границы). Границы в формате PIL: (left, top, right, bottom)
    """
    image = np.asarray(image)
    if step is None:
        step = size
    crops = []
    h = image.shape[0]
    w = image.shape[1]
    for top, bottom in _range_borders(0, h, size, step, full_cover=full_cover):
        for left, right in _range_borders(0, w, size, step, full_cover=full_cover):
            crops.append((image[top:bottom, left:right], (left, top, right, bottom)))
    return crops


def get_shape(crops):
    assert len(crops) > 0, 'пустой массив сегментов'
    max_right = 0
    max_bottom = 0
    segment_shape = None
    for segment, (_, _, right, bottom) in crops:
        if segment_shape is None:
            segment_shape = segment.shape
        assert segment.shape == segment_shape, 'все сегменты должны иметь одинаковый размер'
        if max_right < right:
            max_right = right
        if max_bottom < bottom:
            max_bottom = bottom
    if len(segment_shape) == 2:
        shape = max_bottom, max_right
    elif segment_shape[2] == 1:
        shape = max_bottom, max_right, 1
    elif segment_shape[2] == 3:
        shape = max_bottom, max_right, 3
    else:
        raise Exception('неправильный атрибут shape у сегментов', segment_shape)
    return shape


def imjoin_max(crops):
    shape = get_shape(crops)
    max_image = np.zeros(shape, dtype=np.float64)
    for segment, (left, top, right, bottom) in crops:
        max_image[top:bottom, left:right] = np.maximum(max_image[top:bottom, left:right], segment)
    return max_image


def imjoin_min(crops):
    shape = get_shape(crops)
    min_image = np.zeros(shape, dtype=np.float64)
    for segment, (left, top, right, bottom) in crops:
        min_image[top:bottom, left:right] = np.minimum(min_image[top:bottom, left:right], segment)
    return min_image


def imjoin_average(crops):
    shape = get_shape(crops)
    sum_image = np.zeros(shape, dtype=np.float64)
    count_image = np.zeros(shape, dtype=np.float64)
    for segment, (left, top, right, bottom) in crops:
        sum_image[top:bottom, left:right] += segment
        count_image[top:bottom, left:right] += 1
    return sum_image / count_image

In [None]:
crops = imsplit(test_image, 256, 64)

len(crops)

In [None]:
batch = np.array([segment for segment, _ in crops])

batch.shape, batch.dtype, batch.max()

In [None]:
result_batch = model.predict(batch)

In [None]:
result_crops = replace_segments(crops=crops, new_segments=result_batch)
result_image = imjoin_average(result_crops)

plt.imshow(np.squeeze(result_image), cmap='gray')
plt.show()

In [None]:
from PIL import Image

path = 'image.png'
im = Image.fromarray(np.squeeze((result_image * 255).astype(np.uint8))).save(path)