In [1]:
from tqdm.notebook import tqdm
import random
import torch
from torch import nn
from torch import optim
import numpy as np

In [2]:
data = np.load("eu_d.npy", allow_pickle=True)
sq = data ** 2
prod = (data[:, 0] * data[:, 1])[:, None]
uod = (data[:, 0] / data[:, 1])[:, None]
u2od2 = uod ** 2
inv_d2 = sq[:, 1][:, None]
data = np.concatenate([data, sq, prod, inv_d2, uod, u2od2], axis=1)
np.random.shuffle(data)

In [3]:
data.shape

(3116150, 8)

In [4]:
class ResBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LeakyReLU(),
            nn.BatchNorm1d(dim),
        )
        
    def forward(self, x):
        return x + self.net(x)

In [12]:
net = nn.Sequential(
    nn.Linear(8,16),
    nn.LeakyReLU(),
    nn.BatchNorm1d(16),
    ResBlock(16),
#     ResBlock(128),
#     ResBlock(128),
#     ResBlock(128),
#     ResBlock(128),
#     ResBlock(128),
#     ResBlock(128),
    ResBlock(16),
    nn.Linear(16,1),
    nn.Softplus(),
)

In [13]:
lr = 1e-2
gamma = 0.95

In [14]:
optimizer = optim.AdamW(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
criterion = nn.MSELoss()

In [15]:
train_loader = torch.utils.data.DataLoader(data[:-100000], batch_size=10000)

In [16]:
val_loader = torch.utils.data.DataLoader(data[-100000:], batch_size=10000)

In [17]:
bar_info = {}

In [19]:
for epoch in range(1000):
    bar_info.update({"lr": scheduler.get_last_lr()[0]})
    net.train()
    for u_d in (bar := tqdm(train_loader)):
#         eu_d = torch.rand_like(eu_d) * 10
        net.zero_grad()
        sqdiff = torch.tensor((u_d[:, 0] - u_d[:, 1]) / u_d[:, 1])[:, None].square()
        pred = net(u_d)
        loss = criterion(pred, sqdiff)
        loss.backward()
        optimizer.step()
        bar_info.update({"err": np.sqrt(loss.item())})
        bar.set_postfix(bar_info)
    scheduler.step()
    net.eval()
    err = []
    err_ratio = []
    with torch.no_grad():
        for u_d in val_loader:
#             eu_d = torch.rand_like(eu_d) * 10
            sqdiff = torch.tensor((u_d[:, 0] - u_d[:, 1]) / u_d[:, 1])[:, None].square()
            pred = net(u_d)
            err += torch.abs(pred - sqdiff).tolist()
            err_ratio += torch.abs((pred - sqdiff) / sqdiff).tolist()
        print(torch.cat([sqdiff, pred], dim=1)[:10, :])
    print(epoch, np.mean(err), np.median(err_ratio), np.mean(err_ratio))

  0%|          | 0/302 [00:00<?, ?it/s]

  sqdiff = torch.tensor((u_d[:, 0] - u_d[:, 1]) / u_d[:, 1])[:, None].square()
  sqdiff = torch.tensor((u_d[:, 0] - u_d[:, 1]) / u_d[:, 1])[:, None].square()


tensor([[0.0933, 0.0929],
        [0.1077, 0.1065],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0236],
        [0.1338, 0.1315],
        [0.0869, 0.0846],
        [0.0073, 0.0077],
        [0.1112, 0.1097],
        [0.0006, 0.0022]])
0 0.0012997072551349993 0.030516049824655056 5489.990356876527


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0930],
        [0.1077, 0.1064],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1314],
        [0.0869, 0.0846],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
1 0.0013044953651091783 0.03011482860893011 5473.580261356086


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0930],
        [0.1077, 0.1064],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1314],
        [0.0869, 0.0845],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
2 0.0012939745951461373 0.029807355254888535 5457.392224303472


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0930],
        [0.1077, 0.1064],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0236],
        [0.1338, 0.1315],
        [0.0869, 0.0846],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
3 0.0012748675263748737 0.02959490194916725 5439.179993868009


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0930],
        [0.1077, 0.1063],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1315],
        [0.0869, 0.0847],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
4 0.0012506226428743684 0.029541456140577793 5418.973149839561


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0930],
        [0.1077, 0.1063],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1316],
        [0.0869, 0.0849],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
5 0.0012270054479083046 0.029357106424868107 5397.449442989108


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0930],
        [0.1077, 0.1063],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1316],
        [0.0869, 0.0850],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
6 0.001214744281178573 0.029347763396799564 5370.932986956281


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0931],
        [0.1077, 0.1063],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1317],
        [0.0869, 0.0851],
        [0.0073, 0.0077],
        [0.1112, 0.1096],
        [0.0006, 0.0022]])
7 0.0012027266950788908 0.029208594001829624 5351.534496019575


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0931],
        [0.1077, 0.1064],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1318],
        [0.0869, 0.0853],
        [0.0073, 0.0077],
        [0.1112, 0.1097],
        [0.0006, 0.0022]])
8 0.0011975639496592339 0.029228288680315018 5331.442654029684


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0931],
        [0.1077, 0.1064],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1319],
        [0.0869, 0.0854],
        [0.0073, 0.0077],
        [0.1112, 0.1097],
        [0.0006, 0.0022]])
9 0.0011938426008907845 0.028833528980612755 5318.726803152328


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0622],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1320],
        [0.0869, 0.0855],
        [0.0073, 0.0077],
        [0.1112, 0.1098],
        [0.0006, 0.0022]])
10 0.0011949651484959758 0.029006371274590492 5298.2663745517875


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1321],
        [0.0869, 0.0855],
        [0.0073, 0.0077],
        [0.1112, 0.1098],
        [0.0006, 0.0022]])
11 0.001195645031364984 0.0289413845166564 5276.861337994657


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1066],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1323],
        [0.0869, 0.0856],
        [0.0073, 0.0077],
        [0.1112, 0.1099],
        [0.0006, 0.0022]])
12 0.0011920704624857172 0.028933883644640446 5256.007591278912


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1066],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1323],
        [0.0869, 0.0857],
        [0.0073, 0.0077],
        [0.1112, 0.1099],
        [0.0006, 0.0022]])
13 0.001190980221608188 0.02882008906453848 5236.19116273623


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0623],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0857],
        [0.0073, 0.0076],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
14 0.001188088993608835 0.028710167855024338 5214.184627089848


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0624],
        [0.0045, 0.0042],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0857],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
15 0.0011845992112995008 0.02910929825156927 5193.258508857762


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0624],
        [0.0045, 0.0042],
        [0.0238, 0.0232],
        [0.1338, 0.1324],
        [0.0869, 0.0857],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
16 0.001180929670196201 0.029271972365677357 5171.931771101667


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0931],
        [0.1077, 0.1063],
        [0.0633, 0.0623],
        [0.0045, 0.0042],
        [0.0238, 0.0232],
        [0.1338, 0.1319],
        [0.0869, 0.0855],
        [0.0073, 0.0076],
        [0.1112, 0.1096],
        [0.0006, 0.0021]])
17 0.0012030703587352764 0.029647070914506912 5133.972719334864


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1067],
        [0.0633, 0.0624],
        [0.0045, 0.0042],
        [0.0238, 0.0232],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
18 0.0011781171708769398 0.029681585729122162 5125.3576510661205


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1067],
        [0.0633, 0.0624],
        [0.0045, 0.0042],
        [0.0238, 0.0231],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0076],
        [0.1112, 0.1101],
        [0.0006, 0.0021]])
19 0.0011778796429448993 0.02979337517172098 5103.054855318614


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1067],
        [0.0633, 0.0624],
        [0.0045, 0.0041],
        [0.0238, 0.0231],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
20 0.001182282217630418 0.030000044964253902 5080.084610192485


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0625],
        [0.0045, 0.0041],
        [0.0238, 0.0231],
        [0.1338, 0.1323],
        [0.0869, 0.0859],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
21 0.0011840860416059149 0.030029132030904293 5062.949231712858


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0625],
        [0.0045, 0.0041],
        [0.0238, 0.0230],
        [0.1338, 0.1323],
        [0.0869, 0.0858],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
22 0.001190305980751582 0.030001269653439522 5044.428648519938


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0625],
        [0.0045, 0.0041],
        [0.0238, 0.0230],
        [0.1338, 0.1323],
        [0.0869, 0.0858],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
23 0.001197066275133402 0.02987941075116396 5029.108226493414


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1067],
        [0.0633, 0.0625],
        [0.0045, 0.0041],
        [0.0238, 0.0230],
        [0.1338, 0.1322],
        [0.0869, 0.0858],
        [0.0073, 0.0076],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
24 0.0012043604034307645 0.029743455350399017 5014.38409358245


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1066],
        [0.0633, 0.0625],
        [0.0045, 0.0042],
        [0.0238, 0.0231],
        [0.1338, 0.1322],
        [0.0869, 0.0857],
        [0.0073, 0.0076],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
25 0.00121323210276576 0.029388356022536755 5002.886827055541


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1066],
        [0.0633, 0.0626],
        [0.0045, 0.0042],
        [0.0238, 0.0231],
        [0.1338, 0.1321],
        [0.0869, 0.0857],
        [0.0073, 0.0076],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
26 0.0012236419069988186 0.028942995704710484 4991.561715538901


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1066],
        [0.0633, 0.0626],
        [0.0045, 0.0042],
        [0.0238, 0.0231],
        [0.1338, 0.1320],
        [0.0869, 0.0857],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
27 0.0012357015262643108 0.028655996546149254 4981.3812914968685


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0626],
        [0.0045, 0.0042],
        [0.0238, 0.0232],
        [0.1338, 0.1320],
        [0.0869, 0.0857],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
28 0.0012492538239163697 0.02823957707732916 4971.747696713888


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0627],
        [0.0045, 0.0042],
        [0.0238, 0.0232],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
29 0.0012603337302434375 0.028294273652136326 4962.712758706531


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0627],
        [0.0045, 0.0042],
        [0.0238, 0.0232],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
30 0.0012700995369133306 0.02843872271478176 4953.205022384221


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0628],
        [0.0045, 0.0042],
        [0.0238, 0.0233],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
31 0.0012787607989792013 0.028610538691282272 4945.705888667819


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0628],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
32 0.0012847185294519296 0.028797045350074768 4939.752423389661


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0628],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
33 0.0012900997234269744 0.028908714652061462 4935.754267926076


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0629],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
34 0.001293601357769221 0.029174741357564926 4932.907535098852


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0932],
        [0.1077, 0.1065],
        [0.0633, 0.0629],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
35 0.0012955426669178996 0.029496455565094948 4928.043586109672


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1065],
        [0.0633, 0.0630],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
36 0.0012961941492912593 0.029548496939241886 4923.186535608256


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1065],
        [0.0633, 0.0630],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
37 0.0012949607787939021 0.02963855490088463 4917.843762736327


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1065],
        [0.0633, 0.0630],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
38 0.001292296729022346 0.029689940623939037 4912.551600944568


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1065],
        [0.0633, 0.0630],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1320],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1098],
        [0.0006, 0.0021]])
39 0.001288782802670903 0.029712487943470478 4907.491764524789


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1064],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1319],
        [0.0869, 0.0855],
        [0.0073, 0.0075],
        [0.1112, 0.1097],
        [0.0006, 0.0021]])
40 0.001298816305285145 0.029933477751910686 4899.832892426504


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1065],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1321],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
41 0.0012779525566822849 0.02980582509189844 4895.941434641132


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0933],
        [0.1077, 0.1065],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1321],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
42 0.0012722594004447455 0.029835525900125504 4891.6536716003275


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1321],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
43 0.0012649676906806417 0.029852968640625477 4885.420198552914


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1322],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
44 0.0012578712270673714 0.029868580400943756 4879.509051365188


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1322],
        [0.0869, 0.0856],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
45 0.0012506813398117083 0.029895581305027008 4874.375549740331


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1322],
        [0.0869, 0.0857],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
46 0.0012433231058367527 0.02990797907114029 4869.854993355759


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0631],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1323],
        [0.0869, 0.0857],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
47 0.0012356499608478042 0.029922112822532654 4864.221934980903


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1323],
        [0.0869, 0.0857],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
48 0.0012286002890797682 0.029934094287455082 4859.709272612802


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0934],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0235],
        [0.1338, 0.1322],
        [0.0869, 0.0857],
        [0.0073, 0.0075],
        [0.1112, 0.1099],
        [0.0006, 0.0021]])
49 0.0012285774607170607 0.029986918903887272 4853.425276966095


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1323],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
50 0.0012148223123612115 0.02993688825517893 4850.091220799489


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1323],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
51 0.001208304152927012 0.029948342591524124 4846.143512084659


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
52 0.0012020832294534194 0.029935150407254696 4841.824456226817


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
53 0.001196432365021028 0.029936141334474087 4838.078713174613


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
54 0.0011909222746483282 0.029913133941590786 4834.212697599842


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
55 0.0011860088812932373 0.0299066249281168 4830.989978238581


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
56 0.0011814483792849933 0.029863514937460423 4827.853326242949


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0858],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
57 0.0011807445934796123 0.029931169003248215 4824.6978830767375


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
58 0.0011732201724662446 0.029858073219656944 4821.688844883857


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
59 0.0011698109418575768 0.02988116815686226 4818.393595771905


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
60 0.0011664951359201222 0.029894214123487473 4815.102421331674


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
61 0.001163393846594845 0.02988656796514988 4811.973294083414


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
62 0.001160482560757373 0.02985814493149519 4808.926784352459


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
63 0.0011578911448188592 0.029838048852980137 4806.046725284905


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
64 0.0011557247970151365 0.029828130267560482 4803.66333284147


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
65 0.0011558013110459433 0.029865038581192493 4801.347436962677


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
66 0.0011522831463767215 0.02978996466845274 4799.771647565274


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0044],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
67 0.0011506592500710395 0.02977199386805296 4798.060613615827


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1067],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
68 0.0011491141871915898 0.029741116799414158 4795.836190618885


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0234],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
69 0.0011479182377032702 0.029723734594881535 4794.014984714625


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
70 0.0011467921368591488 0.029733438044786453 4792.402591877695


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
71 0.001145776015559386 0.02969621866941452 4790.879186307798


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0936],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
72 0.0011450414210555028 0.0296857925131917 4789.48478734455


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
73 0.0011453405349893728 0.029710018076002598 4787.926572037561


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
74 0.0011438132427292293 0.029651891440153122 4787.0268817901115


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
75 0.0011433425887735211 0.029635055921971798 4786.011892111641


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
76 0.0011429750137028167 0.029630791395902634 4785.113380429029


  0%|          | 0/302 [00:00<?, ?it/s]

tensor([[0.0933, 0.0935],
        [0.1077, 0.1066],
        [0.0633, 0.0632],
        [0.0045, 0.0043],
        [0.0238, 0.0233],
        [0.1338, 0.1324],
        [0.0869, 0.0859],
        [0.0073, 0.0075],
        [0.1112, 0.1100],
        [0.0006, 0.0021]])
77 0.0011426699579608977 0.02962195035070181 4784.297893704185


  0%|          | 0/302 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [192]:
data[-100000:][np.argmax(err)]

array([2.0000076, 2.       , 4.0000305, 4.       , 4.0000153],
      dtype=float32)

In [439]:
print(net(torch.tensor([[0., 0.],
                        [0., 1],
                        [0., 2],
                        [0., 3],
                        [0., 4],
                        [0., 5],
                        [0., 6],
                        [0., 7],
                        [0., 8],
                        [0., 9]])))

tensor([[ 0.8370],
        [ 1.3064],
        [ 3.7761],
        [ 8.9825],
        [17.3631],
        [27.4014],
        [37.5714],
        [47.5369],
        [57.0769],
        [66.3386]], grad_fn=<AddmmBackward>)


In [112]:
net.eval()
err = []
with torch.no_grad():
    for eu_d in (bar := tqdm(data[-5000:, :])):
        sqdiff = torch.tensor(eu_d[0] - eu_d[1]).square()
        pred = net(torch.tensor(eu_d))
        err.append(np.abs(pred.item() - sqdiff.item()) / sqdiff.item())
        print(pred.item(), sqdiff.item())

  0%|          | 0/10000 [00:00<?, ?it/s]

0.4058016538619995 1.343533992767334
24.172975540161133 28.480993270874023
3.702401876449585 3.4859619140625
0.46427279710769653 0.11884527653455734
0.4227063059806824 1.0068390369415283
0.46118849515914917 0.008778247982263565
0.4758131504058838 0.03400048986077309
0.4633966088294983 0.10225093364715576
0.46653103828430176 0.006437885574996471
0.41779059171676636 0.6680024266242981
0.4586895704269409 0.09229059517383575
0.4427664875984192 0.5741840600967407
0.4328789710998535 0.302799791097641
0.6527474522590637 0.44908860325813293
0.42768287658691406 0.7422394156455994
0.4410039782524109 0.10152947157621384
0.4796168804168701 0.016825322061777115
0.465549111366272 0.2962101399898529
0.4572637677192688 0.03287649154663086
0.47245150804519653 0.0564974844455719
10.757575035095215 6.9920334815979
0.4306986927986145 0.3750700354576111
0.45801401138305664 0.10852419584989548
0.4778231978416443 0.007147960364818573
0.43711429834365845 0.5711864233016968
0.4567188620567322 0.273633509874343

In [113]:
np.median(err)

0.6662024000679129

In [222]:
net.eval()

Sequential(
  (0): Linear(in_features=2, out_features=4, bias=True)
  (1): LeakyReLU(negative_slope=0.01)
  (2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Linear(in_features=4, out_features=8, bias=True)
  (4): LeakyReLU(negative_slope=0.01)
  (5): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Linear(in_features=8, out_features=8, bias=True)
  (7): LeakyReLU(negative_slope=0.01)
  (8): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): Linear(in_features=8, out_features=8, bias=True)
  (10): LeakyReLU(negative_slope=0.01)
  (11): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): Linear(in_features=8, out_features=4, bias=True)
  (13): LeakyReLU(negative_slope=0.01)
  (14): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (15): Linear(in_features=4, out_features=2, bias=True)
  (16): LeakyReLU(negative

In [252]:
net(torch.tensor([[0., 0.],[0., .5]]))

tensor([[0.1162],
        [0.1162]], grad_fn=<AddmmBackward>)