In [None]:
from typing import Final
import torch
import torch.nn as nn

class RelNwRegr(nn.Module):
    def __init__(self, init_sigma: float, init_r_scale: float) -> None:
        super().__init__()
        self.sigma = nn.Parameter(torch.ones(1) * init_sigma)
        self.r_scale = nn.Parameter(torch.ones(1) * init_r_scale)
    
    def k(self, xi: torch.Tensor, xj: torch.Tensor, rij: float) -> float:
        return torch.exp(-torch.norm(xi - xj, 2) / self.sigma + self.r_scale * rij).item()
    
    def forward(self, x_train: torch.Tensor, y_train: torch.Tensor, x_test: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """
        x_train: (n_train, n_features)
        x_test: (n_features,)
        r[i] = rel(x_train[i], x_test) (n_train,)
        """
        denominator: Final[float] = sum([self.k(xi, x_test, ri) for xi, ri in zip(x_train, r)])
        y = 0
        for xi, yi, ri in zip(x_train, y_train, r):
            y += yi * self.k(xi, x_test, ri)
        return y / denominator

In [None]:
import matplotlib.pyplot as plt

n_train = 50
torch.random.manual_seed(42)
x_train = torch.rand((n_train, 1)) * 2 - 1
y_train = x_train.flatten() ** 2

plt.scatter(x_train.flatten(), y_train, color="black")

x_grid = torch.linspace(start=-1, end=1, steps=50).unsqueeze(1)
r = torch.zeros((n_train,))
for sigma in (.01, .05, .1, .3,):
    regr = RelNwRegr(init_sigma=sigma, init_r_scale=1)
    y_grid = [regr(x_train, y_train, x, r) for x in x_grid]
    plt.plot(x_grid.flatten(), y_grid, "--", label=f"$\sigma={sigma}$")

plt.legend()

In [None]:
torch.random.manual_seed(42)
clusters = torch.randint(0, 3, (n_train,))
y_train = x_train.flatten() ** 2
y_train += clusters

sigma = .1
for r_scale in (.5, 2.5, 10):
    plt.figure()
    plt.scatter(x_train.flatten(), y_train, color="black")

    regr = RelNwRegr(init_sigma=sigma, init_r_scale=r_scale)
    for i_c in torch.unique(clusters):
        y_grid = []
        for x in x_grid:
            r = (clusters == i_c).int()
            y_grid.append(regr(x_train, y_train, x, r))
        plt.plot(x_grid, y_grid, "--")
    
    plt.title(f"r_scale={r_scale}")