In [61]:
import numpy as np
from partition_decode.models import ReluNetClassifier
import torch

In [62]:
net = ReluNetClassifier()

In [63]:
X = np.random.normal(0, 1, (100, 2))
y = np.random.choice(2, (100))

In [64]:
net = net.fit(X, y)

Results for epoch 1, bce_loss=0.69, 01_error=0.44


In [65]:
def get_polytopes(model, train_x, penultimate=False):
    """
    Returns the polytopes.
    Points that has same activations values after fed to the model
     belong to the same polytope.
    """
    polytope_memberships = []
    last_activations = train_x
    penultimate_act = None
    layers = [module for module in model.modules() if type(module) == torch.nn.Linear]

    for layer_id, layer in enumerate(layers):
        weights, bias = (
            layer.weight.data.detach().cpu().numpy(),
            layer.bias.data.detach().cpu().numpy(),
        )
        preactivation = np.matmul(last_activations, weights.T) + bias
        if layer_id == len(layers) - 1:
            preactivation = 1 / (1 + np.exp(-1 / (1 + np.exp(-preactivation))))
            binary_preactivation = (preactivation > 0.5).astype("int")
        else:
            binary_preactivation = (preactivation > 0).astype("int")
        polytope_memberships.append(binary_preactivation)
        last_activations = preactivation * binary_preactivation

        if penultimate and layer_id == len(layers) - 1:
            penultimate_act = last_activations

    polytope_memberships = np.tensordot(
        np.concatenate(polytope_memberships, axis=1),
        2
        ** np.arange(0, np.shape(np.concatenate(polytope_memberships, axis=1))[1]),
        axes=1)

    if penultimate:
        return polytope_memberships, penultimate_act
    return np.asarray(polytope_memberships), last_activations

def get_internal_representation(model, X, penultimate=True):
#     if self.history_ is None:
#         raise RuntimeError("Classifier has not been fit")
    from torch.autograd import Variable
    split_size = 50 #math.ceil(len(X) / self.batch_size)
    
    irm = []
    for batch in np.array_split(X, split_size):
        batch_irm = []
        x_pred = Variable(torch.from_numpy(batch).float())
        for module in model.modules():
            x_pred = module(x_pred.detach())
            if type(module) == torch.nn.modules.activation.ReLU:
                batch_irm.append((x_pred.detach().numpy() > 0).astype(int))
        if penultimate:
            irm.append(batch_irm[-1])
        else:
            irm.append(np.hstack(batch_irm))
     
    return np.vstack(irm)

In [66]:
irm = get_internal_representation(net.model_, X, penultimate=False)

In [67]:
irm.shape

(100, 200)

In [25]:
polytope_memberships.shape

(100,)

In [26]:
last_activations.shape

(100, 2)

In [30]:
mods = list(net.model_.modules())

In [39]:
type(mods[2])

torch.nn.modules.activation.ReLU

In [70]:
net.model_.parameters()

<generator object Module.parameters at 0x7f2c645f7510>