In [18]:
import numpy as np
import torch
from torch import nn
from tropical import Tropical, TropicalMonomial, TropicalPolynomial, convert_net_to_tropical

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
def to_tensor(x):
    return torch.tensor(x).float()

Let's create simple network with random integer weights and real biases

In [29]:
inp_size = [3,3]
out_size = [3,2]

t = Tropical(0)

A = [np.random.randint(-10, 10, size=inp_size[i]*out_size[i]).reshape(out_size[i],inp_size[i]) for i in range(len(inp_size))]

b = [np.random.randn(out_size[i]) for i in range(len(A))]

In [30]:
class Net(torch.nn.Module):
    def __init__(self, inp_size, out_size, bias=True):
        super(Net, self).__init__()
        
        self.linears = nn.ModuleList([nn.Linear(inp_size[0], out_size[0], bias=bias)])
        self.linears[0].weight.data.copy_(to_tensor(A[0]))
        self.linears[0].bias.data.copy_(to_tensor(b[0]))
        
        for i in range(1, len(inp_size)):
            self.linears.extend([nn.Linear(inp_size[i], out_size[i], bias=bias)])
            self.linears[i].weight.data.copy_(to_tensor(A[i]))
            self.linears[i].bias.data.copy_(to_tensor(b[i]))

    def forward(self, output):
        for i, l in enumerate(self.linears):
            output = l.forward(output)
            if i<len(self.linears)-1:
                output = torch.relu(output)
        return output

In [31]:
model = Net(inp_size,out_size)

In [32]:
model

Net(
  (linears): ModuleList(
    (0): Linear(in_features=3, out_features=3, bias=True)
    (1): Linear(in_features=3, out_features=2, bias=True)
  )
)

In [33]:
for l in model.linears:
    print(l.weight.data)

tensor([[-10.,  -3.,  -3.],
        [  0.,   3.,   0.],
        [ -3., -10.,   7.]])
tensor([[  6.,   1., -10.],
        [ -8.,   8.,   3.]])


In [34]:
x = np.random.random(inp_size[0]).tolist()
x_t = to_tensor(x)

In [35]:
out = model.forward(x_t)

In [36]:
out

tensor([ 2.4397, 16.8617], grad_fn=<AddBackward0>)

Convert this network into difference of two polynomial maps

In [37]:
h,g = convert_net_to_tropical(model)

In [38]:
h

[8.672683462500572⨀a^30⨀b^100 ⨁ 8.22757063806057⨀a^30⨀b^103 ⨁ -0.3366626650094986⨀a^90⨀b^121⨀c^18 ⨁ 0.10845015943050385⨀a^90⨀b^118⨀c^18,
 -1.7880237102508545⨀a^89⨀b^54⨀c^24 ⨁ -5.348926305770874⨀a^89⨀b^78⨀c^24 ⨁ -2.0436301603913307⨀a^80⨀b^24⨀c^45 ⨁ -5.60453275591135⨀a^80⨀b^48⨀c^45]

In [39]:
g

[a^90⨀b^118⨀c^18 ⨁ -0.8520215004682541⨀a^60⨀b^18⨀c^88,
 11.418977737426758⨀a^9⨀b^30 ⨁ a^89⨀b^54⨀c^24]

In [40]:
h[0].evaluate(x)/ g[0].evaluate(x), h[1].evaluate(x)/ g[1].evaluate(x)

(2.439659289677394, 16.861649331724323)