In [1]:
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def make_random_data(x_min, x_max, a, b, c, noise_scale):
    x = np.arange(x_min, x_max, 0.1)
    y = a * x**2 + b * x + c

    y_noise = y + np.random.randn(len(x)) * 100
    y_noise[1000:1250] = y_noise[1000:1250] + np.random.randn(250) * noise_scale

    return x, y, y_noise


def quadratic_from_three_points(x1, y1, x2, y2, x3, y3):

    A = np.array([[x1**2, x1, 1], [x2**2, x2, 1], [x3**2, x3, 1]])
    b = np.array([y1, y2, y3])
    x = np.linalg.solve(A, b)

    a = x[0]
    b = x[1]
    c = x[2]

    return a, b, c


def plot_data(x, y, y_noise, model):

    if model is not None:
        y_pred = model[0] * x**2 + model[1] * x + model[2]
        plt.plot(x, y_pred, label="Predicted Data", linestyle="--", color="red")

    # plt.plot(x, y, label="Origin Data", color="orange")
    plt.scatter(x, y_noise, label="Real Data", s=1)

    plt.grid(True)
    plt.legend()
    plt.show()


def ransac(x, y, n, k, t, d):
    best_error = np.inf
    best_model = None
    best_inliers = None

    for i in range(n):
        sample = np.random.choice(len(x), k, replace=False)
        x_sample = x[sample]
        y_sample = y[sample]

        a, b, c = quadratic_from_three_points(
            x_sample[0], y_sample[0], x_sample[1], y_sample[1], x_sample[2], y_sample[2]
        )

        y_pred = a * x**2 + b * x + c
        error = np.abs(y - y_pred)

        inliers = np.where(error < t)[0]
        if len(inliers) > d:
            error = np.sum(error[inliers])
            if error < best_error:
                best_error = error
                best_model = (a, b, c)
                best_inliers = inliers

    return best_model, best_inliers

In [3]:
x, y_real, y_noise = make_random_data(
    x_min=-50, x_max=50, a=0.1, b=-10, c=300, noise_scale=100
)
for i in range(1):
    best_model, best_inliers = ransac(x, y_noise, 10, 3, 1000, 1000)
    print(f"[{i:2d}] Best Model: {best_model}")
    plot_data(x, y_real, y_noise, best_model)

ValueError: operands could not be broadcast together with shapes (0,) (250,) 