In [1]:
import os
import sys
sys.path.append('../src')
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from plinko.misc import data_utils
from plinko.misc.simulation_dataset import SimulationDataset
from plinko.model.predictor_gru import GRUPredictor

def loaddata(run_indices = range(20), outdf = False):
    df_ball = pd.read_feather('../data/simulations/sim_ball.feather')
    df_env = pd.read_feather('../data/simulations/sim_environment.feather')
    df_col = pd.read_feather('../data/simulations/sim_collisions.feather')

    sim_data = data_utils.get_sim_data(df_ball, df_col)

    # change this to dur <940 (max, 15s fall)
    selected_runs = sim_data[(sim_data.num_collisions == 1)
                             & (sim_data.duration < 50)
                             & np.in1d(sim_data.run,run_indices)]
    simulations, environments = data_utils.create_task_df(selected_runs, df_ball, df_env, append_t0 = False)
    if outdf:
        states, envs, simulations, environments = data_utils.to_tensors(simulations, environments, device, outdf)
    else:
        states, envs = data_utils.to_tensors(simulations, environments, device, outdf)
    return states, envs, simulations, environments

def get_logp_loss(gm, targets):
    return -gm.log_p(targets).mean()

def get_mu_mse_loss(gm, targets):
    return F.mse_loss(gm.mu[:,:,0], targets)



In [2]:
def train_model(model,optimizer,simulations,dataset,savename = 'gru.model'):
    # run SGD, with batchsize =64
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

    max_t = simulations.t.max()
    epochs = 100
    losses = []
    for epoch in tqdm(range(epochs + 1)):
        epoch_loss = 0
        epoch_mse_loss = 0
        epoch_logp_loss = 0
        for batch_i, batch in enumerate(dataloader):
            optimizer.zero_grad()

            gm = model(batch['envs'], batch['states'], 0)
            targets = batch['targets']

            logp_loss = get_logp_loss(gm, targets)
            mse_loss = 10 * get_mu_mse_loss(gm, targets)
            loss = logp_loss + mse_loss
            loss.backward(retain_graph=True)
            optimizer.step()
            epoch_loss += loss
            epoch_logp_loss += logp_loss
            epoch_mse_loss += mse_loss
            losses.append((epoch, batch_i, float(loss)))
        if epoch % 5 == 0:
            print('Epoch {} | logp: {} | mse: {} | total: {}'.format(epoch,
                                                                     round(float(epoch_logp_loss), 4),
                                                                     round(float(epoch_mse_loss), 4),
                                                                     round(float(epoch_loss), 4)))

    torch.save(model.state_dict(), savename)
    return model

In [10]:
def simulate_model(model,dataset,sim_df, env_df,modelname = 'gru1'):
    dataloader = DataLoader(dataset, batch_size=len(envs), shuffle=True)
    i = 0
    for batch in dataloader:
        i += 1
        with torch.no_grad():
            inter_gm, extra_gm, samples = model(batch['envs'], batch['states'][:, :1], 100)
            targets = batch['targets'][:,1:101]
            df_env, df_ball = data_utils.create_simdata_from_samples(samples, batch['envs'],sim_df, env_df)
            
#             df_ball["px"] = df_ball.px.astype(float)
#             df_ball["py"] = df_ball.py.astype(float)
#             df_ball.to_feather(os.path.join('../experiments/' + modelname + '/batch{}'.format(i) + 'samp.feather'))
#             df_env.to_feather(os.path.join('../experiments/' + modelname + '/batch_{}'.format(i) + 'envs.feather'))
            return df_ball, df_env

def todf(env_batch):
    columns = ['' '']

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.FloatTensor')
epsilon = sys.float_info.epsilon

# load data
states, envs, sim_df, env_df = loaddata(run_indices = range(20), outdf = True)
dataset = SimulationDataset(envs, states)


In [5]:
# define model
model = GRUPredictor(env_size=envs.shape[1], state_size=2, num_gaussians=4).to(device)

# train model;
# optimizer = optim.SGD(model.parameters(), lr=.001)
optimizer = optim.Adam(model.parameters(), weight_decay=.001)
model = train_model(model, optimizer,sim_df, dataset)


HBox(children=(IntProgress(value=0, max=101), HTML(value='')))

Epoch 0 | logp: 506.6757 | mse: 4066.1643 | total: 4572.8394
Epoch 5 | logp: -38.1786 | mse: 13.5641 | total: -24.6145
Epoch 10 | logp: -1.6629 | mse: 27.889 | total: 26.2261
Epoch 15 | logp: -120.8749 | mse: 5.5896 | total: -115.2853
Epoch 20 | logp: -45.1586 | mse: 12.8062 | total: -32.3524
Epoch 25 | logp: 1.1531 | mse: 37.1328 | total: 38.2859
Epoch 30 | logp: -132.4912 | mse: 4.6821 | total: -127.8091
Epoch 35 | logp: -70.6819 | mse: 13.9931 | total: -56.6888
Epoch 40 | logp: -110.1565 | mse: 7.6125 | total: -102.544
Epoch 45 | logp: -139.7759 | mse: 4.1942 | total: -135.5817
Epoch 50 | logp: -158.36 | mse: 7.6127 | total: -150.7473
Epoch 55 | logp: -73.8411 | mse: 8.5897 | total: -65.2514
Epoch 60 | logp: -28.7351 | mse: 45.9127 | total: 17.1777
Epoch 65 | logp: -42.1858 | mse: 27.8863 | total: -14.2994
Epoch 70 | logp: -44.1223 | mse: 11.8533 | total: -32.2691
Epoch 75 | logp: -133.9275 | mse: 9.6009 | total: -124.3265
Epoch 80 | logp: -110.3693 | mse: 22.252 | total: -88.1173
E

In [11]:
# simulate from trained model
df_ball, df_env = simulate_model(model, dataset,sim_df, env_df)

TypeError: cannot unpack non-iterable NoneType object

In [8]:
df_ball["px"] = df_ball.px.astype(float)
df_ball["py"] = df_ball.py.astype(float)

# type(df_ball)
df_ball.info()
df_env.info()

# print(df_ball["t"])
# print(df_ball["py"] )
# print(df_env['triangle_x'])

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 450000 entries, 0 to 449999
Data columns (total 5 columns):
simulation    450000 non-null object
run           450000 non-null int64
t             450000 non-null int64
px            450000 non-null float64
py            450000 non-null float64
dtypes: float64(2), int64(2), object(1)
memory usage: 17.2+ MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4500 entries, 0 to 4499
Data columns (total 9 columns):
triangle_x     4500 non-null float32
triangle_y     4500 non-null float32
triangle_r     4500 non-null float32
rectangle_x    4500 non-null float32
rectangle_y    4500 non-null float32
rectangle_r    4500 non-null float32
pentagon_x     4500 non-null float32
pentagon_y     4500 non-null float32
pentagon_r     4500 non-null float32
dtypes: float32(9)
memory usage: 158.3 KB


In [9]:
df_ball.to_feather(os.path.join('../experiments/gru1/batch1_samp.feather'))
df_env.to_feather(os.path.join('../experiments/gru1/batch1_envs.feather'))