In [1]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import torch
from poutyne import Model

In [2]:
import torch
torch.manual_seed(5321)
#torch.set_deterministic(True)

<torch._C.Generator at 0x296ca45fdb0>

In [3]:
# Dimensions of input and output
dim = 4
rs = np.random.RandomState(412112)
hidden_layers = 3
hidden_units = dim

# Learnable backward net (inverse rules of the game that we need to learn)
class BackwardNet(nn.Module):
    def __init__(self):
        super(BackwardNet, self).__init__()
        self.hidden_layers = [nn.Linear(dim if i == 0 else hidden_units, hidden_units) for i in range(hidden_layers)]
        self.output_layer = nn.Linear(hidden_units, dim)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        x = torch.sigmoid(self.output_layer(x))
        return x

# Fixed forward net (rules of the game)
class ForwardNet(nn.Module):
    def __init__(self):
        super(ForwardNet, self).__init__()
        self.w1 = torch.tensor(rs.choice([1, -1.1], size=(dim, dim)), dtype=torch.float, requires_grad=False)
        self.b1 = torch.tensor(rs.choice([0, -0.5], size=dim), dtype=torch.float, requires_grad=False)
        self.w2 = torch.tensor(rs.choice([1, -1.1], size=(dim, dim)), dtype=torch.float, requires_grad=False)
        self.b2 = torch.tensor(rs.choice([0, -0.5], size=dim), dtype=torch.float, requires_grad=False)

    def forward(self, x):
        x = F.relu(x @ self.w1 + self.b1)
        x = torch.sigmoid(x @ self.w2 + self.b2)
        return x

class FullNet(nn.Module):
    def __init__(self):
        super(FullNet, self).__init__()
        self.backward_net = BackwardNet()
        self.forward_net = ForwardNet()

    def forward(self, x):
        x = self.backward_net.forward(x)
        x = self.forward_net.forward(x)
        return x

In [4]:
net = FullNet()
print(net)

criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

model = Model(net, optimizer, criterion)

m = 10
X = torch.tensor(rs.choice([1, 0], size=(m, dim)), dtype=torch.float)


FullNet(
  (backward_net): BackwardNet(
    (output_layer): Linear(in_features=4, out_features=4, bias=True)
  )
  (forward_net): ForwardNet()
)


In [5]:
# Example of how game rules make step foward (just to make sure our rules don't output all zeros or all ones etc.)
Y = net.forward_net(X).round()
print('X:')
print(X)
print('Step forward from X:')
print(Y)

X:
tensor([[1., 0., 1., 1.],
        [1., 1., 1., 1.],
        [1., 0., 0., 0.],
        [1., 0., 1., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 1.],
        [1., 0., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
Step forward from X:
tensor([[1., 1., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 0., 0.],
        [1., 1., 1., 1.],
        [1., 0., 0., 0.],
        [1., 1., 0., 1.],
        [1., 0., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])


In [6]:
X_hat = model.predict(X).round()

print('X:')
print(X)

print('X_hat before training:')
print(X_hat)

mae_0 = torch.mean(torch.abs(X - X_hat))
print(f'MAE before training: {mae_0}')

X:
tensor([[1., 0., 1., 1.],
        [1., 1., 1., 1.],
        [1., 0., 0., 0.],
        [1., 0., 1., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 1.],
        [1., 0., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
X_hat before training:
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 0. 1. 1.]
 [1. 1. 1. 1.]
 [1. 0. 1. 1.]
 [1. 0. 1. 1.]
 [1. 0. 1. 1.]
 [1. 1. 1. 1.]
 [1. 0. 1. 1.]
 [1. 0. 1. 1.]]
MAE before training: 0.44999998807907104


In [7]:
model.fit(X, X, epochs=1000)

[93mEpoch: [94m1/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.13s [93mloss:[96m 0.691700[0m691700
[93mEpoch: [94m2/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.01s [93mloss:[96m 0.690110[0m690110
[93mEpoch: [94m3/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.01s [93mloss:[96m 0.688543[0m688543
[93mEpoch: [94m4/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.01s [93mloss:[96m 0.687000[0m687000
[93mEpoch: [94m5/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.685481[0m685481
[93mEpoch: [94m6/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.683987[0m683987
[93mEpoch: [94m7/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.682518[0m682518
[93mEpoch: [94m8/1

[93mEpoch: [94m59/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.642796[0m642796
[93mEpoch: [94m60/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.642585[0m642585
[93mEpoch: [94m61/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.642378[0m642378
[93mEpoch: [94m62/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.642174[0m642174
[93mEpoch: [94m63/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.641972[0m641972
[93mEpoch: [94m64/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.641773[0m641773
[93mEpoch: [94m65/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.641575[0m641575
[93mEpoch: 

[93mEpoch: [94m116/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.631474[0m631474
[93mEpoch: [94m117/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.631273[0m631273
[93mEpoch: [94m118/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.631071[0m631071
[93mEpoch: [94m119/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.630870[0m630870
[93mEpoch: [94m120/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.630669[0m630669
[93mEpoch: [94m121/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.630467[0m630467
[93mEpoch: [94m122/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.630265[0m630265
[93mE

[93mEpoch: [94m173/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.620006[0m620006
[93mEpoch: [94m174/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.619806[0m619806
[93mEpoch: [94m175/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.619606[0m619606
[93mEpoch: [94m176/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.619406[0m619406
[93mEpoch: [94m177/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.619206[0m619206
[93mEpoch: [94m178/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.619007[0m619007
[93mEpoch: [94m179/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.618807[0m618807
[93mE

[93mEpoch: [94m230/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.608765[0m608765
[93mEpoch: [94m231/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.608571[0m608571
[93mEpoch: [94m232/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.608377[0m608377
[93mEpoch: [94m233/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.608183[0m608183
[93mEpoch: [94m234/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.607990[0m607990
[93mEpoch: [94m235/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.607796[0m607796
[93mEpoch: [94m236/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.607603[0m607603
[93mE

[93mEpoch: [94m287/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.597920[0m597920
[93mEpoch: [94m288/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.597734[0m597734
[93mEpoch: [94m289/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.597547[0m597547
[93mEpoch: [94m290/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.597361[0m597361
[93mEpoch: [94m291/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.597175[0m597175
[93mEpoch: [94m292/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.596990[0m596990
[93mEpoch: [94m293/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.596804[0m596804
[93mE

[93mEpoch: [94m344/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.587532[0m587532
[93mEpoch: [94m345/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.587354[0m587354
[93mEpoch: [94m346/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.587176[0m587176
[93mEpoch: [94m347/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.586998[0m586998
[93mEpoch: [94m348/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.586821[0m586821
[93mEpoch: [94m349/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.586644[0m586644
[93mEpoch: [94m350/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.586467[0m586467
[93mE

[93mEpoch: [94m401/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.577647[0m577647
[93mEpoch: [94m402/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.577479[0m577479
[93mEpoch: [94m403/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.577310[0m577310
[93mEpoch: [94m404/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.577142[0m577142
[93mEpoch: [94m405/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.576973[0m576973
[93mEpoch: [94m406/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.576806[0m576806
[93mEpoch: [94m407/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.576638[0m576638
[93mE

[93mEpoch: [94m458/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.568301[0m568301
[93mEpoch: [94m459/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.568142[0m568142
[93mEpoch: [94m460/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.567983[0m567983
[93mEpoch: [94m461/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.567824[0m567824
[93mEpoch: [94m462/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.567665[0m567665
[93mEpoch: [94m463/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.567507[0m567507
[93mEpoch: [94m464/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.567349[0m567349
[93mE

[93mEpoch: [94m515/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.559502[0m559502
[93mEpoch: [94m516/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.559352[0m559352
[93mEpoch: [94m517/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.559203[0m559203
[93mEpoch: [94m518/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.559054[0m559054
[93mEpoch: [94m519/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.558905[0m558905
[93mEpoch: [94m520/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.558756[0m558756
[93mEpoch: [94m521/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.558607[0m558607
[93mE

[93mEpoch: [94m572/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.551310[0m551310
[93mEpoch: [94m573/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.551267[0m551267
[93mEpoch: [94m574/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.551039[0m551039
[93mEpoch: [94m575/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.550909[0m550909
[93mEpoch: [94m576/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.550778[0m550778
[93mEpoch: [94m577/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.550647[0m550647
[93mEpoch: [94m578/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.550516[0m550516
[93mE

[93mEpoch: [94m629/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.543724[0m543724
[93mEpoch: [94m630/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.543732[0m543732
[93mEpoch: [94m631/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.543474[0m543474
[93mEpoch: [94m632/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.543357[0m543357
[93mEpoch: [94m633/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.543238[0m543238
[93mEpoch: [94m634/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.543118[0m543118
[93mEpoch: [94m635/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.542996[0m542996
[93mE

[93mEpoch: [94m686/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536730[0m536730
[93mEpoch: [94m687/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536604[0m536604
[93mEpoch: [94m688/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536478[0m536478
[93mEpoch: [94m689/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536405[0m536405
[93mEpoch: [94m690/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536248[0m536248
[93mEpoch: [94m691/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536142[0m536142
[93mEpoch: [94m692/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.536034[0m536034
[93mE

[93mEpoch: [94m743/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.530272[0m530272
[93mEpoch: [94m744/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.530159[0m530159
[93mEpoch: [94m745/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.530043[0m530043
[93mEpoch: [94m746/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.529926[0m529926
[93mEpoch: [94m747/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.529809[0m529809
[93mEpoch: [94m748/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.529690[0m529690
[93mEpoch: [94m749/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.529683[0m529683
[93mE

[93mEpoch: [94m800/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.524273[0m524273
[93mEpoch: [94m801/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.524181[0m524181
[93mEpoch: [94m802/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.524085[0m524085
[93mEpoch: [94m803/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.523986[0m523986
[93mEpoch: [94m804/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.523884[0m523884
[93mEpoch: [94m805/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.523780[0m523780
[93mEpoch: [94m806/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.523673[0m523673
[93mE

[93mEpoch: [94m857/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518762[0m518762
[93mEpoch: [94m858/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518659[0m518659
[93mEpoch: [94m859/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518569[0m518569
[93mEpoch: [94m860/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518480[0m518480
[93mEpoch: [94m861/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518403[0m518403
[93mEpoch: [94m862/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518320[0m518320
[93mEpoch: [94m863/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.518232[0m518232
[93mE

[93mEpoch: [94m914/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513665[0m513665
[93mEpoch: [94m915/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513592[0m513592
[93mEpoch: [94m916/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513513[0m513513
[93mEpoch: [94m917/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513428[0m513428
[93mEpoch: [94m918/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513339[0m513339
[93mEpoch: [94m919/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513246[0m513246
[93mEpoch: [94m920/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.513149[0m513149
[93mE

[93mEpoch: [94m971/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.509147[0m509147
[93mEpoch: [94m972/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.509051[0m509051
[93mEpoch: [94m973/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.508983[0m508983
[93mEpoch: [94m974/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.509021[0m509021
[93mEpoch: [94m975/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.508823[0m508823
[93mEpoch: [94m976/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.508777[0m508777
[93mEpoch: [94m977/1000 [93mStep: [94m1/1 [93m100.00% |[92m█████████████████████████[93m|[32m0.00s [93mloss:[96m 0.508719[0m508719
[93mE

[{'epoch': 1, 'loss': 0.691700279712677, 'time': 0.1321707},
 {'epoch': 2, 'loss': 0.6901101469993591, 'time': 0.008543300000000004},
 {'epoch': 3, 'loss': 0.6885432004928589, 'time': 0.005998699999999996},
 {'epoch': 4, 'loss': 0.687000036239624, 'time': 0.006149099999999991},
 {'epoch': 5, 'loss': 0.6854811310768127, 'time': 0.004104099999999999},
 {'epoch': 6, 'loss': 0.6839872598648071, 'time': 0.004164399999999985},
 {'epoch': 7, 'loss': 0.682518482208252, 'time': 0.0030739000000000183},
 {'epoch': 8, 'loss': 0.6810759902000427, 'time': 0.0030941000000000163},
 {'epoch': 9, 'loss': 0.6796598434448242, 'time': 0.003364800000000001},
 {'epoch': 10, 'loss': 0.6782705187797546, 'time': 0.003960899999999989},
 {'epoch': 11, 'loss': 0.6769083738327026, 'time': 0.002792799999999984},
 {'epoch': 12, 'loss': 0.6755732297897339, 'time': 0.003710700000000011},
 {'epoch': 13, 'loss': 0.6742659211158752, 'time': 0.0030347000000000013},
 {'epoch': 14, 'loss': 0.6729861497879028, 'time': 0.00362

In [8]:
model.predict(X)

array([[0.8070521 , 0.17163771, 0.71727157, 0.71727157],
       [0.7803937 , 0.19864918, 0.6830799 , 0.6830799 ],
       [0.616763  , 0.37205413, 0.49395812, 0.49395812],
       [0.8070521 , 0.17163771, 0.71727157, 0.71727157],
       [0.53055763, 0.46639538, 0.40670177, 0.40670177],
       [0.5482085 , 0.44700527, 0.4239533 , 0.4239533 ],
       [0.69649   , 0.28624275, 0.5819148 , 0.5819148 ],
       [0.8070521 , 0.17163771, 0.71727157, 0.71727157],
       [0.50005037, 0.49994457, 0.37758803, 0.37758803],
       [0.50005037, 0.49994457, 0.37758803, 0.37758803]], dtype=float32)

In [9]:
X_hat = model.predict(X).round()
mae_1 = torch.mean(torch.abs(X - X_hat))
print(f'MAE after training: {mae_1} (improvement: {mae_1 - mae_0})')


print('X:')
print(X)

print('X_hat after training:')
print(X_hat)


MAE after training: 0.20000000298023224 (improvement: -0.2499999850988388)
X:
tensor([[1., 0., 1., 1.],
        [1., 1., 1., 1.],
        [1., 0., 0., 0.],
        [1., 0., 1., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 1.],
        [1., 0., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
X_hat after training:
[[1. 0. 1. 1.]
 [1. 0. 1. 1.]
 [1. 0. 0. 0.]
 [1. 0. 1. 1.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 1. 1.]
 [1. 0. 1. 1.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]]
