<a href="https://colab.research.google.com/github/rafaelsoStanford/DDPM/blob/master/Copy_of_CurrentFinal_CarRacing_v2_DiffusionPolicy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: ignored

In [None]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown -  gym, pygame, pymunk & shapely
!python --version
!apt install swig &> /dev/null
!pip3 uninstall cvxpy -y > /dev/null
!pip3 install gymnasium[box2d] &> /dev/null
!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.11.1 simple-pid\
scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
&> /dev/null # mute output

In [None]:
#@markdown ### **Imports**
# Diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from simple_pid import PID

# env imports
import pygame
from pygame import gfxdraw

import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.box2d.car_dynamics import Car
from gymnasium.error import DependencyNotInstalled, InvalidAction
from gymnasium.utils import EzPickle
from typing import Optional, Union

import Box2D
from Box2D.b2 import fixtureDef
from Box2D.b2 import polygonShape
from Box2D.b2 import contactListener


In [None]:
#@markdown ### **Environment**
#@markdown Slightly modified version of the CarRacing Environment "CarRacing-v2".
#@markdown Adapted from [Gymnasium](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/box2d/car_racing.py)
#@markdown - Added new functions which return `velocity` and `on_track_flag` of the car
#@markdown - Observations are now returned as a dictionary and then accessed as needed in subsequent cells

__credits__ = ["Andrea PIERRÉ"]


STATE_W = 96  # less than Atari 160x192
STATE_H = 96
VIDEO_W = 600
VIDEO_H = 400
WINDOW_W = 1000
WINDOW_H = 800

SCALE = 6.0  # Track scale
TRACK_RAD = 900 / SCALE  # Track is heavily morphed circle with this radius
PLAYFIELD = 2000 / SCALE  # Game over boundary
FPS = 50  # Frames per second
ZOOM = 2.7  # Camera zoom
ZOOM_FOLLOW = True  # Set to False for fixed view (don't use zoom)


TRACK_DETAIL_STEP = 21 / SCALE
TRACK_TURN_RATE = 0.31
TRACK_WIDTH = 40 / SCALE
BORDER = 8 / SCALE
BORDER_MIN_COUNT = 4
GRASS_DIM = PLAYFIELD / 20.0
MAX_SHAPE_DIM = (
    max(GRASS_DIM, TRACK_WIDTH, TRACK_DETAIL_STEP) * math.sqrt(2) * ZOOM * SCALE
)


class FrictionDetector(contactListener):
    def __init__(self, env, lap_complete_percent):
        contactListener.__init__(self)
        self.env = env
        self.lap_complete_percent = lap_complete_percent

    def BeginContact(self, contact):
        self._contact(contact, True)

    def EndContact(self, contact):
        self._contact(contact, False)

    def _contact(self, contact, begin):
        tile = None
        obj = None
        u1 = contact.fixtureA.body.userData
        u2 = contact.fixtureB.body.userData
        if u1 and "road_friction" in u1.__dict__:
            tile = u1
            obj = u2
        if u2 and "road_friction" in u2.__dict__:
            tile = u2
            obj = u1
        if not tile:
            return

        # inherit tile color from env
        tile.color[:] = self.env.road_color
        if not obj or "tiles" not in obj.__dict__:
            return
        if begin:
            obj.tiles.add(tile)
            if not tile.road_visited:
                tile.road_visited = True
                self.env.reward += 1000.0 / len(self.env.track)
                self.env.tile_visited_count += 1

                # Lap is considered completed if enough % of the track was covered
                if (
                    tile.idx == 0
                    and self.env.tile_visited_count / len(self.env.track)
                    > self.lap_complete_percent
                ):
                    self.env.new_lap = True
        else:
            obj.tiles.remove(tile)


class CarRacing(gym.Env, EzPickle):
    """
    ## Description
    The easiest control task to learn from pixels - a top-down
    racing environment. The generated track is random every episode.

    Some indicators are shown at the bottom of the window along with the
    state RGB buffer. From left to right: true speed, four ABS sensors,
    steering wheel position, and gyroscope.
    To play yourself (it's rather fast for humans), type:
    ```
    python gymnasium/envs/box2d/car_racing.py
    ```
    Remember: it's a powerful rear-wheel drive car - don't press the accelerator
    and turn at the same time.

    ## Action Space
    If continuous there are 3 actions :
    - 0: steering, -1 is full left, +1 is full right
    - 1: gas
    - 2: breaking

    If discrete there are 5 actions:
    - 0: do nothing
    - 1: steer left
    - 2: steer right
    - 3: gas
    - 4: brake

    ## Observation Space

    A top-down 96x96 RGB image of the car and race track.

    ## Rewards
    The reward is -0.1 every frame and +1000/N for every track tile visited,
    where N is the total number of tiles visited in the track. For example,
    if you have finished in 732 frames, your reward is
    1000 - 0.1*732 = 926.8 points.

    ## Starting State
    The car starts at rest in the center of the road.

    ## Episode Termination
    The episode finishes when all the tiles are visited. The car can also go
    outside the playfield - that is, far off the track, in which case it will
    receive -100 reward and die.

    ## Arguments
    `lap_complete_percent` dictates the percentage of tiles that must be visited by
    the agent before a lap is considered complete.

    Passing `domain_randomize=True` enables the domain randomized variant of the environment.
    In this scenario, the background and track colours are different on every reset.

    Passing `continuous=False` converts the environment to use discrete action space.
    The discrete action space has 5 actions: [do nothing, left, right, gas, brake].

    ## Reset Arguments
    Passing the option `options["randomize"] = True` will change the current colour of the environment on demand.
    Correspondingly, passing the option `options["randomize"] = False` will not change the current colour of the environment.
    `domain_randomize` must be `True` on init for this argument to work.
    Example usage:
    ```python
    import gymnasium as gym
    env = gym.make("CarRacing-v1", domain_randomize=True)

    # normal reset, this changes the colour scheme by default
    env.reset()

    # reset with colour scheme change
    env.reset(options={"randomize": True})

    # reset with no colour scheme change
    env.reset(options={"randomize": False})
    ```

    ## Version History
    - v1: Change track completion logic and add domain randomization (0.24.0)
    - v0: Original version

    ## References
    - Chris Campbell (2014), http://www.iforce2d.net/b2dtut/top-down-car.

    ## Credits
    Created by Oleg Klimov
    """

    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "state_pixels",
        ],
        "render_fps": FPS,
    }

    def __init__(
        self,
        render_mode: Optional[str] = None,
        verbose: bool = False,
        lap_complete_percent: float = 0.95,
        domain_randomize: bool = False,
        continuous: bool = True,
    ):
        EzPickle.__init__(
            self,
            render_mode,
            verbose,
            lap_complete_percent,
            domain_randomize,
            continuous,
        )
        self.continuous = continuous
        self.domain_randomize = domain_randomize
        self.lap_complete_percent = lap_complete_percent
        self._init_colors()

        self.contactListener_keepref = FrictionDetector(self, self.lap_complete_percent)
        self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
        self.screen: Optional[pygame.Surface] = None
        self.surf = None
        self.clock = None
        self.isopen = True
        self.invisible_state_window = None
        self.invisible_video_window = None
        self.road = None
        self.car: Optional[Car] = None
        self.reward = 0.0
        self.prev_reward = 0.0
        self.verbose = verbose
        self.new_lap = False
        self.fd_tile = fixtureDef(
            shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])
        )

        # This will throw a warning in tests/envs/test_envs in utils/env_checker.py as the space is not symmetric
        #   or normalised however this is not possible here so ignore
        if self.continuous:
            self.action_space = spaces.Box(
                np.array([-1, 0, 0]).astype(np.float32),
                np.array([+1, +1, +1]).astype(np.float32),
            )  # steer, gas, brake
        else:
            self.action_space = spaces.Discrete(5)
            # do nothing, left, right, gas, brake

        self.observation_space = spaces.Box(
            low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
        )

        self.render_mode = render_mode

    def _destroy(self):
        if not self.road:
            return
        for t in self.road:
            self.world.DestroyBody(t)
        self.road = []
        assert self.car is not None
        self.car.destroy()

    def _init_colors(self):
        if self.domain_randomize:
            # domain randomize the bg and grass colour
            self.road_color = self.np_random.uniform(0, 210, size=3)

            self.bg_color = self.np_random.uniform(0, 210, size=3)

            self.grass_color = np.copy(self.bg_color)
            idx = self.np_random.integers(3)
            self.grass_color[idx] += 20
        else:
            # default colours
            self.road_color = np.array([102, 102, 102])
            self.bg_color = np.array([102, 204, 102])
            self.grass_color = np.array([102, 230, 102])

    def _reinit_colors(self, randomize):
        assert (
            self.domain_randomize
        ), "domain_randomize must be True to use this function."

        if randomize:
            # domain randomize the bg and grass colour
            self.road_color = self.np_random.uniform(0, 210, size=3)

            self.bg_color = self.np_random.uniform(0, 210, size=3)

            self.grass_color = np.copy(self.bg_color)
            idx = self.np_random.integers(3)
            self.grass_color[idx] += 20

    def _create_track(self):
        CHECKPOINTS = 12

        # Create checkpoints
        checkpoints = []
        for c in range(CHECKPOINTS):
            noise = self.np_random.uniform(0, 2 * math.pi * 1 / CHECKPOINTS)
            alpha = 2 * math.pi * c / CHECKPOINTS + noise
            rad = self.np_random.uniform(TRACK_RAD / 3, TRACK_RAD)

            if c == 0:
                alpha = 0
                rad = 1.5 * TRACK_RAD
            if c == CHECKPOINTS - 1:
                alpha = 2 * math.pi * c / CHECKPOINTS
                self.start_alpha = 2 * math.pi * (-0.5) / CHECKPOINTS
                rad = 1.5 * TRACK_RAD

            checkpoints.append((alpha, rad * math.cos(alpha), rad * math.sin(alpha)))
        self.road = []

        # Go from one checkpoint to another to create track
        x, y, beta = 1.5 * TRACK_RAD, 0, 0
        dest_i = 0
        laps = 0
        track = []
        no_freeze = 2500
        visited_other_side = False
        while True:
            alpha = math.atan2(y, x)
            if visited_other_side and alpha > 0:
                laps += 1
                visited_other_side = False
            if alpha < 0:
                visited_other_side = True
                alpha += 2 * math.pi

            while True:  # Find destination from checkpoints
                failed = True

                while True:
                    dest_alpha, dest_x, dest_y = checkpoints[dest_i % len(checkpoints)]
                    if alpha <= dest_alpha:
                        failed = False
                        break
                    dest_i += 1
                    if dest_i % len(checkpoints) == 0:
                        break

                if not failed:
                    break

                alpha -= 2 * math.pi
                continue

            r1x = math.cos(beta)
            r1y = math.sin(beta)
            p1x = -r1y
            p1y = r1x
            dest_dx = dest_x - x  # vector towards destination
            dest_dy = dest_y - y
            # destination vector projected on rad:
            proj = r1x * dest_dx + r1y * dest_dy
            while beta - alpha > 1.5 * math.pi:
                beta -= 2 * math.pi
            while beta - alpha < -1.5 * math.pi:
                beta += 2 * math.pi
            prev_beta = beta
            proj *= SCALE
            if proj > 0.3:
                beta -= min(TRACK_TURN_RATE, abs(0.001 * proj))
            if proj < -0.3:
                beta += min(TRACK_TURN_RATE, abs(0.001 * proj))
            x += p1x * TRACK_DETAIL_STEP
            y += p1y * TRACK_DETAIL_STEP
            track.append((alpha, prev_beta * 0.5 + beta * 0.5, x, y))
            if laps > 4:
                break
            no_freeze -= 1
            if no_freeze == 0:
                break

        # Find closed loop range i1..i2, first loop should be ignored, second is OK
        i1, i2 = -1, -1
        i = len(track)
        while True:
            i -= 1
            if i == 0:
                return False  # Failed
            pass_through_start = (
                track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
            )
            if pass_through_start and i2 == -1:
                i2 = i
            elif pass_through_start and i1 == -1:
                i1 = i
                break
        if self.verbose:
            print("Track generation: %i..%i -> %i-tiles track" % (i1, i2, i2 - i1))
        assert i1 != -1
        assert i2 != -1

        track = track[i1 : i2 - 1]

        first_beta = track[0][1]
        first_perp_x = math.cos(first_beta)
        first_perp_y = math.sin(first_beta)
        # Length of perpendicular jump to put together head and tail
        well_glued_together = np.sqrt(
            np.square(first_perp_x * (track[0][2] - track[-1][2]))
            + np.square(first_perp_y * (track[0][3] - track[-1][3]))
        )
        if well_glued_together > TRACK_DETAIL_STEP:
            return False

        # Red-white border on hard turns
        border = [False] * len(track)
        for i in range(len(track)):
            good = True
            oneside = 0
            for neg in range(BORDER_MIN_COUNT):
                beta1 = track[i - neg - 0][1]
                beta2 = track[i - neg - 1][1]
                good &= abs(beta1 - beta2) > TRACK_TURN_RATE * 0.2
                oneside += np.sign(beta1 - beta2)
            good &= abs(oneside) == BORDER_MIN_COUNT
            border[i] = good
        for i in range(len(track)):
            for neg in range(BORDER_MIN_COUNT):
                border[i - neg] |= border[i]

        # Create tiles
        for i in range(len(track)):
            alpha1, beta1, x1, y1 = track[i]
            alpha2, beta2, x2, y2 = track[i - 1]
            road1_l = (
                x1 - TRACK_WIDTH * math.cos(beta1),
                y1 - TRACK_WIDTH * math.sin(beta1),
            )
            road1_r = (
                x1 + TRACK_WIDTH * math.cos(beta1),
                y1 + TRACK_WIDTH * math.sin(beta1),
            )
            road2_l = (
                x2 - TRACK_WIDTH * math.cos(beta2),
                y2 - TRACK_WIDTH * math.sin(beta2),
            )
            road2_r = (
                x2 + TRACK_WIDTH * math.cos(beta2),
                y2 + TRACK_WIDTH * math.sin(beta2),
            )
            vertices = [road1_l, road1_r, road2_r, road2_l]
            self.fd_tile.shape.vertices = vertices
            t = self.world.CreateStaticBody(fixtures=self.fd_tile)
            t.userData = t
            c = 0.01 * (i % 3) * 255
            t.color = self.road_color + c
            t.road_visited = False
            t.road_friction = 1.0
            t.idx = i
            t.fixtures[0].sensor = True
            self.road_poly.append(([road1_l, road1_r, road2_r, road2_l], t.color))
            self.road.append(t)
            if border[i]:
                side = np.sign(beta2 - beta1)
                b1_l = (
                    x1 + side * TRACK_WIDTH * math.cos(beta1),
                    y1 + side * TRACK_WIDTH * math.sin(beta1),
                )
                b1_r = (
                    x1 + side * (TRACK_WIDTH + BORDER) * math.cos(beta1),
                    y1 + side * (TRACK_WIDTH + BORDER) * math.sin(beta1),
                )
                b2_l = (
                    x2 + side * TRACK_WIDTH * math.cos(beta2),
                    y2 + side * TRACK_WIDTH * math.sin(beta2),
                )
                b2_r = (
                    x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
                    y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
                )
                self.road_poly.append(
                    (
                        [b1_l, b1_r, b2_r, b2_l],
                        (255, 255, 255) if i % 2 == 0 else (255, 0, 0),
                    )
                )
        self.track = track
        return True

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        self._destroy()
        self.world.contactListener_bug_workaround = FrictionDetector(
            self, self.lap_complete_percent
        )
        self.world.contactListener = self.world.contactListener_bug_workaround
        self.reward = 0.0
        self.prev_reward = 0.0
        self.tile_visited_count = 0
        self.t = 0.0
        self.new_lap = False
        self.road_poly = []

        if self.domain_randomize:
            randomize = True
            if isinstance(options, dict):
                if "randomize" in options:
                    randomize = options["randomize"]

            self._reinit_colors(randomize)

        while True:
            success = self._create_track()
            if success:
                break
            if self.verbose:
                print(
                    "retry to generate track (normal if there are not many"
                    "instances of this message)"
                )
        self.car = Car(self.world, *self.track[0][1:4])

        if self.render_mode == "human":
            self.render()
        return self.step(None)[0], {}

    def return_velocity(self):
        return self.car.hull.linearVelocity

    def return_abs_velocity(self):
        v= self.car.hull.linearVelocity
        return np.linalg.norm(v)

    def return_track_flag(self):
        """
        Verify if a tire is on grass tile
        Returns: True if on track, False if on grass
        """
        grass = True
        track = False
        for w in self.car.wheels:
            for tile in w.tiles:
                #If there is a tile that is not grass, then the car is not on grass
                grass = False  
                track = True 
        return track

    def step(self, action: Union[np.ndarray, int]):
        assert self.car is not None
        if action is not None:
            if self.continuous:
                self.car.steer(-action[0])
                self.car.gas(action[1])
                self.car.brake(action[2])
            else:
                if not self.action_space.contains(action):
                    raise InvalidAction(
                        f"you passed the invalid action `{action}`. "
                        f"The supported action_space is `{self.action_space}`"
                    )
                self.car.steer(-0.6 * (action == 1) + 0.6 * (action == 2))
                self.car.gas(0.2 * (action == 3))
                self.car.brake(0.8 * (action == 4))

        self.car.step(1.0 / FPS)
        self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
        self.t += 1.0 / FPS

        self.state = self._render("state_pixels")

        step_reward = 0
        terminated = False
        truncated = False
        if action is not None:  # First step without action, called from reset()
            self.reward -= 0.1
            # We actually don't want to count fuel spent, we want car to be faster.
            # self.reward -=  10 * self.car.fuel_spent / ENGINE_POWER
            self.car.fuel_spent = 0.0
            step_reward = self.reward - self.prev_reward
            self.prev_reward = self.reward
            if self.tile_visited_count == len(self.track) or self.new_lap:
                # Truncation due to finishing lap
                # This should not be treated as a failure
                # but like a timeout
                truncated = True
            x, y = self.car.hull.position
            if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD:
                terminated = True
                step_reward = -100

        if self.render_mode == "human":
            self.render()
        

        # create observation dict
        observation = {
            "image": self.state,
            "velocity": self.return_abs_velocity(),
            "on_track": self.return_track_flag()
        }
        
        #if action is None:
        #  return self.state, step_reward, terminated, truncated, {}
        return observation, step_reward, terminated, truncated, {}

    def render(self):
        if self.render_mode is None:
            assert self.spec is not None
            gym.logger.warn(
                "You are calling render method without specifying any render mode. "
                "You can specify the render_mode at initialization, "
                f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
            )
            return
        else:
            return self._render(self.render_mode)

    def _render(self, mode: str):
        assert mode in self.metadata["render_modes"]

        pygame.font.init()
        if self.screen is None and mode == "human":
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((WINDOW_W, WINDOW_H))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        if "t" not in self.__dict__:
            return  # reset() not called yet

        self.surf = pygame.Surface((WINDOW_W, WINDOW_H))

        assert self.car is not None
        # computing transformations
        angle = -self.car.hull.angle
        # Animating first second zoom.
        zoom = 0.1 * SCALE * max(1 - self.t, 0) + ZOOM * SCALE * min(self.t, 1)
        scroll_x = -(self.car.hull.position[0]) * zoom
        scroll_y = -(self.car.hull.position[1]) * zoom
        trans = pygame.math.Vector2((scroll_x, scroll_y)).rotate_rad(angle)
        trans = (WINDOW_W / 2 + trans[0], WINDOW_H / 4 + trans[1])

        self._render_road(zoom, trans, angle)
        self.car.draw(
            self.surf,
            zoom,
            trans,
            angle,
            mode not in ["state_pixels_list", "state_pixels"],
        )

        self.surf = pygame.transform.flip(self.surf, False, True)

        # showing stats
        self._render_indicators(WINDOW_W, WINDOW_H)

        font = pygame.font.Font(pygame.font.get_default_font(), 42)
        text = font.render("%04i" % self.reward, True, (255, 255, 255), (0, 0, 0))
        text_rect = text.get_rect()
        text_rect.center = (60, WINDOW_H - WINDOW_H * 2.5 / 40.0)
        self.surf.blit(text, text_rect)

        if mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            assert self.screen is not None
            self.screen.fill(0)
            self.screen.blit(self.surf, (0, 0))
            pygame.display.flip()
        elif mode == "rgb_array":
            return self._create_image_array(self.surf, (VIDEO_W, VIDEO_H))
            
        elif mode == "state_pixels":
            return self._create_image_array(self.surf, (STATE_W, STATE_H))
        else:
            return self.isopen

    def _render_road(self, zoom, translation, angle):
        bounds = PLAYFIELD
        field = [
            (bounds, bounds),
            (bounds, -bounds),
            (-bounds, -bounds),
            (-bounds, bounds),
        ]

        # draw background
        self._draw_colored_polygon(
            self.surf, field, self.bg_color, zoom, translation, angle, clip=False
        )

        # draw grass patches
        grass = []
        for x in range(-20, 20, 2):
            for y in range(-20, 20, 2):
                grass.append(
                    [
                        (GRASS_DIM * x + GRASS_DIM, GRASS_DIM * y + 0),
                        (GRASS_DIM * x + 0, GRASS_DIM * y + 0),
                        (GRASS_DIM * x + 0, GRASS_DIM * y + GRASS_DIM),
                        (GRASS_DIM * x + GRASS_DIM, GRASS_DIM * y + GRASS_DIM),
                    ]
                )
        for poly in grass:
            self._draw_colored_polygon(
                self.surf, poly, self.grass_color, zoom, translation, angle
            )

        # draw road
        for poly, color in self.road_poly:
            # converting to pixel coordinates
            poly = [(p[0], p[1]) for p in poly]
            color = [int(c) for c in color]
            self._draw_colored_polygon(self.surf, poly, color, zoom, translation, angle)

    def _render_indicators(self, W, H):
        s = W / 40.0
        h = H / 40.0
        color = (0, 0, 0)
        polygon = [(W, H), (W, H - 5 * h), (0, H - 5 * h), (0, H)]
        pygame.draw.polygon(self.surf, color=color, points=polygon)

        def vertical_ind(place, val):
            return [
                (place * s, H - (h + h * val)),
                ((place + 1) * s, H - (h + h * val)),
                ((place + 1) * s, H - h),
                ((place + 0) * s, H - h),
            ]

        def horiz_ind(place, val):
            return [
                ((place + 0) * s, H - 4 * h),
                ((place + val) * s, H - 4 * h),
                ((place + val) * s, H - 2 * h),
                ((place + 0) * s, H - 2 * h),
            ]

        assert self.car is not None
        true_speed = np.sqrt(
            np.square(self.car.hull.linearVelocity[0])
            + np.square(self.car.hull.linearVelocity[1])
        )

        # simple wrapper to render if the indicator value is above a threshold
        def render_if_min(value, points, color):
            if abs(value) > 1e-4:
                pygame.draw.polygon(self.surf, points=points, color=color)

        render_if_min(true_speed, vertical_ind(5, 0.02 * true_speed), (255, 255, 255))
        # ABS sensors
        render_if_min(
            self.car.wheels[0].omega,
            vertical_ind(7, 0.01 * self.car.wheels[0].omega),
            (0, 0, 255),
        )
        render_if_min(
            self.car.wheels[1].omega,
            vertical_ind(8, 0.01 * self.car.wheels[1].omega),
            (0, 0, 255),
        )
        render_if_min(
            self.car.wheels[2].omega,
            vertical_ind(9, 0.01 * self.car.wheels[2].omega),
            (51, 0, 255),
        )
        render_if_min(
            self.car.wheels[3].omega,
            vertical_ind(10, 0.01 * self.car.wheels[3].omega),
            (51, 0, 255),
        )

        render_if_min(
            self.car.wheels[0].joint.angle,
            horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle),
            (0, 255, 0),
        )
        render_if_min(
            self.car.hull.angularVelocity,
            horiz_ind(30, -0.8 * self.car.hull.angularVelocity),
            (255, 0, 0),
        )

    def _draw_colored_polygon(
        self, surface, poly, color, zoom, translation, angle, clip=True
    ):
        poly = [pygame.math.Vector2(c).rotate_rad(angle) for c in poly]
        poly = [
            (c[0] * zoom + translation[0], c[1] * zoom + translation[1]) for c in poly
        ]
        # This checks if the polygon is out of bounds of the screen, and we skip drawing if so.
        # Instead of calculating exactly if the polygon and screen overlap,
        # we simply check if the polygon is in a larger bounding box whose dimension
        # is greater than the screen by MAX_SHAPE_DIM, which is the maximum
        # diagonal length of an environment object
        if not clip or any(
            (-MAX_SHAPE_DIM <= coord[0] <= WINDOW_W + MAX_SHAPE_DIM)
            and (-MAX_SHAPE_DIM <= coord[1] <= WINDOW_H + MAX_SHAPE_DIM)
            for coord in poly
        ):
            gfxdraw.aapolygon(self.surf, poly, color)
            gfxdraw.filled_polygon(self.surf, poly, color)

    def _create_image_array(self, screen, size):
        scaled_screen = pygame.transform.smoothscale(screen, size)
        return np.transpose(
            np.array(pygame.surfarray.pixels3d(scaled_screen)), axes=(1, 0, 2)
        )

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            self.isopen = False
            pygame.quit()



In [None]:

#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.26.0 API)
#@markdown Car Racing Demo

# 0. create env object
env = CarRacing(render_mode='state_pixels')

# 2. must reset before use
obs = env.reset(seed=200)

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, done, _ ,info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):

    print("obs['image'].shape:", obs['image'].shape, "Box(0, 255, (96, 96, 3), uint8)")
    print("obs['velocity']", type(obs['velocity']))
    print("obs['on_track']:", type(obs['on_track']))
    print("Action sample: ", action)



In [65]:
#@markdown ### **Controllers for Driving behaviors**
#@markdown Three different controller strategies:
#@markdown - PID Driver: Safe trajectory following the middle line of track
#@markdown - Sinusoidal Driver (safe): Sinusoidal driving trajectory; stays within the lane
#@markdown - Sinusoidal Driver (unsafe): Sinusoidal driving trajectory; leaves lane and recovers

#@markdown Middle line is estimated using Image Processing (A forced solution, but works quite well).



def findEdges(image):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    mask_green = cv2.inRange(hsv, (36, 25, 25), (70, 255, 255))
    edgesGreen = cv2.Canny(mask_green, 100, 255)
    edgesGreen[64:78, 44:52] = 0
    edgesGreen[83:-1, :] = 0
    kernel = np.ones((3, 3), np.uint8)
    edgesGreen = cv2.dilate(edgesGreen, kernel, iterations=2)
    edgesGreen = cv2.erode(edgesGreen, kernel, iterations=2)
    return edgesGreen

def findClosestEdgePos(edges, carPos = np.array([70, 48])):
    edgesPos = np.nonzero(edges)
    #Find single closest edge point
    distanceCarToEdges = np.linalg.norm(np.array(carPos)[:, None] - np.array(edgesPos), axis=0)
    closestEdgeIdx = np.argmin(distanceCarToEdges)
    closestEdgePos = np.array([edgesPos[0][closestEdgeIdx], edgesPos[1][closestEdgeIdx]])
    return closestEdgePos

def findTrackVector(edges, closestEdgePos):
    #Find vector describing track direction (tangent to track) using a square around the closest edge point
    squareSize = 3
    squareMiddlePoint = closestEdgePos
    square = edges.copy()[squareMiddlePoint[0] - squareSize: squareMiddlePoint[0] + squareSize + 1,
                                squareMiddlePoint[1] - squareSize: squareMiddlePoint[1] + squareSize + 1]
    edgesIdxSquare = np.nonzero(square)
    pnt1 = np.array([edgesIdxSquare[0][0], edgesIdxSquare[1][0]])
    pnt2 = np.array([edgesIdxSquare[0][-1], edgesIdxSquare[1][-1]])
    vector_track = pnt2 - pnt1
    return vector_track

def calculateTargetPoint(image, widthOfTrack, freq, scale_dist, Amplitude, t):
    # Find edges of track
    edges = findEdges(image) # returns a binary image with edges of track
    closestEdgePos = findClosestEdgePos(edges) # returns the position of the closest edge point to the car
    vector_track = findTrackVector(edges, closestEdgePos) # returns a vector describing the direction of the track
    
    #Make sure the track vector is pointing towards the car direction
    if np.dot(vector_track, np.array([-1, 0])) < 0:
        vector_track = -vector_track

    #Normalized track heading vector and perpendicular vector
    vector_track_normalized = vector_track / np.linalg.norm(vector_track)
    vector_track_perp_normalized = np.array([-vector_track_normalized[1], vector_track_normalized[0]])

    #Make sure that both vectors have reasonable values
    if np.isnan(vector_track_normalized).any() or np.isnan(vector_track_perp_normalized).any():
        return None, None, None, None

    #Check if the vector is pointing towards the inside of the track
    controlPixelPos = closestEdgePos + (vector_track_perp_normalized*3).astype(int)
    controlPixel = image[controlPixelPos[0], controlPixelPos[1]]
    if controlPixel[1] > 200: # Green pixel meaning outside of track
        vector_track_perp_normalized = -vector_track_perp_normalized
    
    #Find the estimated middle point of the track relative to the closest edge point
    estimatedMiddlePoint = (closestEdgePos + vector_track_perp_normalized * widthOfTrack / 2).astype(int)

    # Calculate the next num_points points on the trajectory (sinusoidal curve)
    sin_coeff = Amplitude * np.sin((t+1) * freq * 2 * np.pi)
    #cross product btw track vector and perpendicular vector positive
    sin_vector = (sin_coeff * vector_track_perp_normalized).astype(int)
    if np.cross(vector_track_normalized, vector_track_perp_normalized) < 0:
        sin_vector = -sin_vector
    sin_vector = sin_vector.astype(int)
    dir_vector = vector_track_normalized * scale_dist
    sinusPoints_pos = (estimatedMiddlePoint + dir_vector + sin_vector)
    targetPoint = sinusPoints_pos
    targetPoint = int(targetPoint[0]), int(targetPoint[1]) 

    return targetPoint, estimatedMiddlePoint, vector_track_normalized ,vector_track_perp_normalized


def find_edge_1dStrip(array, direction):
    # Find edge point of a 1D array. 
    # If none is found return -1
    starting_point = int(len(array) // 2)
    idx = -1
    if direction == 'left':
        for i in range(starting_point, -1, -1):
            if array[i] != 0:
                idx = i
                break
    elif direction == 'right':
        for i in range(starting_point, len(array)):
            if array[i] != 0:
                idx = i
                break
    return idx

def find_middle_point(strip_1d):
    # Check if there is edge point for both left and right side of track.
    # If none is found set border of strip as edge point
    idx1 = find_edge_1dStrip(strip_1d, 'left')
    idx2 = find_edge_1dStrip(strip_1d, 'right')

    if idx1 == -1:
        idx1 = 0
    if idx2 == -1:
        idx2 = len(strip_1d) - 1

    idx_middle = int((idx1 + idx2) / 2)
    return idx_middle

def calculateDistAngle(idx_middle_upper, idx_middle_lower, strip_width, strip_height):
    # Calculate distance and angle from middle of the track
    # idx_middle_upper: index of middle point on upper edge of strip
    # idx_middle_lower: index of middle point on lower edge of strip
    # strip_width: width of the strip
    # strip_height: height of the strip
    # return: distance and angle
    
    # Compute distance to middleline
    distance_to_middleline = strip_width // 2 - idx_middle_lower
    # Compute angular error
    upper_lenght_to_target = strip_width // 2 - idx_middle_upper
    angle_to_target = np.arctan(upper_lenght_to_target / strip_height)
    return distance_to_middleline, angle_to_target


def processImage(image):
    # Cropping image down to a strip
    strip_height = 20
    strip_width = 96
    middle_height = 65
    top = int(middle_height - strip_height / 2)
    bottom = int(middle_height + strip_height / 2)
    # Crop the strip from the image
    strip = image[top:bottom, :]

    ## Mask where only edge is retained
    hsv = cv2.cvtColor(strip, cv2.COLOR_BGR2HSV)
    mask_green = cv2.inRange(hsv, (36, 25, 25), (70, 255,255))
    imask_green = mask_green>0
    gray_mask = imask_green.astype(np.uint8)
    gray_mask = gray_mask*255
    # Only use two edges of the strip: Upper and lower and find edge points coordinates
    upper_edge = gray_mask[0, :]
    lower_edge = gray_mask[strip_height - 1, :]
    # Get index of middle point on the upper and lower edge
    idx_middle_upper = find_middle_point(upper_edge)
    idx_middle_lower = find_middle_point(lower_edge)

    distance, angle = calculateDistAngle(idx_middle_upper, idx_middle_lower, strip_width, strip_height)
    return distance, angle


def calculateAction(observation , target_velocity):

    # Initialize controllers
    pid_angle = PID(0.5, -0.01, 0.05, setpoint=0)
    pid_distance = PID(0.5, -0.005, 0.05, setpoint=0)
    pid_velocity = PID(0.05, -0.1, 0.2, setpoint=target_velocity)
    
    # Distinguish observation type
    image = observation['image']
    velocity = observation['velocity']

    # Get distance from processed image
    error_dist, error_ang = processImage(image)

    # Get control outputs from PD controllers
    control_ang = pid_angle(error_ang)
    control_dist = pid_distance(error_dist)
    control_vel = pid_velocity(velocity)

    #print("Control outputs: ", control_ang, control_dist, control_vel)
    acc = control_vel
    breaking = 0
    if acc < 0:
        acc = 0
        breaking = -control_vel
    
    # Calculate and return final action
    action = [control_ang , acc, breaking]
    return action

def action_sinusoidalTrajectory(t, freq, observation, Amplitude, target_velocity):
    # Observations are the following:
    image = observation['image']
    velocity = observation['velocity']

    # Environment constants
    carPos = np.array([70, 48]) # Position of the car in the image (pixel coordinates)
    widthOfTrack = 20 # Approx width of the track in pixels

    # Initialize controllers
    pid_angle = PID(0.5, -0.2, 0.0, setpoint=0)
    pid_velocity = PID(0.05, 0.1, 0.1, setpoint=target_velocity)

    # Find the next target point of sinusoidal trajectory
    scale_dist = 10 # This scales the vertical distance of the next target point from tip of car
    targetPoint, estimatedMiddlePoint, vector_track_normalized, vector_track_perp_normalized = calculateTargetPoint(image, widthOfTrack, freq, scale_dist , Amplitude, t)
    
    if targetPoint is None:
        action = [0,0,0] # If unreasonable values where found for the target point, keep the previous action. This avoids an edge case error
        return action

    # Calculate the angle to the target point
    error = targetPoint - carPos
    carVector = np.array([-1, 0])
    angle = np.arccos(np.dot(error, carVector) / (np.linalg.norm(error) * np.linalg.norm(carVector)))
    #Check if the angle is positive or negative -> negative full left turn, positive full right turn        
    if error[1] > 0:
        angle = -angle        
    steeringAngle = pid_angle(angle)
    # Calculate the acceleration or if negative, the breaking
    acc = pid_velocity(velocity)
    breaking = 0
    if acc < 0:
        breaking = -acc
        acc = 0
    action = [steeringAngle, acc, breaking]

    #print("Actions: ", action)
    return action
    



In [66]:
#@markdown ### **Controller Demo**
#@markdown This cell is only for demonstration purposes. Skip if not needed

# Get the initial observation and velocity
obs, _ = env.reset(seed=500)
env.render()
imgs0 , imgs1, imgs2 = [obs['image']], [obs['image']], [obs['image']]

for controller in range(3):
  # Run the controller for maxiter steps
  env.reset(seed=500)
  max_iter = 200
  for i in range(max_iter):
      # Choose controller
      if controller == 0: #PID Driver
        action = calculateAction(obs, target_velocity=30)
        obs, reward, done, _ , info = env.step(action)
        imgs0.append(obs['image'])
      if controller == 1: #Safe
        action = action_sinusoidalTrajectory(i, 1/100, obs, 5 ,target_velocity=30)
        obs, reward, done, _ , info = env.step(action)
        imgs1.append(obs['image'])
      if controller == 2: #Unsafe
        action = action_sinusoidalTrajectory(i, 1/100, obs, 12 ,target_velocity=30)
        obs, reward, done, _ , info = env.step(action)
        imgs2.append(obs['image'])
      if done:
          break
# Close the environment
env.close()

# Visualize
from IPython.display import Video
from IPython.display import HTML

vwrite('vis.mp4', imgs0)
video1 = Video('vis.mp4', embed=True, width=256, height=256)

vwrite('vis1.mp4', imgs1)
video2 = Video('vis1.mp4', embed=True, width=256, height=256)

vwrite('vis2.mp4', imgs2)
video3 = Video('vis2.mp4', embed=True, width=256, height=256)

# Displaying videos side by side using HTML and CSS
html_code = f'''
<div style="display:flex;">
    <div style="margin-right:10px;">{video1._repr_html_()}</div>
    <div style="margin-right:10px;">{video2._repr_html_()}</div>
    <div>{video3._repr_html_()}</div>
</div>
'''
HTML(html_code)

In [67]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `CarRacingDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data from a zarr storage
#@markdown - Normalizes each dimension of non-images and actions to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition


#@markdown Planned data structure that is used as zarr file.<br>
#@markdown ├── data <br>
#@markdown │   ├── action (*, 3) float32 <br>
#@markdown │   ├── h_action (*, 3) float32<br>
#@markdown │   ├── img (*, 96, 96, 3) float32<br>
#@markdown │   ├── track (*, 1) float32<br>
#@markdown │   └── velocity (*, 1) float32<br>
#@markdown └── meta<br>
#@markdown    └── episode_ends (*,) int64<br>


def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int, 
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx
        
        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after
        
        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx, 
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx, 
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class CarRacingDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 dataset_path: str,
                 pred_horizon: int, 
                 obs_horizon: int, 
                 action_horizon: int):
        
        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')
        
        # float32, [0,1], (N,96,96,3)
        train_image_data = dataset_root['data']['img'][:]
        train_image_data = train_image_data.astype(np.uint8)

        train_image_data = np.moveaxis(train_image_data, -1,1)
        # (N,3,96,96)

        # (N, D)
        train_data = {
            # Create Prediction Targets
            'positions_pred': dataset_root['data']['position'][:], # (T,2)
            'velocities_pred': dataset_root['data']['velocity'][:], # (T,2)
            'actions_pred': dataset_root['data']['action'][:] #(T,3)
        }
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])
        
        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations and add corresponding observations to each prediction batch
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['action_obs'] = nsample['actions_pred'][:self.obs_horizon,:]
        nsample['velocity_obs'] = nsample['velocities_pred'][:self.obs_horizon,:]
        nsample['position_obs'] = nsample['positions_pred'][:self.obs_horizon,:]
        return nsample


In [76]:
dataset_path = '/content/drive/MyDrive/ActionPrediction/data/multipleDrivingBehaviours_parallel.zarr.zip'
# download demonstration data from Google Drive if not available
if not os.path.isfile(dataset_path):
    #Locally saves the file (temporary; deleted after runtime)
    print("Downloading File from Google drive")
    id = "14Jyz9YoqOv57DBt-JJHC5M6EA2OWUQEl"
    dataset_path = 'multipleDrivingBehaviours.zarr.zip'
    gdown.download(id=id, output=dataset_path, quiet=False)
zarr_array = zarr.open(dataset_path, mode='r')

In [2]:

#@markdown ### **Dataset Demo**
#@markdown A dataset is loaded and later used for training.
#@markdown ***Note:*** If data is not available on your drive, 
#@markdown it is fetched from our Google Drive and saved during this runtime (temporary).
#@markdown For code which generated dataset see: 
#@markdown [Shared Autonomy and Risk negtioation](https://github.com/rafaelsoStanford/SharedAutonomy_RiskNegotiation/tree/main)



# parameters
# parameters
pred_horizon = 8
obs_horizon = 2
action_horizon = 1

# create dataset from file
dataset = CarRacingDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# print(dataset.normalized_train_data.keys())
# print(list(dataset.normalized_train_data.keys())[0], ": ", dataset.normalized_train_data[list(dataset.normalized_train_data.keys())[0]].shape)
# print(list(dataset.normalized_train_data.keys())[1], ": ", dataset.normalized_train_data[list(dataset.normalized_train_data.keys())[1]].shape)
# print(list(dataset.normalized_train_data.keys())[2], ": ", dataset.normalized_train_data[list(dataset.normalized_train_data.keys())[2]].shape)
# print(list(dataset.normalized_train_data.keys())[3], ": ", dataset.normalized_train_data[list(dataset.normalized_train_data.keys())[3]].shape)
# print(list(dataset.normalized_train_data.keys())[4], ": ", dataset.normalized_train_data[list(dataset.normalized_train_data.keys())[4]].shape)

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True, 
    # don't kill worker process afte each epoch
    persistent_workers=True 
)

batch = next(iter(dataloader))

print("Visualizing Batch and Data structure")
for key, value in batch.items():
    print()
    print(f'--> Key: {key}')
    print(f'Shape: ({len(value)}, {len(value[0])})')
    print(value.shape)

    if key == 'image':
        min_value = value.min().item()
        max_value = value.max().item()
        print(f'Min Value: {min_value}')
        print(f'Max Value: {max_value}')
        



ModuleNotFoundError: ignored

In [78]:
#@markdown ### **Vision Encoder**
#@markdown
#@markdown Defines helper functions:
#@markdown - `get_resnet` to initialize standard ResNet vision encoder
#@markdown - `replace_bn_with_gn` to replace all BatchNorm layers with GroupNorm

def get_resnet(name:str, weights=None, **kwargs) -> nn.Module:
    """
    name: resnet18, resnet34, resnet50
    weights: "IMAGENET1K_V1", None
    """
    # Use standard ResNet implementation from torchvision
    func = getattr(torchvision.models, name)
    resnet = func(weights=weights, **kwargs)

    # remove the final fully connected layer
    # for resnet18, the output dim should be 512
    resnet.fc = torch.nn.Identity()
    return resnet


def replace_submodules(
        root_module: nn.Module, 
        predicate: Callable[[nn.Module], bool], 
        func: Callable[[nn.Module], nn.Module]) -> nn.Module:
    """
    Replace all submodules selected by the predicate with
    the output of func.

    predicate: Return true if the module is to be replaced.
    func: Return new module to use.
    """
    if predicate(root_module):
        return func(root_module)

    bn_list = [k.split('.') for k, m 
        in root_module.named_modules(remove_duplicate=True) 
        if predicate(m)]
    for *parent, k in bn_list:
        parent_module = root_module
        if len(parent) > 0:
            parent_module = root_module.get_submodule('.'.join(parent))
        if isinstance(parent_module, nn.Sequential):
            src_module = parent_module[int(k)]
        else:
            src_module = getattr(parent_module, k)
        tgt_module = func(src_module)
        if isinstance(parent_module, nn.Sequential):
            parent_module[int(k)] = tgt_module
        else:
            setattr(parent_module, k, tgt_module)
    # verify that all modules are replaced
    bn_list = [k.split('.') for k, m 
        in root_module.named_modules(remove_duplicate=True) 
        if predicate(m)]
    assert len(bn_list) == 0
    return root_module

def replace_bn_with_gn(
    root_module: nn.Module, 
    features_per_group: int=16) -> nn.Module:
    """
    Relace all BatchNorm layers with GroupNorm.
    """
    replace_submodules(
        root_module=root_module,
        predicate=lambda x: isinstance(x, nn.BatchNorm2d),
        func=lambda x: nn.GroupNorm(
            num_groups=x.num_features//features_per_group, 
            num_channels=x.num_features)
    )
    return root_module


In [79]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection. 
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self, 
            in_channels, 
            out_channels, 
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self, 
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM 
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level. 
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim, 
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim, 
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))
        
        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self, 
            sample: torch.Tensor, 
            timestep: Union[torch.Tensor, float, int], 
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)
        
        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


In [80]:
#@markdown ### **Network Demo**

# construct ResNet18 encoder
# if you have multiple camera views, use seperate encoder weights for each view.
vision_encoder = get_resnet('resnet18')

# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
vision_encoder = replace_bn_with_gn(vision_encoder)


## Dimensions of Features ##  ########## IMPORTANT TO ADJUST ###########
# ResNet18 has output dim of 512
vision_feature_dim = 512

# Position (2 dim) + Velocity (2 dim) + "Human" action (3 dim)
lowdim_obs_dim = 2 + 2 + 3
# observation feature total per step
obs_dim = vision_feature_dim + lowdim_obs_dim

# Action space is 3 dimensional + Position 2 dim + Velocity 2 dim
action_dim = 3 + 2 + 2

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# the final arch has 2 parts
nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})

# demo
with torch.no_grad():
    # example inputs
    image = torch.zeros((1, obs_horizon,3,96,96))
    position = torch.zeros((1, obs_horizon, 2))
    velocity = torch.zeros((1, obs_horizon, 2))
    h_action = torch.zeros((1, obs_horizon, 3))

    # vision encoder
    image_features = nets['vision_encoder'](
        image.flatten(end_dim=1))
    # (3,512)
    print(image_features.shape)
    image_features = image_features.reshape(*image.shape[:2],-1)
    # (1,3,512)
    print(image_features.shape)
    obs = torch.cat([image_features, position, velocity, h_action],dim=-1)
    print(obs.shape)
  

    noised_action = torch.randn((1, pred_horizon, action_dim))
    diffusion_iter = torch.zeros((1,))
    flat = obs.flatten(start_dim = 1)

    # the noise prediction network
    # takes noisy action, diffusion iteration and observation as input
    # predicts the noise added to action
    noise = nets['noise_pred_net'](
        sample=noised_action, 
        timestep=diffusion_iter,
        global_cond=obs.flatten(start_dim=1))

    # illustration of removing noise 
    # the actual noise removal is performed by NoiseScheduler 
    # and is dependent on the diffusion noise schedule
    denoised_action = noised_action - noise

    print(denoised_action.shape)

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cuda')
_ = nets.to(device)


number of parameters: 8.009959e+07
torch.Size([2, 512])
torch.Size([1, 2, 512])
torch.Size([1, 2, 519])
torch.Size([1, 8, 7])


In [1]:
# Load the TensorBoard notebook extension
# Install latest Tensorflow build
!pip install -q tf-nightly-2.0-previewfrom tensorflow import summary
%load_ext tensorboard.notebook
import datetime
import tensorflow as tf
from tensorflow import summary

current_time = str(datetime.datetime.now().timestamp())

train_log_dir = 'logs/tensorboard/train/' + current_time
test_log_dir = 'logs/tensorboard/test/' + current_time

train_summary_writer = summary.create_file_writer(train_log_dir)
test_summary_writer = summary.create_file_writer(test_log_dir)

%tensorboard --logdir logs/tensorboard

In [None]:
#@markdown ### **Training**
#@markdown
#@markdown Takes about 2.5 hours. If you don't want to wait, skip to the next cell
#@markdown to load pre-trained weights

num_epochs = 100
# Build training checkpoints for reloading
# If used make sure you have the desired checkpoints available:
createCheckpoints = True
if createCheckpoints:
  checkpoint_dir = "/content/drive/MyDrive/ActionPrediction/checkpoints/"
  if os.path.exists(checkpoint_dir):
    checkpoint_counter = 0
    checkpoint_frequency = 10  # save model every 10 epochs
  else: 
    os.makedirs(checkpoint_dir)

min_loss = float('inf')
# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    model=nets,
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=nets.parameters(), 
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

globaliter = 0
with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # data normalized in dataset
                # device transfer

                        # Prepare data
                nimage_obs = batch['image'][:,:obs_horizon].to(device)
                nVelocity_obs = batch['velocity_obs'][:,:obs_horizon].to(device)
                nPosition_obs = batch['position_obs'][:,:obs_horizon].to(device)
                nAction_obs = batch['action_obs'][:,:obs_horizon].to(device)
                B = nVelocity_obs.shape[0]

                # Observation Features
                # nimage = nbatch['image'][:,:obs_horizon].to(device)
                # nvel = nbatch['velocity'][:,:obs_horizon].to(device)
                # nh_action = nbatch['h_action'][:,:obs_horizon].to(device)
                # # Action Features
                # ntrack = nbatch['track'].to(device)
                # naction = nbatch['action'].to(device)
          
                # B = nvel.shape[0] # Lenght of batch

                # encoder vision features
                image_features = nets['vision_encoder'](
                    nimage_obs.flatten(end_dim=1))
                image_features = image_features.reshape(
                    *nimage_obs.shape[:2],-1)
                # (B,obs_horizon,D)

                # concatenate vision feature and low-dim obs
                obs_features = torch.cat([image_features, nVelocity_obs, nPosition_obs, nAction_obs], dim=-1)
                obs_cond = obs_features.flatten(start_dim=1)
                # (B, obs_horizon * obs_dim)

                # Output: Get actions and future positions and velocities
                nAction = batch['actions_pred'].to(device)
                nVelocity = batch['velocities_pred'].to(device)
                nPosition = batch['positions_pred'].to(device)

                # Concatenate actions and future positions and velocities
                actions_to_pred = torch.cat([nAction, nPosition , nVelocity], dim=-1)
                
                # # concatenate vision feature and low-dim obs
                # obs_features = torch.cat([image_features, nvel, nh_action], dim=-1)
                # obs_cond = obs_features.flatten(start_dim=1)
                # (B, obs_horizon * obs_dim)
                # concatenate actions and track flags:
                #noutput = torch.cat([naction, ntrack], dim=-1)

                # sample noise to add to actions
                noise = torch.randn(actions_to_pred.shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, 
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_actions = noise_scheduler.add_noise(
                    actions_to_pred, noise, timesteps)
                
                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, global_cond=obs_cond)
                
                
                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(nets)

                # logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)

                with train_summary_writer.as_default():
                  tf.summary.scalar('loss', loss.item(), step=globaliter)
                globaliter += 1

        tglobal.set_postfix(loss=np.mean(epoch_loss))

        if epoch_idx > 0 and epoch_idx % checkpoint_frequency == 0 and createCheckpoints==True:
            # compute mean epoch loss
            mean_epoch_loss = np.mean(epoch_loss)
            # save model if mean epoch loss is smaller than the loss after the previous checkpoint
            if mean_epoch_loss < min_loss:
                min_loss = mean_epoch_loss
                checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_Epoch{epoch_idx}.pt")
                torch.save(ema.averaged_model.state_dict(), checkpoint_path)
                checkpoint_counter += 1
                print(f"Model saved at epoch {epoch_idx}")


# Weights of the EMA model
# is used for inference
ema_net = ema.averaged_model






Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

Batch:   0%|          | 0/39 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
#@markdown ### **Loading Pretrained Checkpoint**
#@markdown Set `load_pretrained = True` to load pretrained weights.

load_pretrained = False
checkpoint_num = 90
ema_net = nets

if load_pretrained:
  ckpt_path = f"/content/drive/MyDrive/checkpoints/checkpoint_Epoch{checkpoint_num}.pt"

  state_dict = torch.load(ckpt_path, map_location='cuda')
  ema_net.load_state_dict(state_dict)
  print('Pretrained weights loaded.')
else:
  print("Skipped pretrained weight loading.")

In [None]:
#@markdown ### **Inference**

# limit enviornment interaction to 200 steps before termination
max_steps = 1000
env = CarRacing(render_mode='state_pixels')

# get first observation
observation = env.reset(seed = 7000)
observation = observation[0]



# keep a queue lenght of observations horizon
obs_deque = collections.deque(
    [observation] * obs_horizon, maxlen=obs_horizon)

# save visualization and rewards
imgs = [env.render()]
rewards = list()
done = False
step_idx = 0


with tqdm(total=max_steps, desc="Eval CarRacingImageEnv") as pbar:
    while not done:
        B = 1


        # stack the last obs_horizon number of observations
        images = np.stack(x['image'] for x in obs_deque)
        images = np.moveaxis(images, -1,1) # normalize
        velocity =  np.stack(x['velocity'] for x in obs_deque)
        h_action = np.stack(x['h_action'] for x in obs_deque)
        # normalize observation
        ncar_vels = normalize_data(car_vel, stats=stats['car_vel'])
        # images are already normalized to [0,1]
        nimages = images

        # device transfer
        nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
        # (2,3,96,96)
        ncar_vels = torch.from_numpy(ncar_vels).to(device, dtype=torch.float32)
        # (2,1)
        ncar_vels = ncar_vels.unsqueeze(-1)


        # infer action
        with torch.no_grad():
            # get image features
            image_features = ema_net['vision_encoder'](nimages)
            # (2,512)
            # concat with low-dim observations
            obs_features = torch.cat([image_features, ncar_vels], dim=-1)



            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action
            
            #print(noisy_action.shape)
            

            # init scheduler
            noise_scheduler.set_timesteps(num_diffusion_iters)

            for k in noise_scheduler.timesteps:
                # predict noise
                noise_pred = ema_net['noise_pred_net'](
                    sample=naction, 
                    timestep=k,
                    global_cond=obs_cond
                )

                # inverse diffusion step (remove noise)
                naction = noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _ , info = env.step(action[i])
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render())

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break


In [None]:
#@markdown ### **Visualize**
# visualize
from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)