# Edit a pretrained models

Take an existing model that has been trained on many data sets and finish training it on a new dataset.

In [45]:
import keras

### Save new model so it can be loaded later. 

In [53]:
base_model = '/home/wroscoe/models/all_lined_tracks_categorical.h5'
new_model = '/home/wroscoe/models/rally2.h5'

###  Train on new data. 

In [54]:
import donkeycar as dk
from donkeycar.parts.keras import KerasCategorical
from donkeycar.parts.datastore import TubGroup

In [55]:
cfg = dk.config.load_config('/home/wroscoe/d2/config.py')

loading config file: /home/wroscoe/d2/config.py
config loaded


### Load base model.

In [56]:
kl = KerasCategorical()
kl.load(base_model)
kl.model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
img_in (InputLayer)              (None, 120, 160, 3)   0                                            
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (None, 58, 78, 24)    1824        img_in[0][0]                     
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (None, 27, 37, 32)    19232       conv2d_1[0][0]                   
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (None, 12, 17, 64)    51264       conv2d_2[0][0]                   
___________________________________________________________________________________________

In [57]:
#Check that the layers are trainable. 
[(l.name, l.trainable) for l in kl.model.layers]

[('img_in', False),
 ('conv2d_1', True),
 ('conv2d_2', True),
 ('conv2d_3', True),
 ('conv2d_4', True),
 ('conv2d_5', True),
 ('flattened', True),
 ('dense_1', True),
 ('dropout_1', True),
 ('dense_2', True),
 ('dropout_2', True),
 ('angle_out', True),
 ('throttle_out', True)]

### Load new dataset

In [58]:
tub_names = ','.join(['/home/wroscoe/data/rally/rally_A*'])

X_keys = ['cam/image_array']
y_keys = ['user/angle', 'user/throttle']

def rt(record):
    record['user/angle'] = dk.utils.linear_bin(record['user/angle'])
    return record

tubgroup = TubGroup(tub_names)
tubgroup.df

#Create data generators to train the network.
train_gen, val_gen = tubgroup.get_train_val_gen(X_keys, y_keys, record_transform=rt,
                                                batch_size=cfg.BATCH_SIZE,
                                                train_frac=cfg.TRAIN_TEST_SPLIT)


TubGroup:tubpaths: ['/home/wroscoe/data/rally/rally_A_06', '/home/wroscoe/data/rally/rally_A_02', '/home/wroscoe/data/rally/rally_A_08', '/home/wroscoe/data/rally/rally_A_05', '/home/wroscoe/data/rally/rally_A_07', '/home/wroscoe/data/rally/rally_A_03', '/home/wroscoe/data/rally/rally_A_01', '/home/wroscoe/data/rally/rally_A_04']
path_in_tub: /home/wroscoe/data/rally/rally_A_06
Tub exists: /home/wroscoe/data/rally/rally_A_06
path_in_tub: /home/wroscoe/data/rally/rally_A_02
Tub exists: /home/wroscoe/data/rally/rally_A_02
path_in_tub: /home/wroscoe/data/rally/rally_A_08
Tub exists: /home/wroscoe/data/rally/rally_A_08
path_in_tub: /home/wroscoe/data/rally/rally_A_05
Tub exists: /home/wroscoe/data/rally/rally_A_05
path_in_tub: /home/wroscoe/data/rally/rally_A_07
Tub exists: /home/wroscoe/data/rally/rally_A_07
path_in_tub: /home/wroscoe/data/rally/rally_A_03
Tub exists: /home/wroscoe/data/rally/rally_A_03
path_in_tub: /home/wroscoe/data/rally/rally_A_01
Tub exists: /home/wroscoe/data/rally/

In [59]:
total_records = len(tubgroup.df)
total_train = int(total_records * cfg.TRAIN_TEST_SPLIT)
total_val = total_records - total_train
print('train: %d, validation: %d' % (total_train, total_val))
steps_per_epoch = total_train // cfg.BATCH_SIZE
print('steps_per_epoch', steps_per_epoch)

kl.train(train_gen,
         val_gen,
         saved_model_path=new_model,
         steps=steps_per_epoch,
         train_split=cfg.TRAIN_TEST_SPLIT)

train: 38809, validation: 9703
steps_per_epoch 303
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 00017: early stopping


<keras.callbacks.History at 0x7f20cd0cd5f8>