In [1]:
import numpy as np
import torch
from torch import nn
from tropical import Tropical, TropicalMonomial, TropicalPolynomial

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def to_tensor(x):
    return torch.tensor(x).float()

In [4]:
inp_size = [3,3]
out_size = [3,2]

t = Tropical(0)

A = [np.random.randint(-10, 10, size=inp_size[i]*out_size[i]).reshape(out_size[i],inp_size[i]) for i in range(len(inp_size))]

b = [np.random.randn(out_size[i]) for i in range(len(A))]

In [5]:
class Net(torch.nn.Module):
    def __init__(self, inp_size, out_size):
        super(Net, self).__init__()
        
        self.linears = nn.ModuleList([
                                nn.Sequential(
                                    nn.Linear(inp_size[0], out_size[0]),
                                    nn.ReLU()
                                )])
        
        self.linears.extend([
            nn.Sequential(
                nn.Linear(inp_size[i], out_size[i]),
                nn.ReLU()
            ) for i in range(1, len(inp_size))])

        for i in range(len(inp_size)):
            self.linears[i][0].weight.data.copy_(to_tensor(A[i]))
            self.linears[i][0].bias.data.copy_(to_tensor(b[i]))
    def forward(self, output):
        for l in self.linears:
            output = l.forward(output)
        return output

In [6]:
model = Net(inp_size,out_size)

In [7]:
model

Net(
  (linears): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=3, out_features=3, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=3, out_features=2, bias=True)
      (1): ReLU()
    )
  )
)

In [8]:
for l in model.linears:
    print(l[0].weight.data)

tensor([[ 0., -4.,  4.],
        [ 6., -1., -7.],
        [ 0., -9., -1.]])
tensor([[-3.,  1., -1.],
        [-7.,  9.,  2.]])


In [9]:
def convert_net_to_tropical(net):
    
    d = model.linears[0][0].in_features
    f = [TropicalPolynomial([[0]+np.eye(d)[i].tolist()]) for i in range(d)]
    
    g = [TropicalPolynomial([[0]+np.zeros(d).tolist()]) for i in range(d)]


    for l in model.linears:
        
        n = l[0].in_features
        m = l[0].out_features
        a = l[0].weight.data.detach().cpu().numpy()
        a_plus = np.maximum(a,0)
        a_minus = np.maximum(-a,0)
        b_ = l[0].bias.data.detach().cpu().numpy()
        
        new_g = []
        new_h = []
        new_f = []
        
        for i in range(m):
            g_i = None
            h_i = None
            f_i = None
            for j in range(n):
                if g_i is None:
                    g_i = f[j]**a_minus[i][j]
                    g_i *= g[j]**a_plus[i][j]
                else:
                    g_i *= f[j]**a_minus[i][j]
                    g_i *= g[j]**a_plus[i][j]
                
                if h_i is None:
                    h_i = f[j]**a_plus[i][j]
                    h_i *= g[j]**a_minus[i][j]
                else:
                    h_i *= f[j]**a_plus[i][j]
                    h_i *= g[j]**a_minus[i][j]
                
            h_i *= Tropical(b_[i])
            f_i = h_i+g_i*t
            
            new_g.append(g_i)
            new_h.append(h_i)
            new_f.append(f_i)
        
        f = new_f
        g = new_g
        h = new_h

        
    return f,g,h
        

In [13]:
x = np.random.random(inp_size[0]).tolist()
x_t = to_tensor(x)

In [14]:
out = model.forward(x_t)

In [15]:
out

tensor([ 3.0852, 19.5080], grad_fn=<ReluBackward0>)

In [10]:
f,g,h = convert_net_to_tropical(model)

In [11]:
f

[1.0264563653618097⨀a^6.0⨀b^21.0⨀c ⨁ 1.0569193363189697⨀b^22.0⨀c^8.0 ⨁ 3.3155518770217896⨀b⨀c^19.0 ⨁ 0.6108950972557068⨀b^13.0⨀c^7.0 ⨁ 2.7046567797660828⨀b^10.0⨀c^20.0,
 2.2012282256036997⨀a^54.0⨀b^28.0 ⨁ 2.4753949642181396⨀b^37.0⨀c^63.0 ⨁ 0.9794380310922861⨀a^54.0⨀b^46.0⨀c^2.0 ⨁ 1.253604769706726⨀b^55.0⨀c^65.0 ⨁ 6.310865819454193⨀b^27.0⨀c^93.0]

In [12]:
g

[3.3155518770217896⨀b⨀c^19.0 ⨁ 0.6108950972557068⨀b^13.0⨀c^7.0 ⨁ 2.7046567797660828⨀b^10.0⨀c^20.0 ⨁ b^22.0⨀c^8.0,
 6.310865819454193⨀b^27.0⨀c^93.0 ⨁ b^55.0⨀c^65.0]

In [16]:
f[0].evaluate(x)/ g[0].evaluate(x), f[1].evaluate(x)/ g[1].evaluate(x)

(3.085188564095496, 19.508027819695464)