In [443]:
%reload_ext autoreload
%autoreload 2
%matplotlib widget

In [444]:
import torch

In [445]:
import bnn.data
import bnn.functions
import bnn.layer
import bnn.loss
import bnn.network
import bnn.optimizer

In [446]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

### Setup

In [447]:
forward_func = bnn.functions.forward.SignBinarise()
#forward_func=bnn.functions.forward.LayerMeanBinarise()
#forward_func=bnn.functions.forward.LayerMedianBinarise()

In [448]:
backward_func = bnn.functions.backward.SignTernarise()
#backward_func = bnn.functions.backward.StochasticTernarise()
#backward_func = bnn.functions.backward.LayerMeanStdTernarise(half_range_stds=0.5)
#backward_func = bnn.functions.backward.LayerQuantileTernarise(lo=0.25, hi=0.75)
backward_func = bnn.functions.backward.LayerQuantileSymmetricTernarise(prop_zero=0.33)

In [449]:
INPUT_DIM = 128
OUTPUT_DIM = 1

In [450]:
HIDDEN_DIM = INPUT_DIM
NUM_LAYERS = 10

In [451]:
dims = [INPUT_DIM] + [HIDDEN_DIM] * (NUM_LAYERS - 1) + [OUTPUT_DIM]

In [452]:
proj_TBNN = bnn.network.TernBinNetwork(
    dims,
    forward_func=forward_func,
    backward_func=backward_func,
)

In [453]:
noproj_TBNN = bnn.network.TernBinNetwork(
    dims,
    forward_func=forward_func,
    backward_func=bnn.functions.backward.ActualGradient(),
)

### Helper Funcs

In [454]:
def copy_weights(source: bnn.network.TernBinNetwork, target: bnn.network.TernBinNetwork):
    for name, layer in source.layers.items():
        target.layers[name].W.data = layer.W.data

In [455]:
def assert_same_weights(source: bnn.network.TernBinNetwork, target: bnn.network.TernBinNetwork):
    for name, layer in source.layers.items():
        assert torch.equal(target.layers[name].W.data, layer.W.data)

In [456]:
def grad_sign_confmats(
    input: torch.Tensor, 
    grad: torch.Tensor, 
    net1: bnn.network.TernBinNetwork, 
    net2: bnn.network.TernBinNetwork,
) -> dict[str, torch.Tensor]:
    # forward
    net1.forward(input)
    net2.forward(input)

    # backward
    net1.backward(grad)
    net2.backward(grad)

    confmats = {}
    for name in net1.layers:
        confmat = _grad_sign_confmat(net1.grad[name], net2.grad[name])
        confmats[name] = confmat
    
    return confmats


def _grad_sign_confmat(out1, out2):
    sign_out1 = torch.sign(out1)
    sign_out2 = torch.sign(out2)

    out = torch.empty(size=[3, 3])

    SYMBOLS = (-1, 0, 1)
    for i in SYMBOLS:
        for j in SYMBOLS:
            out[i, j] = torch.sum(sign_out2[sign_out1 == i] == j)

    out /= out.sum()
    return out


In [457]:
def print_confmat(confmat: torch.Tensor):
    print(
        "\t Actual grad sign\n"
        "\t 0 \t 1 \t -1 \t sum\n"
        "----------------------------------\n"
        f" 0  |\t {confmat[0, 0]:.3f} \t {confmat[1, 0]:.3f} \t {confmat[-1, 0]:.3f} \t| {confmat[:, 0].sum():.3f}\n"
        f" 1  |\t {confmat[0, 1]:.3f} \t {confmat[1, 1]:.3f} \t {confmat[-1, 1]:.3f} \t| {confmat[:, 1].sum():.3f}\n"
        f"-1  |\t {confmat[0, -1]:.3f} \t {confmat[1, -1]:.3f} \t {confmat[-1, -1]:.3f} \t| {confmat[:, -1].sum():.3f}\n"
        "----------------------------------\n"
        f"sum |\t {confmat[0, :].sum():.3f} \t {confmat[1, :].sum():.3f} \t {confmat[-1, :].sum():.3f} \t| {confmat.sum():.3f}\n"
    )

In [458]:
def confmat_nonzero_sign_error(confmat) -> float:
    nonzero = confmat[1].sum() + confmat[-1].sum()
    wrong_sign = confmat[1, -1] + confmat[-1, 1]

    prop = wrong_sign / nonzero

    return prop.item()

def confmat_nonzero_sign_error_inc0(confmat) -> float:
    nonzero = confmat[1].sum() + confmat[-1].sum()
    wrong_sign = confmat[1, -1] + confmat[-1, 1]
    zero = confmat[1, 0] + confmat[-1, 0]

    prop = (wrong_sign + zero) / nonzero

    return prop.item()

### Initialise

In [459]:
stable_zero_prob = 0.99 * (1 - 1/HIDDEN_DIM)
proj_TBNN._initialise(W_mean=0, W_zero_prob=stable_zero_prob)
list(proj_TBNN.layers.values())[-1]._initialise_W(mean=0, zero_prob=0.33)

In [460]:
copy_weights(source=proj_TBNN, target=noproj_TBNN)
assert_same_weights(source=proj_TBNN, target=noproj_TBNN)

### Check gradients

In [461]:
NUM_SAMPLES = 1024
input = bnn.random.generate_random_binary_tensor(shape=[NUM_SAMPLES, INPUT_DIM], mean=0)
grad = bnn.random.generate_random_binary_tensor(shape=[NUM_SAMPLES, OUTPUT_DIM], mean=0)

confmats = grad_sign_confmats(input=input, grad=grad, net1=noproj_TBNN, net2=proj_TBNN)

for name, confmat in confmats.items():
    print("layer: ", name)
    print_confmat(confmat)

layer:  TernBinLayer0
	 Actual grad sign
	 0 	 1 	 -1 	 sum
----------------------------------
 0  |	 0.558 	 0.067 	 0.070 	| 0.695
 1  |	 0.004 	 0.125 	 0.043 	| 0.172
-1  |	 0.000 	 0.035 	 0.098 	| 0.133
----------------------------------
sum |	 0.562 	 0.227 	 0.211 	| 1.000

layer:  TernBinLayer1
	 Actual grad sign
	 0 	 1 	 -1 	 sum
----------------------------------
 0  |	 0.569 	 0.055 	 0.063 	| 0.687
 1  |	 0.004 	 0.106 	 0.031 	| 0.141
-1  |	 0.012 	 0.020 	 0.141 	| 0.172
----------------------------------
sum |	 0.585 	 0.180 	 0.235 	| 1.000

layer:  TernBinLayer2
	 Actual grad sign
	 0 	 1 	 -1 	 sum
----------------------------------
 0  |	 0.569 	 0.043 	 0.047 	| 0.659
 1  |	 0.012 	 0.157 	 0.023 	| 0.192
-1  |	 0.004 	 0.027 	 0.117 	| 0.149
----------------------------------
sum |	 0.585 	 0.227 	 0.188 	| 1.000

layer:  TernBinLayer3
	 Actual grad sign
	 0 	 1 	 -1 	 sum
----------------------------------
 0  |	 0.577 	 0.051 	 0.082 	| 0.710
 1  |	 0.008 	 0.1

In [462]:
for name, confmat in confmats.items():
    print("layer: ", name)
    print(f"nonzero wrong sign prop: {confmat_nonzero_sign_error(confmat):.3f}")
    print(f"nonzero wrong sign (inc0) prop: {confmat_nonzero_sign_error_inc0(confmat):.3f}")

layer:  TernBinLayer0
nonzero wrong sign prop: 0.179
nonzero wrong sign (inc0) prop: 0.491
layer:  TernBinLayer1
nonzero wrong sign prop: 0.123
nonzero wrong sign (inc0) prop: 0.406
layer:  TernBinLayer2
nonzero wrong sign prop: 0.123
nonzero wrong sign (inc0) prop: 0.340
layer:  TernBinLayer3
nonzero wrong sign prop: 0.075
nonzero wrong sign (inc0) prop: 0.396
layer:  TernBinLayer4
nonzero wrong sign prop: 0.053
nonzero wrong sign (inc0) prop: 0.287
layer:  TernBinLayer5
nonzero wrong sign prop: 0.022
nonzero wrong sign (inc0) prop: 0.311
layer:  TernBinLayer6
nonzero wrong sign prop: 0.011
nonzero wrong sign (inc0) prop: 0.213
layer:  TernBinLayer7
nonzero wrong sign prop: 0.000
nonzero wrong sign (inc0) prop: 0.072
layer:  TernBinLayer8
nonzero wrong sign prop: 0.000
nonzero wrong sign (inc0) prop: 0.000
layer:  TernBinLayer9
nonzero wrong sign prop: 0.000
nonzero wrong sign (inc0) prop: 0.000
