In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np

**Torch, Numpy stuff**

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

cuda


---
**Training Data**

Demonstration data contains a list of **d** scenes:

- ~~Each scene contains a list of trajectories of **p** people, where p is **not constant**~~
- A trajectory is a list of **t** states, where t = 400
- A state is a **s** = 4 dimensional variable
 - State = (d<sub>goal<sub>x</sub></sub>, d<sub>goal<sub>y</sub></sub>, v<sub>x</sub>, v<sub>y</sub>)
    
> Shape of data is (d, ~~p,~~ t, s) => (d, ~~p,~~ 400, 4)

---

Train, test, val split: 0.8, 0.01, 0.19

---

### Preparing the data

**n <= n<sub>max</sub>** random number of random observations on a random trajectory

**get image**

In [3]:
from PIL import Image
import torchvision.transforms as T

def get_frames(path, demonstration_id, observation_ids):
    frames_path = f'{path}{demonstration_id}/'
    transform = T.Compose([T.ToTensor()])
    frames = []
    for i in observation_ids:
        frames.append(transform(Image.open(f"{frames_path}{i}.jpg")))
                      
    frames = torch.stack(frames, 0)
    return frames

**get all input components**

In [3]:
n_max = 30
d_size = 9500  # nof demonstrations
t_size = 400  # length of trajectories
path="../data/processed/input/train/"  # .../input/train/ contains folders from 0 to 9499

def sample_training_demonstration():
#     d, t, s = X_train.shape
    
    rand_traj_ind = np.random.randint(0, d_size)
    n = np.random.randint(1, n_max+1)

    rand_traj = np.load(f"{path}{rand_traj_ind}/states.npy")

    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    frames = get_frames(path, rand_traj_ind, observation_indices[:-1])
    
    observations = torch.from_numpy(rand_traj[observation_indices[:-1], :])
    targetX = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 0:2]), 0)
    targetY = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 2:]), 0)
    if device.type == 'cuda':
        return frames.float().cuda(), observations.float().cuda(), targetX.float().cuda(), targetY.float().cuda()
    else:
        return frames.float(), observations.float(), targetX.float(), targetY.float()

FileNotFoundError: [Errno 2] No such file or directory: '../data/processed/input/train/994/states.npy'

The same for validation samples

In [5]:
dv_size = 490  # nof validation demonstrations
vpath = "../data/processed/input/val/"  # .../input/val/ contains folders from 9500 to 9989

def get_validation_demonstration(i):
    demostration_ind = i + d_size

    traj = np.load(f"{vpath}{demostration_ind}/states.npy")
    
    n = np.random.randint(1, n_max+1)
    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    frames = get_frames(vpath, demostration_ind, observation_indices[:-1])
    
    observations = torch.from_numpy(traj[observation_indices[:-1], :])
    targetX = torch.unsqueeze(torch.from_numpy(traj[observation_indices[-1], 0:2]), 0)
    targetY = torch.unsqueeze(torch.from_numpy(traj[observation_indices[-1], 2:]), 0)
    
    if device.type == 'cuda':
        return frames.float().cuda(), observations.float().cuda(), targetX.float().cuda(), targetY.float().cuda()
    else:
        return frames.float(), observations.float(), targetX.float(), targetY.float()

---
### Model

In [6]:
class CNP(nn.Module):
    def __init__(self):
        super(CNP, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 8, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(4+64*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024)
        )
        
        self.query = nn.Sequential(
            nn.Linear(1024+2,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,2*2)
        )

    def forward(self, frames, observations, target):
        # n < n_max frames of a scene along with momentary observations are concatenated to constitute input
        scene_encodings = torch.flatten(self.cnn(frames))
        scene_encodings = self.cnn(frames).view(frames.shape[0], 64*4*4)
        encoder_in = torch.cat((observations, scene_encodings), 1)
        r = self.encoder(encoder_in)
        
        r_avg = torch.mean(r, dim=0)
        r_avgs = r_avg.repeat(target.shape[0], 1)  # repeating the same r_avg for each target
        r_avg_target = torch.cat((r_avgs, target), 1)
        query_out = self.query(r_avg_target)
        
        return query_out

    
def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))

def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

### Training

#### 1- Methods and Variables for Validation 

In [7]:
def validate():
    vloss=np.zeros(dv_size)
    with torch.no_grad():
        for i in range(dv_size):
            fs, obss, tx, ty = get_validation_demonstration(i)
            ty_pred = model(fs, obss, tx)
            vloss[i] = log_prob_loss(ty, ty_pred)
            
    return np.mean(vloss)

In [8]:
val_after_epoch = 2500

#### 2- Actual Training

In [11]:
from tqdm import tqdm

model = CNP()
model.apply(initialize_weights)
model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters(), betas=(0.9, 0.999), amsgrad=True)

epoch = 20000000

losses = np.zeros(int(epoch/val_after_epoch))
min_loss = 1e6

for i in range(epoch):
    fs, obss, tx, ty = sample_training_demonstration()
    
    optimizer.zero_grad()
    ty_pred = model(fs, obss, tx)
    loss = log_prob_loss(ty, ty_pred)
    
    loss.backward()
    optimizer.step()
    
    print(i, end="\r")
    if i%val_after_epoch == 0:
        val_loss = validate()
        print(f"{i}: {val_loss}")
    
        if val_loss < min_loss:
            min_loss = val_loss
            torch.save(model.state_dict(), f'{path}../../best_model.pt')


0: 4.075285734084188
2500: 2.2678578089694588
5000: 2.2613347497521614
7500: 2.2547294201899546
10000: 2.3133686401406113
12500: 2.2391293708159
15000: 2.231107589420007
17500: 2.293084923831784
20000: 2.238102569750377
22500: 2.2273242915163234
25000: 2.219854364711411
27500: 2.231999917054663
30000: 2.2133188623554854
32500: 2.2215249553018688
35000: 2.22607035758544
37500: 2.2120101659881826
40000: 2.2119695234055423
42500: 2.2156157277068314
45000: 2.2289829221307014
47500: 2.2485376045411947
50000: 2.225163831759472
52500: 2.227179199578811
55000: 2.2096407909782565
57500: 2.194340138167751
60000: 2.227353668821101
62500: 2.195209777355194
65000: 2.221446590034329
67500: 2.2490262280921547
70000: 2.2210747645825757
72500: 2.1816903952433138
75000: 2.2131019665270437
77500: 2.1954483079666995
80000: 2.2173403680324553
82500: 2.23465002361609
85000: 2.248476248493
87500: 2.274792685922311
90000: 2.2351584418695802
92500: 2.265560816015516
95000: 2.228611462335197
97500: 2.1956218235

1530000: 2.213674840878467
1532500: 2.235715466494463
1535000: 2.2110882652049164
1537500: 2.198549707081853
1540000: 2.2066728084671254
1542500: 2.226641423969853
1545000: 2.195820205430595
1547500: 2.1907332498200085
1550000: 2.212312020331013
1552500: 2.219773230747301
1555000: 2.1948961118046117
1557500: 2.2061255340673487
1560000: 2.1962544104274437
1562500: 2.19483700200003
1565000: 2.2144704386896015
1567500: 2.1975606248086814
1570000: 2.180394594036803
1572500: 2.1851704866302257
1575000: 2.1703689733330083
1577500: 2.1954470198981615
1580000: 2.21357995113548
1582500: 2.198752536214128
1585000: 2.2440083826074795
1587500: 2.1973088894571577
1590000: 2.199623828153221
1592500: 2.2093701161900343
1595000: 2.2197623831885203
1597500: 2.2182949512588737
1600000: 2.1979239522194374
1602500: 2.211290513860936
1605000: 2.23398989232219
1607500: 2.2004966865996924
1610000: 2.2263521256495493
1612500: 2.2162918192999705
1615000: 2.1754634095697987
1617500: 2.1892012498816666
1620000: 

2275000: 2.197050829444613
2277500: 2.213581740126318
2280000: 2.209341602057827
2282500: 2.2158962946765275
2285000: 2.191322331404199
2287500: 2.192608999232857
2290000: 2.1654347737224735
2292500: 2.2102401115456405
2295000: 2.216118237315392
2297500: 2.169810398500793
2300000: 2.2088176462115072
2302500: 2.196545871909784
2305000: 2.1940039913265075
2307500: 2.1877142631277744
2310000: 2.1875886355127605
2312500: 2.1884167255187523
2315000: 2.2040339190132765
2317500: 2.191119027867609
2320000: 2.181717251636544
2322500: 2.230873760033627
2325000: 2.21841046457388
2327500: 2.176453795238417
2330000: 2.1673748747426638
2332500: 2.2152563445422113
2335000: 2.2283927933293946
2337500: 2.1974884124434726
2340000: 2.182648011616298
2342500: 2.1891169314481775
2345000: 2.183554219226448
2347500: 2.222038264298926
2350000: 2.2242424753247474
2352500: 2.1954569474774965
2355000: 2.2019312672469082
2357500: 2.1838267726557596
2360000: 2.2066589702148827
2362500: 2.207317368959894
2365000: 2

3765000: 2.199668620435559
3767500: 2.2006577005191725
3770000: 2.2001952153079363
3772500: 2.2086076901883493
3775000: 2.1972379248969407
3777500: 2.188988781096984
3780000: 2.2046314487651903
3782500: 2.2141689619239497
3785000: 2.218819879877324
3787500: 2.1631650683831194
3790000: 2.227557586528817
3792500: 2.2322453483026856
3795000: 2.1906193604274673
3797500: 2.2039979173212636
3800000: 2.1909678995609285
3802500: 2.2162753622142635
3805000: 2.21643171091469
3807500: 2.2180862420675704
3810000: 2.1941594954656094
3812500: 2.1749520475767095
3815000: 2.173583051988057
3817500: 2.1874511818496547
3820000: 2.2019618842066553
3822500: 2.2250382053608795
3825000: 2.1905882797679124
3827500: 2.202382380743416
3830000: 2.216143370769462
3832500: 2.1981173481260026
3835000: 2.22126795904977
3837500: 2.1858620195972676
3840000: 2.195863873739632
3842500: 2.2212542462105653
3845000: 2.199797985991653
3847500: 2.19601229088647
3850000: 2.222995390089191
3852500: 2.1968755895994145
3855000:

4510000: 2.208449947347446
4512500: 2.183776186558665
4515000: 2.1855805482183186
4517500: 2.195290999023282
4520000: 2.2054405951986507
4522500: 2.1853518391142086
4525000: 2.2231129337330255
4527500: 2.2000789639901144
4530000: 2.1715629462076693
4532500: 2.1770610147592975
4535000: 2.205742978563114
4537500: 2.199880805429147
4540000: 2.2177822049783202
4542500: 2.152935639206244
4545000: 2.169084262361332
4547500: 2.2145632522446768
4550000: 2.229036459022639
4552500: 2.1942186090410973
4555000: 2.2266380109349075
4557500: 2.2120868560002775
4560000: 2.2205495715141295
4562500: 2.199951757095298
4565000: 2.1716635232069055
4567500: 2.17828534926687
4570000: 2.184955117288901
4572500: 2.2080246088456135
4575000: 2.1886166813422223
4577500: 2.190757015773228
4580000: 2.18898687082894
4582500: 2.224252625144258
4585000: 2.2304506826157473
4587500: 2.2030985251981385
4590000: 2.175406799997602
4592500: 2.1692333503645296
4595000: 2.195691876143825
4597500: 2.1948091116486763
4600000: 2

5997500: 2.202641400755668
6000000: 2.2200447288094733
6002500: 2.1472346701184097
6005000: 2.2139043899215
6007500: 2.170636405020344
6010000: 2.201812610699206
6012500: 2.249635247551665
6015000: 2.188224479860189
6017500: 2.1795691482874813
6020000: 2.2148596706439037
6022500: 2.2212282456913774
6025000: 2.1934543733694114
6027500: 2.197493101139458
6030000: 2.2188804499957024
6032500: 2.2089625431566824
6035000: 2.1765147875766364
6037500: 2.19123688279366
6040000: 2.189809217744944
6042500: 2.2084161852087294
6045000: 2.205136469553928
6047500: 2.188752531153815
6050000: 2.182692914957903
6052500: 2.2070649684691914
6055000: 2.195049069730603
6057500: 2.193141368213965
6060000: 2.19638619447241
6062500: 2.189138168704753
6065000: 2.179909179162006
6067500: 2.200084418788248
6070000: 2.213590875085519
6072500: 2.2058315624996108
6075000: 2.2023404137212403
6077500: 2.2306516081702954
6080000: 2.199769086010602
6082500: 2.216645130940846
6085000: 2.178515136606839
6087500: 2.2162170

6742500: 2.1901086675877472
6745000: 2.1844370071985284
6747500: 2.1710217934481952
6750000: 2.2000249565864096
6752500: 2.202564786161695
6755000: 2.1941341263907295
6757500: 2.1895610277750053
6760000: 2.184478695173653
6762500: 2.197909134869673
6765000: 2.2024853475239814
6767500: 2.228927832355305
6770000: 2.222986268997192
6772500: 2.2159301285841027
6775000: 2.163316390465717
6777500: 2.223228020570716
6780000: 2.182314224997345
6782500: 2.1967040918311294
6785000: 2.196633057934897
6787500: 2.1671917498111726
6790000: 2.2192802582468305
6792500: 2.2031340518776252
6795000: 2.2012539586242363
6797500: 2.2059701471912616
6800000: 2.1886169882453217
6802500: 2.1964683733424364
6805000: 2.209450234442341
6807500: 2.1916091091778815
6810000: 2.1831454731980147
6812500: 2.210522721373305
6815000: 2.18941357950775
6817500: 2.209894272745872
6820000: 2.1706991666433764
6822500: 2.1888728027441062
6825000: 2.20979526201073
6827500: 2.188484981473611
6830000: 2.216966154259078
6832500: 2

7487500: 2.1690787117091976
7490000: 2.217876173160514
7492500: 2.1946101048771216
7495000: 2.1878083203520093
7497500: 2.212079415759262
7500000: 2.199685369462383
7502500: 2.205980503072544
7505000: 2.1810820990679214
7507500: 2.1908173141430836
7510000: 2.1729673503612985
7512500: 2.211146952059804
7515000: 2.21182183513836
7517500: 2.218712698805089
7520000: 2.193273173181378
7522500: 2.19985897419404
7525000: 2.186469154455224
7527500: 2.1743435795209844
7530000: 2.164376359204857
7532500: 2.225659038096058
7535000: 2.206638963976685
7537500: 2.1733376191586866
7540000: 2.180913446265824
7542500: 2.2017611216525643
7545000: 2.2004234830943905
7547500: 2.1982269937894783
7550000: 2.1797219822601397
7552500: 2.196499037134404
7555000: 2.199742940129066
7557500: 2.2040926272771797
7560000: 2.2004607621504335
7562500: 2.187921876080182
7565000: 2.191432991198131
7567500: 2.182269267038423
7570000: 2.202788787472005
7572500: 2.18375090275492
7575000: 2.1960595708720536
7577500: 2.20456

8232500: 2.195112026224331
8235000: 2.18757655851695
8237500: 2.212208218842137
8240000: 2.2414908670649236
8242500: 2.1863760492023157
8245000: 2.1861279277168975
8247500: 2.236302804581973
8250000: 2.2097318190701154
8252500: 2.2060679019713887
8255000: 2.2074855468711077
8257500: 2.211163742201669
8260000: 2.1644942753169003
8262500: 2.197504799220027
8265000: 2.2021108793969058
8267500: 2.1695515641144345
8270000: 2.1973421116264498
8272500: 2.2013464313380573
8275000: 2.1827135110388
8277500: 2.1687639128188696
8280000: 2.1778985805657447
8282500: 2.169347562716932
8285000: 2.2080332972565477
8287500: 2.181574450098738
8290000: 2.1898050967527896
8292500: 2.2097368118714313
8295000: 2.2079789053420633
8297500: 2.183301143743554
8300000: 2.209997790443654
8302500: 2.167995946869558
8305000: 2.2171301053494825
8307500: 2.1796620531957975
8310000: 2.197253912200733
8312500: 2.174555990890581
8315000: 2.210818136468226
8317500: 2.194744596311024
8320000: 2.199347575221743
8322500: 2.2

8980000: 2.2093931931622173
8982500: 2.214709013457201
8985000: 2.1804152157841896
8987500: 2.1983827674875456
8990000: 2.198120360350122
8992500: 2.2298376608868034
8995000: 2.175196808576584
8997500: 2.1999878700898616
9000000: 2.198630108152117
9002500: 2.193682554911594
9005000: 2.1828626626608325
9007500: 2.1813585956485904
9010000: 2.182706156433845
9012500: 2.181780126021833
9015000: 2.2003001649768983
9017500: 2.201100687348113
9020000: 2.215549272420455
9022500: 2.170653052232703
9025000: 2.1873298264279657
9027500: 2.207155562055354
9030000: 2.1943650485301505
9032500: 2.1962774678152437
9035000: 2.187574384285479
9037500: 2.204780114791831
9040000: 2.2020238536961223
9042500: 2.2295269456444955
9045000: 2.1980459460190365
9047500: 2.2091994623748623
9050000: 2.2290749908710015
9052500: 2.1866962745481606
9055000: 2.1926430007632898
9057500: 2.1929449150756914
9060000: 2.188481136609097
9062500: 2.194136107697779
9065000: 2.19487798457243
9067500: 2.195967631558983
9070000: 2

9725000: 2.193997092636264
9727500: 2.180430617867684
9730000: 2.1745949326729286
9732500: 2.2100040253327817
9735000: 2.1946562190445102
9737500: 2.202631371726795
9740000: 2.2063086458614896
9742500: 2.180843865993072
9745000: 2.2010926301382026
9747500: 2.2015160988788214
9750000: 2.2346695364738
9752500: 2.199440479156922
9755000: 2.2039025329813664
9757500: 2.191199843737544
9760000: 2.2178319637872734
9762500: 2.187846688591704
9765000: 2.197084269110037
9767500: 2.197978573307699
9770000: 2.221359229452756
9772500: 2.1989785260083723
9775000: 2.204771826705154
9777500: 2.1843981793948584
9780000: 2.1889312336639484
9782500: 2.1977749722344533
9785000: 2.1795051605117566
9787500: 2.2050826536149395
9790000: 2.1968488087459486
9792500: 2.215621027776173
9795000: 2.1702059451414613
9797500: 2.1979504046391467
9800000: 2.2165427540029796
9802500: 2.1884476861175224
9805000: 2.1672010268483843
9807500: 2.2016567934532554
9810000: 2.203126746537734
9812500: 2.2145550439552384
9815000:

10455000: 2.1982629351469933
10457500: 2.208537690128599
10460000: 2.1979884805727976
10462500: 2.189504415283398
10465000: 2.189067206334095
10467500: 2.200019900531185
10470000: 2.1706091703200827
10472500: 2.2093484523345013
10475000: 2.209395616638417
10477500: 2.191588550927688
10480000: 2.182059027467455
10482500: 2.203808276264035
10485000: 2.2123684074197496
10487500: 2.208128643279173
10490000: 2.173094210576038
10492500: 2.19562361641806
10495000: 2.229535412666749
10497500: 2.2036920807799514
10500000: 2.1862161290888884
10502500: 2.2080259736703365
10505000: 2.227828532578994
10507500: 2.2138158478298964
10510000: 2.204256014434659
10512500: 2.1805881522139723
10515000: 2.17770176736676
10517500: 2.1822233864239284
10520000: 2.196924881545865
10522500: 2.2077386380458366
10525000: 2.209971043406701
10527500: 2.2153097697666713
10530000: 2.1824838673581883
10532500: 2.2010312270145027
10535000: 2.1861094821472555
10537500: 2.1945259681769778
10540000: 2.190694427125308
10542

11175000: 2.1860019484344795
11177500: 2.196539395196097
11180000: 2.206783320952435
11182500: 2.2038517853435207
11185000: 2.2246364231012303
11187500: 2.2375727540376236
11190000: 2.200686306369548
11192500: 2.1843839454407594
11195000: 2.168997321566757
11197500: 2.2009563213708447
11200000: 2.1842484065464567
11202500: 2.177344589452354
11205000: 2.1780865299458405
11207500: 2.1935942158407093
11210000: 2.2232118109051062
11212500: 2.1922524381657036
11215000: 2.206054795031645
11217500: 2.1931101599518135
11220000: 2.181093002582083
11222500: 2.2063084269056517
11225000: 2.2092290535265087
11227500: 2.202165926475914
11230000: 2.1958227594288027
11232500: 2.208786155739609
11235000: 2.1968476813666675
11237500: 2.2110507849527865
11240000: 2.1972222219924538
11242500: 2.197362479871633
11245000: 2.190721266975208
11247500: 2.1894797525843797
11250000: 2.182326216478737
11252500: 2.193475055816222
11255000: 2.186434096949441
11257500: 2.1664985421968965
11260000: 2.2002245062468004

11895000: 2.1910718940958684
11897500: 2.1806754951574363
11900000: 2.190909749634412
11902500: 2.2096743672477954
11905000: 2.2044730400552557
11907500: 2.1898951796852812
11910000: 2.2034607686558547
11912500: 2.2042678534984588
11915000: 2.1851284089137097
11917500: 2.1942196097909186
11920000: 2.189907696660684
11922500: 2.1698561445790894
11925000: 2.21247306831029
11927500: 2.1885902146903837
11930000: 2.2031862328247147
11932500: 2.181041748791325
11935000: 2.1842616662687186
11937500: 2.195239959930887
11940000: 2.212951637044245
11942500: 2.186120692321232
11945000: 2.1995031929745967
11947500: 2.1967352801439715
11950000: 2.1925287938847835
11952500: 2.192585057628398
11955000: 2.188512809179267
11957500: 2.204780761198122
11960000: 2.1719008141634415
11962500: 2.2045426624161855
11965000: 2.210787746371055
11967500: 2.21726657955014
11970000: 2.217114094690401
11972500: 2.2175776102104967
11975000: 2.1903830407833564
11977500: 2.2163154451214537
11980000: 2.2104330688106772


12615000: 2.217660295720003
12617500: 2.146147916025045
12620000: 2.1971895764068683
12622500: 2.199636511170134
12625000: 2.228930046850321
12627500: 2.1993762578283036
12630000: 2.2013494699585197
12632500: 2.1868105149998955
12635000: 2.1964396179938803
12637500: 2.1952847617013114
12640000: 2.1924812095505852
12642500: 2.1780345272044745
12645000: 2.2010293286673877
12647500: 2.1853801969362765
12650000: 2.1966851159017913
12652500: 2.182509460862802
12655000: 2.2185028849815835
12657500: 2.2024625120114307
12660000: 2.1879144590728137
12662500: 2.1713448050070783
12665000: 2.2217815383356445
12667500: 2.219511313706028
12670000: 2.1831867326279077
12672500: 2.1867940366268157
12675000: 2.2085118460411928
12677500: 2.206721752273793
12680000: 2.1960250695140995
12682500: 2.225588308670083
12685000: 2.187065566559227
12687500: 2.175695366762122
12690000: 2.1936722824768142
12692500: 2.2009429434124304
12695000: 2.1983428869928634
12697500: 2.1956426535333904
12700000: 2.193144780275

14050000: 2.1935842788949307
14052500: 2.1702514507332626
14055000: 2.193668896689707
14057500: 2.206439838604051
14060000: 2.1877055176666804
14062500: 2.200284654996833
14065000: 2.2018713892722617
14067500: 2.1964905909129553
14070000: 2.218852054707858
14072500: 2.2002813216374846
14075000: 2.1788389583023227
14077500: 2.245842964795171
14080000: 2.204462469840536
14082500: 2.1980284119138913
14085000: 2.2087554701736996
14087500: 2.1936684581698205
14090000: 2.19279858409142
14092500: 2.2122223305458926
14095000: 2.187463733857992
14097500: 2.198925426298258
14100000: 2.213136902390694
14102500: 2.1896010509559085
14105000: 2.1872539266031614
14107500: 2.1885378235456896
14110000: 2.196231744970594
14112500: 2.1945975720882416
14115000: 2.211909764518543
14117500: 2.1727587188993183
14120000: 2.214830623840799
14122500: 2.1921172337872643
14125000: 2.1831351715691234
14127500: 2.1793854803455117
14130000: 2.2073491308153894
14132500: 2.2046993466056124
14135000: 2.2149589525193583

15487500: 2.196060204992489
15490000: 2.1986251099985474
15492500: 2.2161808664701423
15495000: 2.2165669259976366
15497500: 2.189003231452436
15500000: 2.180684159239944
15502500: 2.2201028303224213
15505000: 2.2063393815439576
15507500: 2.205922103293088
15510000: 2.1954694520454017
15512500: 2.207654325816096
15515000: 2.210381681700142
15517500: 2.212384355068207
15520000: 2.1803298108431757
15522500: 2.210944386282746
15525000: 2.1896026796224164
15527500: 2.2092513971182766
15530000: 2.188202385634792
15532500: 2.1836279173286592
15535000: 2.180520824388582
15537500: 2.194939195623203
15540000: 2.192594975476362
15542500: 2.1910361730322547
15545000: 2.214677193213482
15547500: 2.1950149880380048
15550000: 2.2118168431885388
15552500: 2.215134229951975
15555000: 2.188953677367191
15557500: 2.1728335996063386
15560000: 2.194608951344782
15562500: 2.204698813326505
15565000: 2.211399608607195
15567500: 2.213393820426902
15570000: 2.2141266659814485
15572500: 2.1887541050813635
1557

16207500: 2.1758646623212465
16210000: 2.1928406042712076
16212500: 2.2008094942083165
16215000: 2.1860166389115
16217500: 2.215842946938106
16220000: 2.1730453303882054
16222500: 2.223293694428035
16225000: 2.208902869784102
16227500: 2.2198308472730677
16230000: 2.19854214203601
16232500: 2.2100974126737944
16235000: 2.207976046630314
16237500: 2.1880144284695997
16240000: 2.2036834770319413
16242500: 2.194078634588086
16245000: 2.192909027727283
16247500: 2.2090389244410455
16250000: 2.189234514625705
16252500: 2.179793384488748
16255000: 2.211957694681323
16257500: 2.192904990911484
16260000: 2.2061276561143446
16262500: 2.1766408460480826
16265000: 2.1790348982324406
16267500: 2.1905169667029867
16270000: 2.1985273340526894
16272500: 2.201217403460522
16275000: 2.1814046312351616
16277500: 2.181519715031799
16280000: 2.1756101559619516
16282500: 2.205055980049834
16285000: 2.198100186853993
16287500: 2.1731367471266765
16290000: 2.1944083714971736
16292500: 2.2045049999441417
1629

17642500: 2.1972363509693924
17645000: 2.1982118334089007
17647500: 2.217970981646557
17650000: 2.2089786702272844
17652500: 2.187749241079603
17655000: 2.187840494087764
17657500: 2.201826633969132
17660000: 2.225097228312979
17662500: 2.205489421742303
17665000: 2.180030658172101
17667500: 2.1843698055160288
17670000: 2.1937183189148803
17672500: 2.185493577499779
17675000: 2.197784616995831
17677500: 2.1968573443743646
17680000: 2.2187834895386986
17682500: 2.1995877770745023
17685000: 2.1793598228571365
17687500: 2.2009152388086126
17690000: 2.1935349390214802
17692500: 2.215585751071268
17695000: 2.1880380287462353
17697500: 2.1926195020578345
17700000: 2.2083976649508186
17702500: 2.1914727310745086
17705000: 2.1946831608305173
17707500: 2.183110387957826
17710000: 2.194858516600667
17712500: 2.192001062388323
17715000: 2.1947823165630806
17717500: 2.1849885429654803
17720000: 2.216694849729538
17722500: 2.223346305501704
17725000: 2.1927369787984965
17727500: 2.200053378148955
1

18360000: 2.1878721319899266
18362500: 2.197185142672792
18365000: 2.198124520267759
18367500: 2.194431112980356
18370000: 2.213223843063627
18372500: 2.228273796670291
18375000: 2.1934972030775888
18377500: 2.226602835922825
18380000: 2.173256644545769
18382500: 2.1999809122815424
18385000: 2.1888290402840593
18387500: 2.1943219863638586
18390000: 2.19879143031276
18392500: 2.167131693144234
18395000: 2.203285256332281
18397500: 2.198829143509573
18400000: 2.1923667634020045
18402500: 2.1970288628218126
18405000: 2.238988952247464
18407500: 2.1934910670835146
18410000: 2.1902188518825843
18412500: 2.172090831459785
18415000: 2.191171999245274
18417500: 2.18402690704988
18420000: 2.202088968607844
18422500: 2.1901381666562996
18425000: 2.18667251650168
18427500: 2.199278310853608
18430000: 2.192919176573656
18432500: 2.197924996638785
18435000: 2.1823396382283193
18437500: 2.1824599331738996
18440000: 2.201830787926304
18442500: 2.18999012368066
18445000: 2.189524331141491
18447500: 2.

19800000: 2.2042979857143092
19802500: 2.2031468009462163
19805000: 2.1903390727481065
19807500: 2.1919013786072634
19810000: 2.19642585005079
19812500: 2.1774504988777394
19815000: 2.21136386637785
19817500: 2.1833508163082356
19820000: 2.185589562143598
19822500: 2.22252490617791
19825000: 2.1923026897469344
19827500: 2.186195780793015
19830000: 2.1930361871816673
19832500: 2.228275844758871
19835000: 2.205900231794435
19837500: 2.1806386853967394
19840000: 2.21166510569806
19842500: 2.1983759322944953
19845000: 2.213735781883707
19847500: 2.183583147063547
19850000: 2.183036882780036
19852500: 2.229831783138976
19855000: 2.1684717774391173
19857500: 2.2138633471362446
19860000: 2.21723395768477
19862500: 2.191140893770724
19865000: 2.19108657946392
19867500: 2.1910140004693246
19870000: 2.181553053004401
19872500: 2.1801211371713753
19875000: 2.17911059296861
19877500: 2.215697421346392
19880000: 2.204668646564289
19882500: 2.1798168485261957
19885000: 2.2021099584443227
19887500: 2

In [13]:
# import matplotlib.pyplot as plt

# plt.plot(range(len(losses)), losses)
val_loss

2.2049302531748403