In [1]:
import torch, zenkai, torch.nn as nn, sklearn.datasets, torch.utils.data, numpy as np

In [2]:
class SignLearner(zenkai.LearningMachine):

  def __init__(self, in_features: int, out_features: int):

    super().__init__()
    self.linear = nn.Linear(in_features, out_features)
    self.batch_norm = nn.BatchNorm1d(out_features)
    self.bce = nn.BCELoss(reduction='sum')
    self.sse = nn.MSELoss(reduction='none')

  def accumulate(self, x: zenkai.IO, t: zenkai.IO, state: zenkai.State):

    # print(t.f)
    pre_y = state._pre_y
    # q = torch.sigmoid(pre_y)
    # p = torch.clamp(t.f, -1, 1)
    dy = (state._y.f - t.f)
    print(dy[0])
    # dy = (state._y.f - pre_y).clamp(-1.0, 1.0)
    t = (pre_y - dy).detach()
    # print(pre_y[0], t[0], state._y.f[0])
    loss = (self.sse(pre_y, t)) # * ((pre_y > -1.0) & (pre_y < 1.0))
    # print(state._y.f.shape)

    # print(state._pre_y[0], t[0], dy[0], state._y.f[0])
    # t = torch.clamp(t.f, -1.0, 1.0)
    # print(dy.max().item(), dy.min().item())
    # pre_y - (t - )
    # pre_y + t - y
    # loss = self.sse(pre_y, (pre_y + dy).detach())
    # loss = self.sse(state._y.f, t.f)
    # p = (p + 1) / 2
    # print(p.max().item(), p.min().item(), q.max().item(), q.min().item())
    # print((p * torch.log(p / q)).min().item(), ((1 - p) * torch.log((1 - p) / (1 - q))).min().item())
    # loss = -((p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))))
    # loss[loss.isnan()] = 0.0
    # loss = self.bce(q, p) * 0.5
    # print(loss.mean())
    # print(loss.min().item(), loss.max().item())
    # before = zenkai.params.get_params(self)
    (loss.sum() * 0.5).backward()
    # print('Before: ', zenkai.params.get_model_grads(self, flat_cat=True)[0:10])

  def assess_y(self, y: zenkai.IO, t: zenkai.IO, reduction_override: bool=None):

    q = y.f
    # loss = (p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q)))

    # loss = (p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q)))
    loss = (y.f - t.f).abs()

    return zenkai.Reduction[reduction_override or 'mean'].reduce(
        loss
    )

  def step(self, x: zenkai.IO, t: zenkai.IO, state: zenkai.State):

    self.optim.step()
    self.optim.zero_grad()

  def step_x(self, x: zenkai.IO, t: zenkai.IO, state: zenkai.State):
    return x.acc_grad()

  def forward_nn(self, x: zenkai.IO, state: zenkai.State) -> torch.Tensor:

    y = self.linear(x.f)
    y = self.batch_norm(y)
    state._pre_y = y
    return torch.sign(y)


class F(nn.Module):

  def __init__(self, f):
    super().__init__()
    self.f = f

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.f(x)

In [3]:
X, t = sklearn.datasets.make_classification(n_samples=2000, n_features=80, n_informative=20, n_redundant=10)

X = torch.tensor(X).float()
X = (X - X.mean(dim=0, keepdim=True)) / (X - X.std(dim=0, keepdim=True))

t = torch.tensor(t).float().unsqueeze(-1)

# print(t.max().item(), t.min().item())

dataset = torch.utils.data.TensorDataset(X, t)

sign_learner = SignLearner(80, 16)
net = nn.Sequential(
  # nn.Linear(80, 16),
  # F(lambda x: zenkai.utils.sign_ste(x)),
  sign_learner,
  nn.Linear(16, 1),
  nn.Sigmoid()
)

optim = torch.optim.Adam(net.parameters(), lr=1e-3)

# zenkai.set_lmode(net, zenkai.LMode.WithStep)

criterion = nn.BCELoss()
for i in range(400):
  results = []

  for x_i, t_i in torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=128):

    y_i = net(x_i)
    # print(t_i.max().item(), t_i.min().item(), y_i.max().item(), y_i.min().item())
    # if ((y_i > 1.0) | (y_i < 0.0)).any():
    #   print('OOB: ', y_i.max().item(), y_i.min().item())
    # if ((t_i > 1.0) | (t_i < 0.0)).any():
    #   print('OOB: ', t_i.max().item(), t_i.min().item())
    loss = criterion(y_i, t_i)
    results.append(loss.item())
    # before = zenkai.params.get_params(sign_learner)
    optim.zero_grad()
    loss.backward()
    # print('After: ', zenkai.params.get_model_grads(sign_learner, flat_cat=True)[0:10])
    optim.step()
    # assert (before != zenkai.params.get_params(sign_learner)).any()
  if (i + 1) % 10 == 0:
    print(i, np.mean(results))

tensor([ 3.5289e-04,  3.4124e-04, -4.0181e-04, -8.8949e-06,  2.2080e-04,
         2.8669e-04, -1.8520e-04, -3.8257e-04, -7.1228e-05, -2.0182e-05,
         3.5280e-04,  2.9816e-04,  5.1867e-04,  4.2000e-04,  5.2722e-04,
        -1.0745e-04])
tensor([ 3.5286e-04,  3.4124e-04, -4.0185e-04, -8.8811e-06,  2.2078e-04,
         2.8670e-04, -1.8519e-04, -3.8254e-04, -7.1228e-05, -2.0206e-05,
         3.5286e-04,  2.9814e-04,  5.1868e-04,  4.1997e-04,  5.2726e-04,
        -1.0747e-04], grad_fn=<SelectBackward0>)
tensor([-4.6738e-04, -4.5764e-04,  5.3257e-04,  1.4748e-05, -2.9135e-04,
        -3.7916e-04,  2.4970e-04,  5.0693e-04,  9.2026e-05,  2.4001e-05,
        -4.6727e-04, -4.0024e-04, -6.8831e-04, -5.5682e-04, -6.9970e-04,
         1.4030e-04])
tensor([-4.6736e-04, -4.5764e-04,  5.3263e-04,  1.4782e-05, -2.9135e-04,
        -3.7920e-04,  2.4974e-04,  5.0688e-04,  9.2030e-05,  2.3961e-05,
        -4.6730e-04, -4.0019e-04, -6.8831e-04, -5.5683e-04, -6.9970e-04,
         1.4031e-04], grad_fn=<

KeyboardInterrupt: 