In [50]:
import sys
# sys.path.append('/home/ajhnam/plinko_nn/src')
sys.path.append('/home/plinkoproj/plinko_nn/src')
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from plinko.misc import data_utils
from plinko.misc.simulation_dataset import SimulationDataset
from plinko.model.predictor_gru import GRUPredictor

In [60]:
# device = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.FloatTensor')
epsilon = sys.float_info.epsilon

In [52]:
repo_path = '/home/plinkoproj/plinko_nn'
df_ball = pd.read_feather(repo_path + '/data/simulations/sim_ball.feather')
df_env = pd.read_feather(repo_path + '/data/simulations/sim_environment.feather')
df_col = pd.read_feather(repo_path + '/data/simulations/sim_collisions.feather')

In [53]:
sim_data = data_utils.get_sim_data(df_ball, df_col)
selected_runs = sim_data[(sim_data.num_collisions == 1)
                         & (sim_data.duration < 100)
                         & (sim_data.run <= 20)]
simulations, environments = data_utils.create_task_df(selected_runs, df_ball, df_env)
states, envs = data_utils.to_tensors(simulations, environments, device)

In [54]:
print(selected_runs)
# print(states)

      simulation  run  duration  num_collisions
20         sim_1    0        45               1
21         sim_1    1        45               1
22         sim_1    2        45               1
23         sim_1    3        45               1
24         sim_1    4        45               1
...          ...  ...       ...             ...
19995    sim_999   15        45               1
19996    sim_999   16        45               1
19997    sim_999   17        45               1
19998    sim_999   18        45               1
19999    sim_999   19        45               1

[4500 rows x 4 columns]


In [62]:
def get_logp_loss(gm, targets):
    return -gm.log_p(targets).mean()

def get_mu_mse_loss(gm, targets):
    return F.mse_loss(gm.mu[:,:,0], targets)

def get_mu_position(gm):
    return gm.mu[:,:,0]

In [56]:
model = GRUPredictor(env_size=11, state_size=2, num_gaussians=1).to(device)
# optimizer = optim.SGD(model.parameters(), lr=.001)
optimizer = optim.Adam(model.parameters(), weight_decay=.001)
dataset = SimulationDataset(envs, states)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# print(dataloader)

for batch_i, batch in enumerate(dataloader):
    if batch_i == 0:
        print(batch_i)
        print(batch['states'][0])
        print(batch['targets'][0])

0
tensor([[2.5464, 8.9030],
        [2.5419, 8.8714],
        [2.5373, 8.8339],
        [2.5327, 8.7905],
        [2.5281, 8.7411],
        [2.5235, 8.6857],
        [2.5189, 8.6244],
        [2.5144, 8.5571],
        [2.5098, 8.4839],
        [2.5052, 8.4047],
        [2.5006, 8.3196],
        [2.4960, 8.2286],
        [2.4915, 8.1315],
        [2.4869, 8.0286],
        [2.4823, 7.9196],
        [2.4777, 7.8047],
        [2.4731, 7.6839],
        [2.4686, 7.5571],
        [2.4640, 7.4244],
        [2.4594, 7.2857],
        [2.4548, 7.1410],
        [2.4502, 6.9904],
        [2.4456, 6.8339],
        [2.4411, 6.6714],
        [2.4365, 6.5029],
        [2.4319, 6.3285],
        [2.4273, 6.1482],
        [2.4227, 5.9619],
        [2.4182, 5.7696],
        [2.4136, 5.5714],
        [2.4090, 5.3672],
        [2.4044, 5.1571],
        [2.3998, 4.9410],
        [2.3952, 4.7190],
        [2.3907, 4.4910],
        [2.3861, 4.2571],
        [2.3815, 4.0172],
        [2.3769, 3.7714],
        [2

In [84]:
max_t = simulations.t.max()
epochs = 100
losses = []
# for epoch in tqdm(range(epochs+1)):
for epoch in tqdm(range(1)):
    epoch_loss = 0
    epoch_mse_loss = 0
    epoch_logp_loss = 0
    for batch_i, batch in enumerate(dataloader):
        optimizer.zero_grad()
            
        gm = model(batch['envs'], batch['states'], 0)
        targets = batch['targets']
        
        logp_loss = get_logp_loss(gm, targets)
        mse_loss = 10*get_mu_mse_loss(gm, targets)
        loss = logp_loss + mse_loss
        loss.backward(retain_graph=True)
        optimizer.step()
        epoch_loss += loss
        epoch_logp_loss += logp_loss
        epoch_mse_loss += mse_loss
        losses.append((epoch, batch_i, float(loss)))
        
    gm_mu = get_mu_position(gm)
    print('batch', batch_i, '\n', gm_mu[0], '\n', targets[0])
    
    
    if epoch%5 == 0:
        print('Epoch {} | logp: {} | mse: {} | total: {}'.format(epoch,
                                                                 round(float(epoch_logp_loss), 4),
                                                                 round(float(epoch_mse_loss), 4),
                                                                 round(float(epoch_loss), 4)))
    
# torch.save(model.state_dict(), 'gru.model')

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

batch 70 
 tensor([[7.4655, 8.7820],
        [7.4656, 8.8072],
        [7.4694, 8.7627],
        [7.4781, 8.7125],
        [7.4882, 8.6537],
        [7.4970, 8.5934],
        [7.5046, 8.5271],
        [7.5116, 8.4546],
        [7.5177, 8.3755],
        [7.5231, 8.2889],
        [7.5279, 8.1943],
        [7.5322, 8.0919],
        [7.5362, 7.9819],
        [7.5399, 7.8652],
        [7.5437, 7.7424],
        [7.5477, 7.6145],
        [7.5520, 7.4821],
        [7.5567, 7.3457],
        [7.5617, 7.2056],
        [7.5672, 7.0616],
        [7.5731, 6.9135],
        [7.5792, 6.7610],
        [7.5850, 6.6040],
        [7.5899, 6.4424],
        [7.5932, 6.2757],
        [7.5947, 6.1027],
        [7.5948, 5.9225],
        [7.5942, 5.7352],
        [7.5939, 5.5422],
        [7.5949, 5.3449],
        [7.5975, 5.1431],
        [7.6012, 4.9349],
        [7.6045, 4.7178],
        [7.6059, 4.4922],
        [7.6054, 4.2617],
        [7.6058, 4.0256],
        [7.6115, 3.7759],
        [7.6176, 3.5280],
 

ValueError: Expected 1D or 2D array, got 3D array instead

In [10]:
dataset = SimulationDataset(envs, states)
dataloader = DataLoader(dataset, batch_size=len(envs), shuffle=True)
for batch in dataloader:
    with torch.no_grad():
        inter_gm, extra_gm, samples = model(batch['envs'], batch['states'][:, :1], 100)
        targets = batch['targets'][:,1:101]