In [1]:
import initialize
from mistify import _functional as F
import torch
import typing
import torch.nn as nn
from torch.utils import data as torch_data
from zenkai.utils import apply_to_parameters

from mistify.learn import WrapNeuron, MaxMinRelOut
from mistify import OrF, MulG
from mistify.infer import Or, Inter, UnionOn
import numpy as np
import tqdm

In [2]:
def create_input_dataset1(shape, dropout_p: bool=0.5, exp: float=1.0, require_grad: bool=False):

    data = torch.rand(*shape) ** exp
    data = data * (torch.rand(shape) >= dropout_p)
    if require_grad:
        data.requires_grad_()
        data.retain_grad()
    return data



In [3]:

def epoch_str(epoch, n_epochs):
    return f'Epoch {epoch + 1}/{n_epochs}'


def optim(net: nn.Module, X: torch.Tensor, T: torch.Tensor, n_epochs: int=10, batch_size: int=128, epoch_callback=None):

    dataset = torch_data.TensorDataset(X, T)
    optim = torch.optim.Adam(net.parameters(), lr=1e-3)

    with tqdm.tqdm(total=n_epochs) as pbar:
        for _ in range(n_epochs):
            epoch_loss = []

            for x_i, t_i in torch_data.DataLoader(dataset, shuffle=True, batch_size=batch_size):
                optim.zero_grad()
                y_i = net(x_i)
                loss = (y_i - t_i).pow(2).mean()
                loss.backward()
                # print('Grad: ', p[0].grad.abs().sum().item(), p[1].grad.abs().sum().item(), loss.item())
                optim.step()
                apply_to_parameters(
                    net.parameters(), lambda x: torch.clamp(x, 0.0, 1.0)
                )
                epoch_loss.append(loss.item())
            if epoch_callback is not None:
                epoch_callback()
            pbar.update(1)
            # pbar.reset()
            # pbar.set_description(epoch_str(0, n_epochs))
            pbar.set_postfix({'loss': np.mean(epoch_loss)}, refresh=True)
            # print(np.mean(epoch_loss))


In [3]:



f = WrapNeuron(OrF('std', 'std'), MaxMinRelOut())

base = Or(16, 8)
X = torch.rand(1000, 16)
T = base(X).detach()


or_ = Or(16, 8, f=f)

optim(or_, X, T, n_epochs=1000)
or_.f = OrF('std', 'std')
optim(or_, X, T, n_epochs=1000)


NameError: name 'optim' is not defined

In [13]:


torch.manual_seed(4)

base = Or(16, 8)

base = nn.Linear(16, 8)
X = torch.rand(4000, 16)
T = torch.sigmoid(base(X)).detach()

class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.f1 = WrapNeuron(OrF('std', 'std'), MaxMinRelOut())
        self.f2 = WrapNeuron(OrF('std', 'std'), MaxMinRelOut())
        # self.f = OrF('std', 'std')
        # self.f2 = WrapNeuron(OrF('std', 'std'), MaxMinRelOut())
        self.or1 = Or(16, 16, f=self.f1)
        self.or2 = Or(16, 8, f=self.f2) # OrF('std', 'std'))

    # def set_f(self, f):
    #     self.or1.f1 = f
    #     self.or2.f = f
    
    def forward(self, x):

        return self.or2(self.or1(x))

net = Net()

weight1 = True
net.f1.x_weight = 0.05
net.f1.w_weight = 0.
net.f2.x_weight = 0.0
net.f2.w_weight = 0.05
def callback():
    if weight1:
        net.f1.w_weight = 0.1 if weight1 is True else 0.0
        net.f1.x_weight = 0.1 if weight1 is False else 0.0
    else:
        net.f2.x_weight = 0.1 if weight1 is True else 0.0
        net.f2.w_weight = 0.1 if weight1 is False else 0.0


optim(net, X, T, n_epochs=500, epoch_callback=callback)

# net.set_f(OrF('std', 'std'))
# # or_.f = 
# optim(net, X, T, n_epochs=500)
# net.set_f(f)
# optim(net, X, T, n_epochs=500)


 52%|█████▏    | 260/500 [00:23<00:22, 10.89it/s, loss=0.00135]


KeyboardInterrupt: 

In [5]:

# test g
torch.manual_seed(1)


base = nn.Linear(16, 8)
X = torch.rand(4000, 16)
base = Or(16, 8, f=OrF('std', 'std'))
T = base(X).detach()
# T = torch.sigmoid(base(X)).detach()

class Net(nn.Module):

    def __init__(self):
        super().__init__()
        # self.f = WrapNeuron(OrF('std', 'std'), MaxMinRelOut())
        # self.f = OrF('std', 'std')
        # self.f2 = WrapNeuron(OrF('std', 'std'), MaxMinRelOut())
        self.or1 = Or(16, 16, f=(
            Inter(MulG(0.05)), UnionOn(MulG(0.1))
        ))
        self.or2 = Or(16, 8, f=(
            Inter(MulG(0.05)), UnionOn(MulG(0.1))
        )) # OrF('std', 'std'))

    def set_f(self, f):
        self.or1.f = f
        self.or2.f = f
    
    def forward(self, x):

        return self.or2(self.or1(x))

net = Net()

weight_x = True
# net.f.x_weight = 0.05
# net.f.w_weight = 0.05
def callback():
    net.f.x_weight = 0.05 if weight_x is True else 0.0
    net.f.w_weight = 0.05 if weight_x is False else 0.0

optim(net, X, T, n_epochs=500)

 82%|████████▏ | 410/500 [00:53<00:11,  7.61it/s, loss=0.00272]


KeyboardInterrupt: 