In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import autonotebook as tqdm
from dataclasses import dataclass
from typing import Callable, Protocol
from abc import ABC, abstractmethod

np.random.seed(123)

  from tqdm import autonotebook as tqdm


In [None]:
class function(Callable, ABC):

    def __init__(self, name: str):
        self.name = name

    @abstractmethod
    def __call__(self, x: np.ndarray) -> np.ndarray:
        raise NotImplementedError()
    
    @abstractmethod
    def diff(self) -> "function":
        raise NotImplementedError()
    
    def __add__(self, other):
        if isinstance(other, (int, float)):
            other = constant(other)
        return plus(self, other)
    
    def __mul__(self, other):
        if isinstance(other, (int, float)):
            other = constant(other)
        return times(self, other)
    
    def __sub__(self, other):
        return self + (-1) * other
    
    def __div__(self, other):
        return self * inverse(other)
    
    def __pow__(self, exponent: int):
        if not isinstance(exponent, int):
            raise ValueError("Exponent must be an integer")
        return compose(monomial(exponent), self)

    def __truediv__(self, other):
        return inverse(other) * self

    def __str__(self):
        return self.name
    
    def __repr__(self):
        return self.name


class plus(function):
    def __init__(self, f1: function, f2: function):
        super().__init__(f"({f1} + {f2})")
        self.f1 = f1
        self.f2 = f2

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.f1(x) + self.f2(x)
    
    def diff(self) -> "function":
        return self.f1.diff() + self.f2.diff()


class times(function):
    def __init__(self, f1: function, f2: function):
        super().__init__(f"({f1} * {f2})")
        self.f1 = f1
        self.f2 = f2

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.f1(x) * self.f2(x)
    
    def diff(self) -> "function":
        return self.f1.diff() * self.f2 + self.f1 * self.f2.diff()


class inverse(function):
    def __init__(self, f: function):
        super().__init__(f"1 / {f}")
        self.f = f

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return 1 / self.f(x)
    
    def diff(self) -> "function":
        return -self.f.diff() / (self.f ** 2)

class compose(function):
    def __init__(self, f1: function, f2: function):
        super().__init__(f"{f1}({f2})")
        self.f1 = f1
        self.f2 = f2

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.f1(self.f2(x))
    
    def diff(self) -> "function":
        return self.f1.diff(self.f2) * self.f2.diff()

class monomial(function):
    def __init__(self, degree: int):
        super().__init__(f"x^{degree}")
        self.degree = degree

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return x ** self.degree
    
    def diff(self) -> "function":
        return monomial(self.degree - 1) * constant(self.degree)

class zero(function):

    def __init__(self):
        super().__init__("0")

    def __call__(self, x: np.ndarray) -> np.ndarray:
        if np.isscalar(x):
            return 0
        return np.zeros_like(x)
    
    def diff(self) -> "function":
        return zero()
    
class constant(function):

    def __init__(self, value: float):
        self.value = value
        super().__init__(f"{value:.2f}")
        
    def __call__(self, x: np.ndarray) -> np.ndarray:
        if np.isscalar(x):
            return self.value
        return np.full_like(x, self.value)
    
    def diff(self) -> "function":
        return zero()
    
class affine(function):

    def __init__(self, alpha: float, beta: float):
        self.alpha = alpha
        self.beta = beta
        super().__init__(f"{alpha:.2f} x + {beta:.2f}")

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.alpha * x + self.beta
    
    def diff(self) -> "function":
        return constant(self.alpha)





In [10]:
FUNCTION_TYPES = [
    "wave",
    "monomial",
    "exponential",
    "affine",
    "plus",
    "times",
    "inverse",
    "composition",
]

PROBABILITIES = [2, 2, 2, 3, 2, 2, 1, 1]
PROBABILITIES /= np.sum(PROBABILITIES)

def sample_basic_function():
    fn_type = np.random.choice(FUNCTION_TYPES, p=PROBABILITIES)
    match fn_type:
        case "wave":
            frequency = np.random.normal()
            phase = np.random.uniform(0, 2*np.pi)
            return wave(frequency, phase)
        case "monomial":
            degree = np.random.randint(1, 10)
            return monomial(degree)
        case "exponential":
            alpha = np.random.normal()
            return exponential(alpha)
        case "affine":
            alpha = np.random.normal()
            beta = np.random.normal()
            return affine(alpha, beta)
        case "plus":
            f1 = sample_basic_function()
            f2 = sample_basic_function()
            return plus(f1, f2)
        case "times":
            f1 = sample_basic_function()
            f2 = sample_basic_function()
            return times(f1, f2)
        case "inverse":
            f = sample_basic_function()
            return inverse(f)
        case "composition":
            f1 = sample_basic_function()
            f2 = sample_basic_function()
            return composition(f1, f2)
        case _:
            raise ValueError("Invalid function type")




In [11]:
for k in range(10):
    print(sample_basic_function())

RecursionError: maximum recursion depth exceeded

In [None]:
# Example usage

test_x = np.linspace(-2, 2, 100)
fig, axs = plt.subplots(figsize=(12, 4), ncols=2)
ax1, ax2 = axs

ax1.set_title('f(x)')
ax2.set_title("f'(x)")

for _ in range(20):
    function = sample_basic_function()
    ax1.plot(test_x, function(test_x))
    ax2.plot(test_x, function.diff(test_x))
    ax1.set_ylim(-10, 10)
    ax2.set_ylim(-10, 10)
fig.tight_layout()


In [None]:
np.random.seed(42)

def generate_data(num_points, num_examples):
    xs = np.linspace(-1, 1, num_points).reshape(-1, 1)
    ys = []
    dys = []
    
    for _ in tqdm.trange(num_examples):
        function = sample_basic_function()
        y = function(xs)
        dy = function.diff(xs)
        if not np.any(np.isnan(y)) and not np.any(np.isnan(dy)):
            ys.append(function(xs))
            dys.append(function.diff(xs))
        else:
            pass
        
    
    y = np.hstack(ys).T
    dy = np.hstack(dys).T
    return xs, y, dy

xs, ys, dys = generate_data(100, 250_000)

print(xs.shape)
print(ys.shape)
print(dys.shape)

In [None]:
np.save('xs.npy', xs)
np.save('ys.npy', ys)
np.save('dys.npy', dys)