In [1]:
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import random
import math

Action = Tuple[int, int]
State = Tuple[int, int, int, int]

ACTIONS: List[Action] = [(dr, dc) for dr in (-1, 0, 1) for dc in (-1, 0, 1)]

@dataclass
class RacetrackSpec:
  v_max: int = 5
  slip_prob: float = 0.0
  crash_reward: int = -5
  step_reward: int = -1
  seed: Optional[int] = None


class RacetrackEnv:
  def __init__(self,
               grid, spec: RacetrackSpec = RacetrackSpec()):
    self.grid = [list(row) for row in grid]

    self.H = len(self.grid)

    self.W = len(self.grid[0]) if self.H > 0 else 0

    assert all(len(row) == self.W for row in self.grid), "Grid must be rectangular."

    self.spec = spec
    self.rng = random.Random(spec.seed)

    self.start_cells = [(r, c) for r in range(self.H) for c in range(self.W) if self.grid[r][c] == 'S']
    self.finish_cells = {(r, c) for r in range(self.H) for c in range(self.W) if self.grid[r][c] == 'F'}

    assert len(self.start_cells) > 0, "Grid must contain at least one 'S' start cell."
    assert len(self.finish_cells) > 0, "Grid must contain at least one 'F' finish cell."
    self.track_cells = {(r, c) for r in range(self.H) for c in range(self.W)
                            if self.grid[r][c] in ('.', 'S', 'F')}

    self.state: Optional[State] = None
  def reset(self) -> State:
    r, c = self.rng.choice(self.start_cells)
    # Common choice: start with zero velocity (but must not be both zero per rules).
    # We'll start at (0,1) or (1,0) randomly to satisfy "cannot both be zero".
    if self.rng.random() < 0.5:
        vr, vc = 0, 1
    else:
        vr, vc = 1, 0
    self.state = (r, c, vr, vc)
    return self.state

  def in_bounds(self, r: int, c: int) -> bool:
    return 0 <= r < self.H and 0 <= c < self.W

  def is_track(self, r: int, c: int) -> bool:
    return (r, c) in self.track_cells

  def is_finish(self, r: int, c: int) -> bool:
    return (r, c) in self.finish_cells
  def clip_velocity(self, vr: int, vc: int) -> Tuple[int, int]:
    vr = max(0, min(self.spec.v_max - 1, vr))
    vc = max(0, min(self.spec.v_max - 1, vc))
    if vr == 0 and vc == 0:

        if self.rng.random() < 0.5:
            vc = 1
        else:
            vr = 1
    return vr, vc

  def bresenham_cells(self, r0: int, c0: int, r1: int, c1: int) -> List[Tuple[int, int]]:
    """
    Cells traversed by the line segment from (r0,c0) to (r1,c1), inclusive.
    Using integer grid stepping so we detect wall/finish crossing, not just landing.
    """
    cells = []
    dr = abs(r1 - r0)
    dc = abs(c1 - c0)
    sr = 1 if r1 >= r0 else -1
    sc = 1 if c1 >= c0 else -1

    r, c = r0, c0
    cells.append((r, c))

    if dc == 0 and dr == 0:
        return cells

    if dc > dr:
        err = dc / 2.0
        while c != c1:
            c += sc
            err -= dr
            if err < 0:
                r += sr
                err += dc
            cells.append((r, c))
    else:
        err = dr / 2.0
        while r != r1:
            r += sr
            err -= dc
            if err < 0:
                c += sc
                err += dr
            cells.append((r, c))

    return cells

  def apply_slip(self, r: int, c: int) -> Tuple[int, int]:
    """
    With probability slip_prob, displace forward (down) or right by 1 extra cell.
    """
    if self.rng.random() >= self.spec.slip_prob:
        return r, c
    # "forward or to the right" for right turns. forward means increasing row.
    if self.rng.random() < 0.5:
        return r - 1, c
    else:
        return r, c + 1


  def step(self, action: Action) -> Tuple[State, int, bool, Dict]:
        assert self.state is not None, "Call reset() first."
        r, c, vr, vc = self.state
        dvr, dvc = action
        assert (dvr, dvc) in ACTIONS, "Invalid action."

        # Update and clip velocity
        nvr, nvc = self.clip_velocity(vr + dvr, vc + dvc)

        # Proposed move by velocity
        r2, c2 = r - nvr, c + nvc


        # Apply stochastic extra displacement
        r3, c3 = self.apply_slip(r2, c2)

        path = self.bresenham_cells(r, c, r3, c3)

        last_valid = (r, c)

        for (rr, cc) in path[1:]:
            # out of bounds or off-track => crash attempt
            if not self.in_bounds(rr, cc) or not self.is_track(rr, cc):
                # stay at last valid on-track cell
                sr, sc = self.rng.choice(self.start_cells)
                # reset velocity to small nonzero
                if self.rng.random() < 0.5:
                    nvr, nvc = 1, 0
                else:
                    nvr, nvc = 0, 1
                self.state = (sr, sc, nvr, nvc)
                return self.state, self.spec.crash_reward, False, {"event": "crash"}

            # on track
            last_valid = (rr, cc)

            # finish crossing
            if self.is_finish(rr, cc):
                self.state = (rr, cc, nvr, nvc)
                return self.state, self.spec.step_reward, True, {"event": "finish"}

        # Normal on-track step (no finish, no crash)
        self.state = (r3, c3, nvr, nvc)
        return self.state, self.spec.step_reward, False, {"event": "move"}

  def render(self, state: Optional[State] = None) -> str:
    if state is None:
        state = self.state
    if state is None:
        return "\n".join("".join(row) for row in self.grid)

    r, c, vr, vc = state
    out = [row[:] for row in self.grid]
    out[r][c] = 'A'
    return "\n".join("".join(row) for row in out)
def make_example_track() -> List[str]:
    """
    A simple right-turn track you can replace with your own.
    '#' walls, '.' track, S start line, F finish line
    """
    return [
        "####################",
        "###########.....FFFF",
        "###########.....FFFF",
        "###########.....FFFF",
        "###########.........",
        "###########.........",
        "#######.............",
        "#######.............",
        "SSSSSS..............",
        "SSSSSS..............",
        "####################",
    ]



In [2]:
env = RacetrackEnv(make_example_track(), RacetrackSpec(seed =3))
s = env.reset()
print(env.render())

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSASS..............
SSSSSS..............
####################


In [3]:
for _ in range(10):
  a = random.choice(ACTIONS)
  s, r, done, info = env.step(a)
  print("\n", info, "reward:", r, "state:", s)
  print(env.render())



 {'event': 'crash'} reward: -5 state: (9, 3, 1, 0)
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSSSS..............
SSSASS..............
####################

 {'event': 'move'} reward: -1 state: (8, 3, 1, 0)
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSASS..............
SSSSSS..............
####################

 {'event': 'crash'} reward: -5 state: (9, 1, 1, 0)
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSSSS..............
SASSSS..............
####################

 {'event': 'move'} reward: -1 state: (8, 1, 1, 0)
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########......

## Common helpers

In [4]:
from collections import defaultdict
import random

def eps_greedy_action(Q, s, epsilon, rng: random.Random) :
  if rng.random() < epsilon :
    return rng.randrange(9)

  qs = Q[s]
  best = max(range(9) , key = lambda a: qs[a])
  return best

def init_Q() :
  return defaultdict(lambda : [0.0] * 9 )

## Monte Carlo Control (first visit)

In [5]:
def mc_control (env, episodes = 200_000,
                gamma = 1.0, epsilon = 0.1, seed = 0, max_steps = 10_000) :
  rng = random.Random(seed)
  Q = init_Q()

  returns_sum = defaultdict(lambda : [0.0] * 9 )
  returns_cnt = defaultdict(lambda : [0] * 9 )

  for ep in range(episodes) :
    s = env.reset()
    episode = []
    done = False

    for t in range(max_steps) :
      a = eps_greedy_action(Q, s, epsilon, rng)
      s2, r, done, info = env.step(ACTIONS[a])
      episode.append((s, a, r) )

      s = s2
      if done :
        break
    G = 0.0

    seen  = set()
    for (s, a, r) in reversed(episode):
      G = gamma* G + r
      if (s, a) in seen:
        continue
      seen.add((s, a))

      returns_sum[s][a] += G
      returns_cnt[s][a] += 1

      Q[s][a] = returns_sum[s][a] / returns_cnt[s][a]
  return Q



## TD Control (SARSA)

In [6]:
def sarsa_control(env, episodes = 200_000, alpha = 0.1, gamma = 1.0, epsilon = 0.1, seed = 0, max_steps = 10_000):
  rng = random.Random(seed)

  Q = init_Q()

  for ep in range(episodes) :
    s = env.reset()
    a = eps_greedy_action(Q, s, epsilon, rng)
    done = False

    for t in range(max_steps):
      s2, r, done, info = env.step(ACTIONS[a])
      if done:
        Q[s][a] += alpha * (r- Q[s][a])
        break
      a2 = eps_greedy_action(Q, s2, epsilon, rng)
      Q[s][a] += alpha * (r + gamma * Q[s2][a2] - Q[s][a])
      s = s2
      a = a2
  return Q

## TD Control(Q Learning)

In [7]:
def q_learning_control(env, episodes = 200_000, alpha = 0.1, gamma = 1.0, epsilon = 0.1, seed= 0 , max_steps = 15_000):
  rng = random.Random(seed)

  Q = init_Q()
  for episode in range(episodes):
    s = env.reset()
    a = eps_greedy_action(Q, s, epsilon, rng)
    done = False
    for t in range(max_steps):
      s2, r, done, info = env.step(ACTIONS[a])
      if done:
        Q[s][a] += alpha * (r- Q[s][a])
        break
      a2 = eps_greedy_action(Q, s2, epsilon, rng)
      Q[s][a] += alpha * (r + gamma * max(Q[s2]) - Q[s][a])
      s = s2
      a = a2
  return Q

## Extract greedy policies and print tracjectories

In [8]:
def greedy_policy(Q, s):
  qs = Q[s]
  return max(range(9), key = lambda a: qs[a])

def rollout(env, Q, start_state = None, max_steps = 500):
  if start_state is None:
    s = env.reset()
  else:
    env.state = start_state
    s = start_state
  traj = [s]
  total = 0
  for _ in range(max_steps) :
    a = greedy_policy(Q, s)
    s, r, done, info = env.step(ACTIONS[a])
    total += r
    traj.append(s)
    if done:
      break
  return traj, total
def show_trajectory_on_grid(env, traj):
  # just render successive positions
  frames = []
  for s in traj:
      frames.append(env.render(s))
  return "\n\n".join(frames)


In [9]:
grid = make_example_track()
env = RacetrackEnv(grid, RacetrackSpec(seed = 0))

In [10]:
%%time
Q_mc = mc_control(env, episodes = 50_000, epsilon = 0.2, seed = 2)

CPU times: user 2.29 s, sys: 2.31 ms, total: 2.29 s
Wall time: 2.29 s


In [11]:
env = RacetrackEnv(grid, RacetrackSpec(seed=32))
for i in range(3):
  traj, ret = rollout(env, Q_mc, max_steps=200)
  # ok, steps, ret = rollout(env, Q_td, 2000)
  print("Trajectory", i, "return:", ret, "len:", len(traj))
  print(show_trajectory_on_grid(env, traj[:10]))  # first 10 frames only

  # print("finished?", ok, "steps", steps, "return", ret)
  print("----")


Trajectory 0 return: -7 len: 8
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SASSSS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSASSS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSSAS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSSSSA.............
SSSSSS..............
####################

####################
###########.....FFFF

In [12]:
%%time
env = RacetrackEnv(grid, RacetrackSpec(seed=0))
Q_td = sarsa_control(env, episodes=50_000, alpha=0.2, epsilon=0.2, seed=1)

CPU times: user 1.67 s, sys: 2.11 ms, total: 1.67 s
Wall time: 1.67 s


In [13]:
env = RacetrackEnv(grid, RacetrackSpec(seed=123))
for i in range(3):
  traj, ret = rollout(env, Q_td, max_steps=200)
  # ok, steps, ret = rollout(env, Q_td, 2000)
  print("Trajectory", i, "return:", ret, "len:", len(traj))
  print(show_trajectory_on_grid(env, traj[:10]))  # first 10 frames only

  # print("finished?", ok, "steps", steps, "return", ret)
  print("----")


Trajectory 0 return: -5 len: 6
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
ASSSSS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSASSS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSSSA..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######..A..........
SSSSSS..............
SSSSSS..............
####################

####################
###########.....FFFF

In [14]:
%%time
env = RacetrackEnv(grid, RacetrackSpec(seed=0))
Q_tdq = q_learning_control(env, episodes=50_000, alpha=0.2, epsilon=0.2, seed=1)

CPU times: user 1.94 s, sys: 1.96 ms, total: 1.94 s
Wall time: 1.95 s


In [15]:
env = RacetrackEnv(grid, RacetrackSpec(seed=123))
for i in range(3):
  traj, ret = rollout(env, Q_tdq, max_steps=200)
  # ok, steps, ret = rollout(env, Q_td, 2000)
  print("Trajectory", i, "return:", ret, "len:", len(traj))
  print(show_trajectory_on_grid(env, traj[:10]))  # first 10 frames only

  # print("finished?", ok, "steps", steps, "return", ret)
  print("----")


Trajectory 0 return: -5 len: 6
####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
ASSSSS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSASSS..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.............
SSSSSA..............
SSSSSS..............
####################

####################
###########.....FFFF
###########.....FFFF
###########.....FFFF
###########.........
###########.........
#######.............
#######.A...........
SSSSSS..............
SSSSSS..............
####################

####################
###########.....FFFF

## Deep Q Networks

In [16]:
import random
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def encode_state(env, s):
    r, c, vr, vc = s
    return np.array([
        r / (env.H - 1),
        c / (env.W - 1),
        vr / (env.spec.v_max - 1),
        vc / (env.spec.v_max - 1),
    ], dtype=np.float32)

class QNet(nn.Module):
    def __init__(self, in_dim=4, out_dim=9):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim),
        )

    def forward(self, x):
        return self.net(x)

class ReplayBuffer:
    def __init__(self, capacity=200_000):
        self.buf = deque(maxlen=capacity)

    def push(self, s, a, r, s2, done):
        self.buf.append((s, a, r, s2, done))

    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        s, a, r, s2, done = zip(*batch)
        return np.stack(s), np.array(a), np.array(r, dtype=np.float32), np.stack(s2), np.array(done, dtype=np.float32)

    def __len__(self):
        return len(self.buf)

@torch.no_grad()
def select_action(qnet, state_vec, epsilon):
    if random.random() < epsilon:
        return random.randrange(9)
    x = torch.from_numpy(state_vec).unsqueeze(0).to(device)
    q = qnet(x)[0]
    return int(torch.argmax(q).item())

def dqn_train(
    env,
    episodes=50_000,
    gamma=1.0,
    lr=1e-3,
    batch_size=256,
    buffer_size=200_000,
    min_buffer=5_000,
    target_update_every=1000,   # steps
    epsilon_start=1.0,
    epsilon_end=0.05,
    epsilon_decay_steps=200_000,
    max_steps_per_ep=10_000,
    seed=0
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    qnet = QNet().to(device)
    target = QNet().to(device)
    target.load_state_dict(qnet.state_dict())
    target.eval()

    opt = optim.Adam(qnet.parameters(), lr=lr)
    replay = ReplayBuffer(buffer_size)

    steps = 0

    def epsilon_by_step(t):
        if t >= epsilon_decay_steps:
            return epsilon_end
        frac = t / epsilon_decay_steps
        return epsilon_start + frac * (epsilon_end - epsilon_start)

    for ep in range(episodes):
        s = env.reset()
        sv = encode_state(env, s)
        ep_return = 0.0

        for t in range(max_steps_per_ep):
            eps = epsilon_by_step(steps)
            a = select_action(qnet, sv, eps)

            s2, r, done, info = env.step(ACTIONS[a])
            sv2 = encode_state(env, s2)

            replay.push(sv, a, r, sv2, done)

            sv = sv2
            ep_return += r
            steps += 1

            # Learn
            if len(replay) >= min_buffer:
                bs, ba, br, bs2, bdone = replay.sample(batch_size)

                bs_t   = torch.from_numpy(bs).to(device)
                ba_t   = torch.from_numpy(ba).long().to(device)
                br_t   = torch.from_numpy(br).to(device)
                bs2_t  = torch.from_numpy(bs2).to(device)
                done_t = torch.from_numpy(bdone).to(device)

                q_sa = qnet(bs_t).gather(1, ba_t.view(-1, 1)).squeeze(1)

                with torch.no_grad():
                    # Vanilla DQN target
                    max_q_s2 = target(bs2_t).max(dim=1).values
                    y = br_t + gamma * (1.0 - done_t) * max_q_s2

                loss = nn.MSELoss()(q_sa, y)

                opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(qnet.parameters(), 10.0)
                opt.step()

                # Target network update
                if steps % target_update_every == 0:
                    target.load_state_dict(qnet.state_dict())

            if done:
                break

        if (ep + 1) % 200 == 0:
            print(f"ep {ep+1}  return {ep_return:.1f}  eps {epsilon_by_step(steps):.3f}  buffer {len(replay)}")

    return qnet


In [17]:
%%time
grid = make_example_track()
env = RacetrackEnv(grid, RacetrackSpec(seed=0))
qnet = dqn_train(env, episodes=20_000)

# greedy rollout
env = RacetrackEnv(grid, RacetrackSpec(seed=123))
s = env.reset()
sv = encode_state(env, s)
for _ in range(60):
    a = select_action(qnet, sv, epsilon=0.0)
    s, r, done, info = env.step(ACTIONS[a])
    sv = encode_state(env, s)
    print(info, r, s)
    print(env.render())
    if done:
        break


ep 200  return -178.0  eps 0.946  buffer 11426
ep 400  return -45.0  eps 0.908  buffer 19283
ep 600  return -33.0  eps 0.881  buffer 25108
ep 800  return -82.0  eps 0.855  buffer 30541
ep 1000  return -56.0  eps 0.832  buffer 35280
ep 1200  return -88.0  eps 0.810  buffer 39901
ep 1400  return -6.0  eps 0.791  buffer 43985
ep 1600  return -13.0  eps 0.773  buffer 47731
ep 1800  return -31.0  eps 0.755  buffer 51545
ep 2000  return -25.0  eps 0.739  buffer 55034
ep 2200  return -29.0  eps 0.722  buffer 58569
ep 2400  return -13.0  eps 0.707  buffer 61697
ep 2600  return -15.0  eps 0.693  buffer 64725
ep 2800  return -10.0  eps 0.678  buffer 67685
ep 3000  return -20.0  eps 0.665  buffer 70616
ep 3200  return -68.0  eps 0.652  buffer 73254
ep 3400  return -10.0  eps 0.640  buffer 75815
ep 3600  return -7.0  eps 0.629  buffer 78160
ep 3800  return -15.0  eps 0.616  buffer 80812
ep 4000  return -15.0  eps 0.605  buffer 83164
ep 4200  return -6.0  eps 0.593  buffer 85590
ep 4400  return -5.