diff --git a/examples/supervised_learning/evaluate.py b/examples/supervised_learning/evaluate.py index 818e99c07..99cee7341 100755 --- a/examples/supervised_learning/evaluate.py +++ b/examples/supervised_learning/evaluate.py @@ -27,17 +27,39 @@ from donkeycar.utils import linear_unbin import conf +class GifCreator(object): + + def __init__(self, filename): + import imageio + self.filename = filename + self.images = [] + self.every_nth_frame = 4 + self.i_frame = 0 + + def add_image(self, image): + self.i_frame += 1 + if self.i_frame % self.every_nth_frame == 0: + self.images.append(image) + + def close(self): + import imageio + if len(self.images) > 0: + print('writing movie', self.filename) + imageio.mimsave(self.filename, self.images) + + class DonkeySimMsgHandler(IMesgHandler): STEERING = 0 THROTTLE = 1 - def __init__(self, model, constant_throttle): + def __init__(self, model, constant_throttle, movie_handler=None): self.model = model self.constant_throttle = constant_throttle self.sock = None self.timer = FPSTimer() self.image_folder = None + self.movie_handler = movie_handler self.fns = {'telemetry' : self.on_telemetry} def on_connect(self, socketHandler): @@ -62,11 +84,9 @@ def on_telemetry(self, data): image_array = np.asarray(image) self.predict(image_array) - # maybe save frame - if self.image_folder is not None: - timestamp = datetime.utcnow().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3] - image_filename = os.path.join(self.image_folder, timestamp) - image.save('{}.jpg'.format(image_filename)) + # maybe write movie + if self.movie_handler is not None: + self.movie_handler.add_image(image_array) def predict(self, image_array): @@ -112,20 +132,25 @@ def send_control(self, steer, throttle): self.sock.queue_message(msg) - def on_close(self): - pass + def on_disconnect(self): + if self.movie_handler: + self.movie_handler.close() - -def go(filename, address, constant_throttle): +def go(filename, address, constant_throttle, gif): model = load_model(filename) #In this mode, looks like we have to compile it model.compile("sgd", "mse") + + movie_handler = None + + if gif != "none": + movie_handler = GifCreator(gif) #setup the server - handler = DonkeySimMsgHandler(model, constant_throttle) + handler = DonkeySimMsgHandler(model, constant_throttle, movie_handler) server = SimServer(address, handler) try: @@ -140,7 +165,9 @@ def go(filename, address, constant_throttle): parser = argparse.ArgumentParser(description='prediction server') parser.add_argument('--model', type=str, help='model filename') parser.add_argument('--constant_throttle', type=float, default=0.0, help='apply constant throttle') + parser.add_argument('--gif', type=str, default="none", help='make animated gif of evaluation') + args = parser.parse_args() address = ('0.0.0.0', 9091) - go(args.model, address, args.constant_throttle) + go(args.model, address, args.constant_throttle, args.gif)