In [2]:
# Smart City Traffic Control using DQN
# ===================================
# FINAL VERSION – CLEAR ROAD VISUALIZATION + RL LEARNING + LEARNING CURVE

import pygame
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

pygame.init()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- CONFIG ----------------
WIDTH, HEIGHT = 1000, 885
FPS = 60
ROAD_WIDTH = 160
CAR_SIZE = 10
GAP = 14

# Colors
GREEN = (0, 220, 0)
RED = (220, 0, 0)
WHITE = (240, 240, 240)
BLUE = (0, 150, 255)
EMERGENCY_COLOR = (255, 60, 60)

GRASS = (40, 130, 40)
ASPHALT = (50, 50, 50)
LANE_WHITE = (220, 220, 220)
STOP_LINE = (255, 255, 255)
CURB = (90, 90, 90)

CENTER = WIDTH // 2
STOP = ROAD_WIDTH // 2

screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Smart Traffic Control – DQN (Clear Roads)")
clock = pygame.time.Clock()
font = pygame.font.SysFont(None, 22)

# ---------------- TRAFFIC PHASES ----------------
PHASES = [
    ['N'], ['S'], ['E'], ['W'],
    ['N', 'S'], ['E', 'W']
]

# ---------------- DQN ----------------
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, len(PHASES))
        )

    def forward(self, x):
        return self.net(x)

policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
memory = deque(maxlen=6000)

EPSILON = 1.0
EPSILON_MIN = 0.4
EPSILON_DECAY = 0.995
GAMMA = 0.95
BATCH = 64

# ---------------- CAR ----------------
class Car:
    def __init__(self, d):
        self.dir = d
        self.speed = 2.5
        self.wait = 0
        self.committed = False
        self.emergency = random.random() < 0.05
        self.passed = False

        if d == 'N':
            self.x, self.y = CENTER - 40, -30
        if d == 'S':
            self.x, self.y = CENTER + 30, HEIGHT + 30
        if d == 'E':
            self.x, self.y = WIDTH + 30, CENTER - 40
        if d == 'W':
            self.x, self.y = -30, CENTER + 30

    def in_intersection(self):
        return CENTER - STOP < self.x < CENTER + STOP and CENTER - STOP < self.y < CENTER + STOP

    def forward(self):
        if self.dir == 'N': self.y += self.speed
        if self.dir == 'S': self.y -= self.speed
        if self.dir == 'E': self.x -= self.speed
        if self.dir == 'W': self.x += self.speed

    def draw(self):
        pygame.draw.rect(
            screen,
            EMERGENCY_COLOR if self.emergency else BLUE,
            (self.x, self.y, CAR_SIZE, CAR_SIZE)
        )

# ---------------- LANE UPDATE ----------------
def update_lane(cars, d, green):
    if d == 'N':
        cars.sort(key=lambda c: -c.y)
        stop, axis, sign = CENTER - STOP - CAR_SIZE, 'y', 1
    elif d == 'S':
        cars.sort(key=lambda c: c.y)
        stop, axis, sign = CENTER + STOP, 'y', -1
    elif d == 'E':
        cars.sort(key=lambda c: c.x)
        stop, axis, sign = CENTER + STOP, 'x', -1
    else:
        cars.sort(key=lambda c: -c.x)
        stop, axis, sign = CENTER - STOP - CAR_SIZE, 'x', 1

    front_pos = None
    for c in cars:
        if c.committed:
            c.forward()
            continue

        if c.in_intersection():
            c.committed = True
            c.forward()
            continue

        pos = getattr(c, axis)
        limit = stop if front_pos is None else front_pos - sign * (GAP + CAR_SIZE)

        if green or c.emergency or sign * pos < sign * limit:
            c.forward()
        else:
            c.wait += 1

        front_pos = getattr(c, axis)

# ---------------- DQN HELPERS ----------------
def select_action(state):
    if random.random() < EPSILON:
        return random.randrange(len(PHASES))
    with torch.no_grad():
        return torch.argmax(policy_net(state)).item()

def optimize():
    if len(memory) < BATCH:
        return

    batch = random.sample(memory, BATCH)
    s, a, r, ns = zip(*batch)

    s = torch.tensor(s, dtype=torch.float32).to(device)
    ns = torch.tensor(ns, dtype=torch.float32).to(device)
    a = torch.tensor(a).unsqueeze(1).to(device)
    r = torch.tensor(r, dtype=torch.float32).to(device)

    q = policy_net(s).gather(1, a).squeeze()
    next_q = target_net(ns).max(1)[0].detach()
    loss = nn.MSELoss()(q, r + GAMMA * next_q)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# ---------------- EPISODE SYSTEM ----------------
DAY_LENGTH = FPS * 120
day = 1

daily_reward = 0
daily_steps = 0
daily_wait = 0

episode_reward = 0
episode_wait = 0
episode_steps = 0
prev_total_queue = 0
prev_total_wait = 0

def end_of_day_reset(cars):
    for c in cars:
        c.wait = 0

# ---------------- LEARNING CURVE ----------------
learning_curve_width = 300
learning_curve_height = 120
curve_surface = pygame.Surface((learning_curve_width, learning_curve_height))
curve_surface.fill((30, 30, 30))

reward_history = []
wait_history = []

# ---------------- MAIN LOOP ----------------
cars = []
step = 0
active_phase = 0
last_reward = 0.0

running = True
while running:
    clock.tick(FPS)
    screen.fill(GRASS)

    for e in pygame.event.get():
        if e.type == pygame.QUIT:
            running = False

    # -------- DRAW ROADS --------
    pygame.draw.rect(screen, ASPHALT, (CENTER - ROAD_WIDTH // 2, 0, ROAD_WIDTH, HEIGHT))
    pygame.draw.rect(screen, ASPHALT, (0, CENTER - ROAD_WIDTH // 2, WIDTH, ROAD_WIDTH))
    pygame.draw.rect(screen, CURB, (CENTER - ROAD_WIDTH // 2 - 6, 0, 6, HEIGHT))
    pygame.draw.rect(screen, CURB, (CENTER + ROAD_WIDTH // 2, 0, 6, HEIGHT))
    pygame.draw.rect(screen, CURB, (0, CENTER - ROAD_WIDTH // 2 - 6, WIDTH, 6))
    pygame.draw.rect(screen, CURB, (0, CENTER + ROAD_WIDTH // 2, WIDTH, 6))
    dash, gap = 20, 15
    for y in range(0, HEIGHT, dash + gap):
        pygame.draw.rect(screen, LANE_WHITE, (CENTER - ROAD_WIDTH // 4, y, 4, dash))
        pygame.draw.rect(screen, LANE_WHITE, (CENTER + ROAD_WIDTH // 4, y, 4, dash))
    for x in range(0, WIDTH, dash + gap):
        pygame.draw.rect(screen, LANE_WHITE, (x, CENTER - ROAD_WIDTH // 4, dash, 4))
        pygame.draw.rect(screen, LANE_WHITE, (x, CENTER + ROAD_WIDTH // 4, dash, 4))
    pygame.draw.rect(screen, STOP_LINE, (CENTER - ROAD_WIDTH // 2, CENTER - STOP - 6, ROAD_WIDTH, 6))
    pygame.draw.rect(screen, STOP_LINE, (CENTER - ROAD_WIDTH // 2, CENTER + STOP, ROAD_WIDTH, 6))
    pygame.draw.rect(screen, STOP_LINE, (CENTER - STOP - 6, CENTER - ROAD_WIDTH // 2, 6, ROAD_WIDTH))
    pygame.draw.rect(screen, STOP_LINE, (CENTER + STOP, CENTER - ROAD_WIDTH // 2, 6, ROAD_WIDTH))
    pygame.draw.rect(screen, (80, 80, 80), (CENTER - STOP, CENTER - STOP, 2 * STOP, 2 * STOP), 2)

    # -------- SPAWN CARS --------
    if random.random() < 0.05:
        cars.append(Car(random.choice(['N', 'S', 'E', 'W'])))

    # Organize lanes
    lanes = {d: [] for d in ['N', 'S', 'E', 'W']}
    for c in cars:
        lanes[c.dir].append(c)
    queues = {d: len(lanes[d]) for d in lanes}

    state = np.array([queues[d] for d in ['N', 'S', 'E', 'W']], dtype=np.float32)
    state_t = torch.tensor(state).unsqueeze(0).to(device)

    # -------- SELECT PHASE --------
    if step % 60 == 0:
        active_phase = select_action(state_t)

    # -------- UPDATE LANES --------
    for d in lanes:
        update_lane(lanes[d], d, d in PHASES[active_phase])

    # -------- REWARD CALCULATION --------
    reward = 0
    if step % 60 == 0:
        next_state = np.array([queues[d] for d in ['N','S','E','W']], dtype=np.float32)
        total_queue = sum(queues.values())
        total_wait = sum(c.wait for c in cars)

        queue_reduction = prev_total_queue - total_queue
        wait_reduction = prev_total_wait - total_wait

        last_reward = (2.0 * queue_reduction + 1.0 * wait_reduction) / max(1, len(cars))
        reward = last_reward

        memory.append((state, active_phase, reward, next_state))
        prev_total_queue = total_queue
        prev_total_wait = total_wait
        state = next_state.copy()

    # -------- TRACK DAILY STATISTICS --------
    daily_reward += reward
    daily_steps += 1
    daily_decisions = 0
    daily_wait += sum(c.wait for c in cars)

    episode_reward += reward
    episode_wait += sum(c.wait for c in cars)
    episode_steps += 1

    optimize()

    # -------- SIGNALS --------
    for d, pos in {'N': (CENTER - 35, CENTER - STOP - 35),
                   'S': (CENTER + 35, CENTER + STOP + 35),
                   'E': (CENTER + STOP + 35, CENTER - 35),
                   'W': (CENTER - STOP - 35, CENTER + 35)}.items():
        pygame.draw.circle(screen, GREEN if d in PHASES[active_phase] else RED, pos, 10)

    # -------- PANEL --------
    panel = [
        f"Day: {day}",
        f"Epsilon: {EPSILON:.2f}",
        f"Cars: {len(cars)}",
        f"Avg Wait: {sum(c.wait for c in cars)/max(1,len(cars)):.1f}",
        f"Last Reward: {last_reward:.2f}"
    ]
    for i, t in enumerate(panel):
        screen.blit(font.render(t, True, WHITE), (10, 10 + i * 18))

    for c in cars:
        c.draw()

    # -------- DAY RESET & EPSILON DECAY --------
    if step % DAY_LENGTH == 0 and step > 0:
        avg_daily_reward = daily_reward / max(1, daily_decisions)
        avg_daily_wait = daily_wait / max(1, len(cars))
        print(f"Day {day} | Avg Daily Reward: {avg_daily_reward:.2f} | Avg Wait: {avg_daily_wait:.2f} | Epsilon: {EPSILON:.2f}")

        reward_history.append(avg_daily_reward)
        wait_history.append(avg_daily_wait)

        day += 1
        EPSILON = max(EPSILON * 0.9, EPSILON_MIN)
        end_of_day_reset(cars)

        daily_reward = 0
        daily_steps = 0
        daily_wait = 0

    pygame.display.flip()
    step += 1

    # -------- LEARNING CURVE DRAWING --------
    curve_surface.fill((30, 30, 30))
    if len(reward_history) > 1:
        # Reward curve (green)
        for i in range(1, len(reward_history)):
            x1 = int((i-1) * learning_curve_width / max(1, len(reward_history)-1))
            y1 = int(learning_curve_height - reward_history[i-1] * learning_curve_height)
            x2 = int(i * learning_curve_width / max(1, len(reward_history)-1))
            y2 = int(learning_curve_height - reward_history[i] * learning_curve_height)
            pygame.draw.line(curve_surface, (0, 220, 0), (x1, y1), (x2, y2), 2)
        # Wait curve (red)
        max_wait = max(wait_history) if wait_history else 1
        for i in range(1, len(wait_history)):
            x1 = int((i-1) * learning_curve_width / max(1, len(wait_history)-1))
            y1 = int(learning_curve_height - (wait_history[i-1]/max_wait) * learning_curve_height)
            x2 = int(i * learning_curve_width / max(1, len(wait_history)-1))
            y2 = int(learning_curve_height - (wait_history[i]/max_wait) * learning_curve_height)
            pygame.draw.line(curve_surface, (220, 0, 0), (x1, y1), (x2, y2), 2)
    pygame.draw.rect(curve_surface, WHITE, (0, 0, learning_curve_width, learning_curve_height), 1)
    screen.blit(curve_surface, (WIDTH - learning_curve_width - 10, 10))

pygame.quit()

