In [1]:
import initialize
from torch import nn
import torch
from torch.utils import data as torch_data
from zenkai.utils import apply_to_parameters
import tqdm
import numpy as np

In [23]:
class MaxMinES(nn.Module):

    def __init__(self, p: int, in_features: int, out_features: int, k: int):

        super().__init__()
        self.h1 = in_features
        self.w = nn.parameter.Parameter(torch.rand(
            p, in_features, self.h1
        ))
        self.w2 = nn.parameter.Parameter(torch.rand(
            p, self.h1, out_features
        ))
        self.k = k
        self.p = p
        self.best_w = torch.rand(
            in_features, self.h1
        )
        self.best_w2 = torch.rand(
            self.h1, out_features
        )
    
    def forward(self, x: torch.Tensor):

        x = x[None,:,:,None]
        w = self.w[:,None]
        y = torch.max(
            torch.min(x, w), dim=-2, keepdim=False
        )[0]
    
        w2 = self.w2[:,None]
        y = y[:,:,:,None]
        return torch.max(
            torch.min(y, w2), dim=-2, keepdim=False
        )[0]

    def best_forward(self, x: torch.Tensor):

        x = x[:,:,None]
        w = self.best_w[None]
        y = torch.max(
            torch.min(x, w), dim=-2, keepdim=False
        )[0]
    
        w2 = self.best_w2[None]
        y = y[:,:,None]
        return torch.max(
            torch.min(y, w2), dim=-2, keepdim=False
        )[0]
    
    def update(self, assessment: torch.Tensor):
        # pop, batch, out

        assessment = assessment.view(self.p, -1).mean(dim=1)
        best_ind = assessment.topk(
            self.k, dim=0, largest=False
        )[1]
        self.best_w = self.w[assessment.argmin(0)].detach().clone()
        self.best_w2 = self.w2[assessment.argmin(0)].detach().clone()
        best_w = self.w[best_ind].mean(dim=0, keepdim=True)
        best_w2 = self.w2[best_ind].mean(dim=0, keepdim=True)

        keep1 = torch.rand(self.p - 1, best_w.shape[1], best_w.shape[2]) > 0.025
        keep2 = torch.rand(self.p - 1, best_w2.shape[1], best_w2.shape[2]) > 0.025

        spawned_w = keep1 * best_w + (~keep1) * torch.rand(self.p - 1, best_w.shape[1], best_w.shape[2])
        spawned_w2 = keep2 * best_w2 + (~keep2) * torch.rand(self.p - 1, best_w2.shape[1], best_w2.shape[2])

        # spawned_w = torch.clamp(
        #     best_w + 0.05 * torch.randn(self.p - 1, best_w.shape[1], best_w.shape[2]), 0.0, 1.0
        # )
        # spawned_w2 = torch.clamp(
        #     best_w2 + 0.05 * torch.randn(self.p - 1, best_w2.shape[1], best_w2.shape[2]), 0.0, 1.0
        # )
        self.w.data = torch.cat([best_w, spawned_w])
        self.w2.data = torch.cat([best_w2, spawned_w2])

In [24]:

def epoch_str(epoch, n_epochs):
    return f'Epoch {epoch + 1}/{n_epochs}'


def optim(net: MaxMinES, X: torch.Tensor, T: torch.Tensor, X_VAL: torch.Tensor, T_VAL: torch.Tensor, n_epochs: int=10, batch_size: int=128, epoch_callback=None):

    dataset = torch_data.TensorDataset(X, T)

    with tqdm.tqdm(total=n_epochs) as pbar:
        for _ in range(n_epochs):
            optim = torch.optim.Adam(net.parameters(), lr=1e-3)
            epoch_loss = []
            best_loss = []

            for x_i, t_i in torch_data.DataLoader(dataset, shuffle=True, batch_size=batch_size):
                optim.zero_grad()
                y_i = net(x_i)

                y_i_best = net.best_forward(x_i)
                loss_best = (y_i_best - t_i).pow(2).mean()

                t_i = t_i[None]
                loss = (y_i - t_i).pow(2).mean()

                loss.backward()
                # print('Grad: ', p[0].grad.abs().sum().item(), p[1].grad.abs().sum().item(), loss.item())
                optim.step()
                apply_to_parameters(
                    net.parameters(), lambda x: torch.clamp(x, 0.0, 1.0)
                )
                epoch_loss.append(loss.item())
                best_loss.append(loss_best.item())

            with torch.no_grad():
                Y_VAL = net(X_VAL)
                evaluation = (Y_VAL - T_VAL[None]).pow(2)
                net.update(evaluation)

            if epoch_callback is not None:
                epoch_callback()
            pbar.update(1)
            # pbar.reset()
            # pbar.set_description(epoch_str(0, n_epochs))
            pbar.set_postfix({'loss': np.mean(epoch_loss), 'best': np.mean(best_loss)}, refresh=True)
            # print(np.mean(epoch_loss))

In [26]:
in_features = 16
out_features = 8
x = torch.rand(
    4000, in_features
)
x_val = torch.rand(
    500, in_features
)
w = torch.rand(
    in_features, out_features
)
t = torch.max(torch.min(x[:,:,None], w[None]), dim=-2)[0]
t_val = torch.max(torch.min(x_val[:,:,None], w[None]), dim=-2)[0]

mod = MaxMinES(32, in_features, out_features, 4)

optim(
    mod, x, t, x_val, t_val, 2000, 128
)

y_val = mod.best_forward(x_val)
(y_val - t_val).pow(2).mean()
        


 47%|████▋     | 946/2000 [03:37<04:02,  4.35it/s, loss=0.00203, best=0.000796]


KeyboardInterrupt: 