<a href="https://colab.research.google.com/github/ventus550/ShapeCorrection/blob/main/notebooks/train-regressor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Init

In [16]:
try:
    from google.colab import drive
    drive.mount('/content/mount', force_remount=True)
    %cd mount/MyDrive/MachineLearningProjects/ShapeCorrection/notebooks
except ModuleNotFoundError: ...

Mounted at /content/mount
[Errno 107] Transport endpoint is not connected: 'mount/MyDrive/MachineLearningProjects/ShapeCorrection/notebooks'
/content/ShapeCorrection


In [17]:
import sys
sys.path.append("..")

In [18]:
import keras
from keras.layers import *
import tensorflow as tf

import numpy as np
from scipy import ndimage
from pathlib import Path
from matplotlib import pyplot as plt

import utils
import architectures

TypeError: ignored

## Configuration

In [None]:
shape = "ellipse"
vertices = 4

In [19]:
BASEDIR = Path().absolute().parent
DESTDIR = BASEDIR / "data"
MODLDIR = BASEDIR / "models"
DATASET = DESTDIR / "regression" / f"{shape}.npz"

NameError: ignored

## Data augmentation

In [None]:
def rotation_matrix(angle):
    theta = np.radians(angle)
    c, s = np.cos(theta), np.sin(theta)
    return np.array([[c, -s], [s, c]])

def rotate(x, y, angle=30):
    x = np.clip(ndimage.rotate(x, angle=angle, reshape=False), 0, 1)
    centroid = np.zeros(2) + 0.5
    R = rotation_matrix(angle)
    y = centroid + (y - centroid) @ R
    return x, y

def fliplr(x, y):
	x = np.flip(x, axis=1)
	y = ((1, 0) - y) * (1, -1)
	return x, y

def flipud(x, y):
	x = np.flip(x, axis=0)
	y = ((0, 1) - y) * (-1, 1)
	return x, y

def augment(x, y):
    if np.random.rand() < 0.5:
        x, y = fliplr(x, y)
    if np.random.rand() < 0.5:
        x, y = flipud(x, y)
    
    random_angle = np.random.randint(360)
    x, y = rotate(x, y, random_angle)
    return x, y

## Data loading

In [None]:
X, Y  = np.load(DATASET).values()
test  = utils.DataGenerator([X[:100], Y[:100]])
val   = utils.DataGenerator([X[100:200], Y[100:200]])
train = utils.DataGenerator([X[200:], Y[200:]], transform=augment)

In [None]:
def chamfer_distance_loss(points1, points2):
    # Compute pairwise distances between points1 and points2
    distances1 = tf.norm(tf.expand_dims(points1, axis=2) - tf.expand_dims(points2, axis=1), axis=-1)
    distances2 = tf.norm(tf.expand_dims(points2, axis=2) - tf.expand_dims(points1, axis=1), axis=-1)

    # Compute the minimum distances for each point in points1 and points2
    min_distances1 = tf.reduce_min(distances1, axis=-1)
    min_distances2 = tf.reduce_min(distances2, axis=-1)

    # Compute the Chamfer distance loss
    chamfer_loss = tf.reduce_mean(min_distances1, axis=-1) + tf.reduce_mean(min_distances2, axis=-1)

    # Compute the total loss over the entire batch
    return tf.reduce_sum(chamfer_loss)

In [None]:
model = architectures.CNN(vertices)
model.compile(loss=chamfer_distance_loss)

In [None]:
model.fit(train, epochs=1)

In [None]:
utils.save(model, MODLDIR / "new.h5", frozen=True)

In [None]:
model = utils.load(MODLDIR / "new.h5")

In [None]:
from shapely.geometry import Polygon

def dice(label, pred, nb_vertices=4):
    y_polygon = Polygon(label.reshape(nb_vertices, 2))
    pred_polygon = Polygon(pred.reshape(nb_vertices, 2))

    I = y_polygon.intersection(pred_polygon).area
    return 2 * I / (y_polygon.area + pred_polygon.area)

In [None]:
image_size = 70

fig = plt.figure(figsize=(13, 21))
fig.subplots_adjust(hspace=0.13, wspace=0.01, left=0, right=1, bottom=0, top=1.2)
nb_pictures = 32
for irow in range(nb_pictures):
    ipic = np.random.choice(X.shape[0])
    ax = fig.add_subplot(nb_pictures // 4, 4, irow+1, xticks=[],yticks=[]) 
    pred = model(X[ipic].reshape([1, 70, 70, 1]))[0].numpy()
    utils.draw_data_point(X[ipic], Y[ipic], pred, ax)
    try:
        dice = utils.dice(Y[ipic], pred, vertices)
        IoU = utils.IoU(Y[ipic], pred, vertices)
        ax.set_title(f"dice {dice:5.3f}, IoU {IoU:5.3f}")
    except Exception as e:
        ax.set_title("n/a")
plt.show()