In [1]:
import torch, zenkai, torch.nn as nn, sklearn.datasets, torch.utils.data, numpy as np

In [2]:
class SignLearner(zenkai.LearningMachine):

  def __init__(self, in_features: int, out_features: int):

    super().__init__()
    self.linear = nn.Linear(in_features, out_features)
    self.bce = nn.BCELoss(reduction='sum')
    self.sse = nn.MSELoss(reduction='none')

  def accumulate(self, x: zenkai.IO, t: zenkai.IO, state: zenkai.State):

    pre_y = state._pre_y
    # q = torch.sigmoid(pre_y)
    # p = torch.clamp(t.f, -1, 1)
    dy = (state._y.f - t.f)
    t = (pre_y - dy).detach()
    # print(pre_y[0], t[0], state._y.f[0])
    loss = (self.sse(pre_y, t)) * ((pre_y > -1.0) & (pre_y < 1.0))
    (loss.sum() * 0.5).backward()

  def assess_y(self, y: zenkai.IO, t: zenkai.IO, reduction_override: bool=None):
    # loss = (p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q)))
    # loss = (p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q)))
    loss = (y.f - t.f).abs()

    return zenkai.Reduction[reduction_override or 'mean'].reduce(
        loss
    )

  def step_x(self, x: zenkai.IO, t: zenkai.IO, state: zenkai.State):
    return x.acc_grad()

  def forward_nn(self, x: zenkai.IO, state: zenkai.State) -> torch.Tensor:

    y = self.linear(x.f)
    state._pre_y = y
    return torch.sign(y)

class F(nn.Module):

  def __init__(self, f):
    super().__init__()
    self.f = f

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.f(x)

In [3]:
X, t = sklearn.datasets.make_classification(n_samples=2000, n_features=80, n_informative=20, n_redundant=10)

X = torch.tensor(X).float()
X = (X - X.mean(dim=0, keepdim=True)) / (X - X.std(dim=0, keepdim=True))

t = torch.tensor(t).float().unsqueeze(-1)

# print(t.max().item(), t.min().item())

dataset = torch.utils.data.TensorDataset(X, t)

sign_learner = SignLearner(80, 16)
net = nn.Sequential(
  # nn.Linear(80, 16),
  # nn.BatchNorm1d(16),
  # F(lambda x: zenkai.utils.sign_ste(x)),
  sign_learner,
  nn.Linear(16, 1),
  nn.Sigmoid()
)

optim = torch.optim.Adam(net.parameters(), lr=1e-3)

# zenkai.set_lmode(net, zenkai.LMode.WithStep)

criterion = nn.BCELoss()
for i in range(1000):
  results = []

  for x_i, t_i in torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=128):

    y_i = net(x_i)
    # if ((y_i > 1.0) | (y_i < 0.0)).any():
    #   print('OOB: ', y_i.max().item(), y_i.min().item())
    # if ((t_i > 1.0) | (t_i < 0.0)).any():
    #   print('OOB: ', t_i.max().item(), t_i.min().item())
    loss = criterion(y_i, t_i)
    results.append(loss.item())
    before = zenkai.params.get_params(net[0])
    optim.zero_grad()
    loss.backward()
    optim.step()
    after = zenkai.params.get_params(net[0])

  if (i + 1) % 10 == 0:
    print(i, np.mean(results))

9 0.6456049904227257
19 0.591425497084856
29 0.5482065677642822
39 0.5141094122081995
49 0.4810848142951727
59 0.456650085747242
69 0.44077218510210514
79 0.4185577854514122
89 0.40468826703727245
99 0.38974221609532833
109 0.3738476652652025
119 0.36578717827796936
129 0.35666673816740513
139 0.34994928538799286
149 0.3458299897611141
159 0.33359730057418346
169 0.3270200062543154
179 0.3175435923039913
189 0.3122902223840356
199 0.3106965832412243
209 0.30264001712203026
219 0.28683163970708847
229 0.2860636869445443
239 0.2873776266351342
249 0.28614930529147387
259 0.2842739950865507
269 0.28056158125400543
279 0.28442160971462727
289 0.2758025797083974
299 0.2720751417800784
309 0.27666581608355045
319 0.266975405625999
329 0.26751829870045185
339 0.26949406787753105
349 0.2662418307736516
359 0.26829234324395657
369 0.2621750235557556
379 0.26092547085136175
389 0.2615014659240842
399 0.2565246159210801
409 0.2564963148906827
419 0.2579415310174227
429 0.2545348536223173
439 0.25