In [1]:
import sys
sys.path.append('/home/ajhnam/plinko_nn/src')
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from plinko.misc import data_utils
from plinko.misc.simulation_dataset import SimulationDataset
from plinko.model.predictor_gru import GRUPredictor

In [2]:
device = 0
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.FloatTensor')
epsilon = sys.float_info.epsilon

In [3]:
df_ball = pd.read_feather('/home/ajhnam/plinko_nn/data/simulations/sim_ball.feather')
df_env = pd.read_feather('/home/ajhnam/plinko_nn/data/simulations/sim_environment.feather')
df_col = pd.read_feather('/home/ajhnam/plinko_nn/data/simulations/sim_collisions.feather')

In [4]:
sim_data = data_utils.get_sim_data(df_ball, df_col)
selected_runs = sim_data[(sim_data.num_collisions == 1)
                         & (sim_data.duration < 50)
                         & (sim_data.run <= 2)]
simulations, environments = data_utils.create_task_df(selected_runs, df_ball, df_env)
states, envs = data_utils.to_tensors(simulations, environments, device)

In [5]:
def get_logp_loss(gm, targets):
    return -gm.log_p(targets).mean()

def get_mu_mse_loss(gm, targets):
    return F.mse_loss(gm.mu[:,:,0], targets)

In [8]:
model = GRUPredictor(env_size=11, state_size=2, num_gaussians=1).to(device)
# optimizer = optim.SGD(model.parameters(), lr=.001)
optimizer = optim.Adam(model.parameters(), weight_decay=.001)
dataset = SimulationDataset(envs, states)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

max_t = simulations.t.max()
epochs = 100
losses = []
for epoch in tqdm(range(epochs+1)):
    epoch_loss = 0
    epoch_mse_loss = 0
    epoch_logp_loss = 0
    for batch_i, batch in enumerate(dataloader):
        optimizer.zero_grad()
            
        gm = model(batch['envs'], batch['states'], 0)
        targets = batch['targets']
        
        logp_loss = 0 #get_logp_loss(gm, targets)
        mse_loss = 10*get_mu_mse_loss(gm, targets)
        loss = logp_loss + mse_loss
        loss.backward(retain_graph=True)
        optimizer.step()
        epoch_loss += loss
        epoch_logp_loss += logp_loss
        epoch_mse_loss += mse_loss
        losses.append((epoch, batch_i, float(loss)))
    if epoch%10 == 0:
        print('Epoch {} | logp: {} | mse: {} | total: {}'.format(epoch,
                                                                 round(float(epoch_logp_loss), 4),
                                                                 round(float(epoch_mse_loss), 4),
                                                                 round(float(epoch_loss), 4)))
    
torch.save(model.state_dict(), 'gru.model')

HBox(children=(IntProgress(value=0, max=101), HTML(value='')))

Epoch 0 | logp: 0.0 | mse: 2028.1729 | total: 2028.1729
Epoch 10 | logp: 0.0 | mse: 4.0664 | total: 4.0664
Epoch 20 | logp: 0.0 | mse: 0.7547 | total: 0.7547
Epoch 30 | logp: 0.0 | mse: 0.4034 | total: 0.4034
Epoch 40 | logp: 0.0 | mse: 0.2526 | total: 0.2526
Epoch 50 | logp: 0.0 | mse: 0.1734 | total: 0.1734
Epoch 60 | logp: 0.0 | mse: 0.1847 | total: 0.1847
Epoch 70 | logp: 0.0 | mse: 0.0965 | total: 0.0965
Epoch 80 | logp: 0.0 | mse: 0.1815 | total: 0.1815
Epoch 90 | logp: 0.0 | mse: 0.1019 | total: 0.1019
Epoch 100 | logp: 0.0 | mse: 0.1164 | total: 0.1164



In [10]:
dataset = SimulationDataset(envs, states)
dataloader = DataLoader(dataset, batch_size=len(envs), shuffle=True)
for batch in dataloader:
    with torch.no_grad():
        inter_gm, extra_gm, samples = model(batch['envs'], batch['states'][:, :1], 100)
        targets = batch['targets'][:,1:101]