# Behavior cloning on pseudo-driving data-set

## Imports

In [1]:
import json
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
import sys
sys.path.append('../bishop1995_notes/modules_for_nn_training')
from learner import Learner
from callbacks import CallbackHandler
from utility_callbacks import LossCallback, AccuracyCallback
from trainer import Trainer

## Load data

In [3]:
with open('pseudo_driving_dataset.json', 'r') as json_f:
    states, actions = map(np.array, json.load(json_f))

In [4]:
states = states.reshape(-1, 5 * 40)
actions = actions[:,1].reshape(-1, 1)
print(states.shape, actions.shape)

(113, 200) (113, 1)


In [5]:
train_ds = TensorDataset(torch.from_numpy(states), torch.from_numpy(actions))
train_dl = DataLoader(train_ds, batch_size=10, shuffle=True)

## Define neural net

In [6]:
class NN(nn.Module):
    
    def __init__(self, input_len, output_len, num_neurons):
        super().__init__()
        
        self.main = nn.Sequential(
            
            nn.Linear(input_len, num_neurons),
            nn.LeakyReLU(),
            
            nn.Linear(num_neurons, num_neurons),
            nn.LeakyReLU(),
            
            nn.Linear(num_neurons, num_neurons),
            nn.LeakyReLU(),
            
            nn.Linear(num_neurons, output_len),
        
        )
    
    def forward(self, xb):
        xb = self.main(xb)
        return xb

In [7]:
def get_model(**kwargs):
    nn = NN(**kwargs)
    return nn, optim.Adam(nn.parameters(), lr=10)

## Fine-tune learning rate

In [None]:
model, opt = get_model(input_len=200, output_len=1, num_neurons=10)
loss = nn.MSELoss()
learn = Learner(train_dl, train_dl, model, loss, opt)

loss_cb = LossCallback()
acc_cb = AccuracyCallback()
cb_handler = CallbackHandler(cbs=[loss_cb, acc_cb])

trainer = Trainer(learn=learn, cb_handler=cb_handler)

trainer.find_lr(beta=0.98, final_value=10, num_itr=300)

## Train neural net

## Analyze results