In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim




In [21]:


def secret_function(x, l1, l2, l3):
    y = l1(x)
    y = F.relu(y)
    y = l2(y)
    y = torch.exp(2*y)
    y = l3(y)
    return y



In [6]:
x = torch.randn(100,100)

In [14]:
l1 = nn.Linear(100, 50)
l2 = nn.Linear(50, 50)
l3 = nn.Linear(50, 1)

In [22]:
y1 = secret_function(x, l1, l2, l3)

In [23]:
print(y1.min(), y1.max())

tensor(-0.3658, grad_fn=<MinBackward1>) tensor(0.9558, grad_fn=<MaxBackward1>)


In [47]:
class base_net(nn.Module):
    def __init__(self):
        super(base_net, self).__init__()
        self.l1 = nn.Linear(100,100)
        self.l2 = nn.Linear(100,100)
        self.l3 = nn.Linear(100,2)
        self.relu = nn.ReLU()
    def forward(self, x):
        y = self.l1(x)
        y = self.relu(y)
        y = self.l2(y)
        y = self.relu(y)
        out = self.l3(y)
        return out


In [53]:
x = torch.randn(1)

In [55]:
print(x.item())

0.6788895726203918


In [49]:



def custom_loss(output, target, m):
    estimation_loss = torch.mean((output[:,0] - target)**2)
    error_loss = torch.mean((output[:,1] - estimation_loss)**2)
    return (1-m) * estimation_loss + m * error_loss

In [57]:
# train the base net on random data using the custom loss:

base = base_net()
optimizer = optim.Adam(base.parameters(), lr=0.001)
for i in range(100000):
    x = torch.randn(100,100) + 2
    y = torch.sin(x)
    output = base(x)
    loss = custom_loss(output, y, 1/(i+0.00001))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 100 == 0:
        print(f'*****************     {i}    *****************')
        print(loss.item())
        estimation_loss = torch.mean((output[:,0] - y[:,0])**2)
        error_loss = torch.mean((output[:,1] - estimation_loss)**2)
        print(f'estimation loss: {estimation_loss.item()}, error loss: {error_loss.item()}')

*****************     0    *****************
5436050243584.0
estimation loss: 7365.81640625, error loss: 54258288.0
*****************     100    *****************
298808.4375
estimation loss: 5397.2412109375, error loss: 29016094.0
*****************     200    *****************
821025.3125
estimation loss: 12583.6748046875, error loss: 157424736.0
*****************     300    *****************
373060.4375
estimation loss: 10499.0546875, error loss: 109091168.0
*****************     400    *****************
242557.546875
estimation loss: 9601.853515625, error loss: 90888256.0
*****************     500    *****************
55344.0
estimation loss: 5066.5283203125, error loss: 24915480.0
*****************     600    *****************
173362.6875
estimation loss: 9898.74609375, error loss: 96322520.0
*****************     700    *****************
86156.59375
estimation loss: 7403.02734375, error loss: 53470900.0
*****************     800    *****************
101288.3359375
estimation loss:

In [36]:
x = torch.randn(5,100) * (10 ** 0.5) + 5
y = base(x)

In [38]:
print(y[:,0])

tensor([-0.5553, -0.5436, -0.1546, -0.6904, -0.4450],
       grad_fn=<SelectBackward0>)


In [34]:
x = torch.randn(5,100) * (10 ** 0.5) + 5
y = secret_function(x, l1, l2, l3)
output = base(x)
print(y, output)
print()


tensor([[-93.2841],
        [ 33.3999],
        [152.1014],
        [212.0457],
        [ 96.0142]], grad_fn=<AddmmBackward0>) tensor([[ 206.8811, 1194.0159],
        [ 268.5529, 1550.6440],
        [ 261.1402, 1516.1925],
        [ 244.1931, 1411.7697],
        [ 252.1568, 1463.8871]], grad_fn=<AddmmBackward0>)
