In [34]:
#Author: Sreejeet Maity

import math
import os
import sys
import time
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import numpy as np

try:
    import pygame
except Exception as e:
    raise SystemExit("Please `pip install pygame`") from e

try:
    from PIL import Image, ImageDraw
except Exception as e:
    raise SystemExit("Please `pip install pillow`") from e


# ----------------------------- Assets -----------------------------

ASSETS_DIR = "assets"
DRONE_PATH = os.path.join(ASSETS_DIR, "drone.png")
BUILDING_PATH = os.path.join(ASSETS_DIR, "building.png")

def ensure_assets():
    os.makedirs(ASSETS_DIR, exist_ok=True)
    if not os.path.exists(DRONE_PATH):
        img = Image.new("RGBA", (256, 256), (0, 0, 0, 0))
        d = ImageDraw.Draw(img)
        d.ellipse((88, 88, 168, 168), fill=(180, 180, 185, 255))
        d.rectangle((126, 20, 130, 236), fill=(160, 160, 165, 255))
        d.rectangle((20, 126, 236, 130), fill=(160, 160, 165, 255))
        for cx, cy in [(128, 32), (128, 224), (32, 128), (224, 128)]:
            d.ellipse((cx - 24, cy - 24, cx + 24, cy + 24),
                      outline=(90, 90, 95, 255), width=6, fill=(210, 210, 215, 255))
        img.save(DRONE_PATH)
    if not os.path.exists(BUILDING_PATH):
        img = Image.new("RGBA", (256, 256), (0, 0, 0, 0))
        d = ImageDraw.Draw(img)
        d.rectangle((70, 60, 186, 210), fill=(70, 120, 200, 255))
        d.rectangle((90, 30, 166, 60), fill=(60, 100, 180, 255))
        for r in range(4):
            for c in range(3):
                x0 = 82 + c * 28
                y0 = 74 + r * 32
                d.rectangle((x0, y0, x0 + 20, y0 + 20), fill=(230, 240, 255, 255))
        d.rectangle((122, 170, 134, 210), fill=(40, 70, 130, 255))
        img.save(BUILDING_PATH)


# ----------------------------- Configs -----------------------------

@dataclass
class WorldCfg:
    grid: int = 10
    n_agents: int = 4
    step_penalty: float = 0.01
    goal_reward: float = 1.0
    max_steps: int = 200
    cell_px: int = 36
    margin_px: int = 22
    bg_color: Tuple[int, int, int] = (34, 39, 46)   # overridden per world


# ----------------------------- World (Env) -----------------------------

class World:
    """
    Tiny grid world: n_agents drones try to sit on their own goals (buildings).
    Centralized control (we'll use a policy per world). Standalone for the demo.
    """
    def __init__(self, cfg: WorldCfg, seed: Optional[int] = None):
        self.cfg = cfg
        self.rng = np.random.default_rng(seed)
        self.t = 0
        self.agents: List[Tuple[int, int]] = []
        self.goals:  List[Tuple[int, int]] = []
        self.reset()

    def reset(self):
        self.t = 0
        g = self.cfg.grid
        used = set()
        A, G = [], []
        for _ in range(self.cfg.n_agents):
            a = self._sample_empty(used, g)
            used.add(a); A.append(a)
            b = self._sample_empty(used, g)
            used.add(b); G.append(b)
        self.agents, self.goals = A, G

    def _sample_empty(self, used: set, g: int) -> Tuple[int, int]:
        while True:
            p = (int(self.rng.integers(0, g)), int(self.rng.integers(0, g)))
            if p not in used:
                return p

    def step(self, actions: List[int]) -> float:
        """
        actions: list of ints in {0:stay,1:up,2:down,3:left,4:right}, one per agent
        returns reward (float)
        """
        self.t += 1
        g = self.cfg.grid
        A = list(self.agents)
        for i, a in enumerate(actions):
            x, y = A[i]
            if a == 1:   y = max(0, y - 1)
            elif a == 2: y = min(g - 1, y + 1)
            elif a == 3: x = max(0, x - 1)
            elif a == 4: x = min(g - 1, x + 1)
            A[i] = (x, y)
        self.agents = A

        rew = 0.0
        for i in range(self.cfg.n_agents):
            if self.agents[i] == self.goals[i]:
                rew += self.cfg.goal_reward
                # relocate goal to a new empty cell
                used = set(self.agents) | set(self.goals)
                self.goals[i] = self._sample_empty(used, g)
        rew -= self.cfg.step_penalty
        return rew

    def done(self) -> bool:
        return self.t >= self.cfg.max_steps


# ----------------------------- Policies (dummy) -----------------------------

class RandomPolicy:
    def __init__(self, n_agents: int, rng: Optional[np.random.Generator] = None):
        self.n_agents = n_agents
        self.rng = rng or np.random.default_rng()

    def act(self, obs: np.ndarray) -> List[int]:
        return list(self.rng.integers(0, 5, size=self.n_agents))

    def weights(self) -> Dict[str, np.ndarray]:
        return {"bias": np.zeros(1, dtype=np.float32)}

    def delta(self) -> Dict[str, np.ndarray]:
        return {"bias": np.array([0.0], dtype=np.float32)}

    def load(self, weights: Dict[str, np.ndarray]):
        pass


# ----------------------------- Federated Orchestration -----------------------------

class FedClient:
    def __init__(self, idx: int, world: World, policy: RandomPolicy):
        self.idx = idx
        self.world = world
        self.policy = policy
        self.last_reward = 0.0

    def rollout_and_update(self, K: int = 48):
        tot = 0.0
        for _ in range(K):
            obs = self.observe()
            actions = self.policy.act(obs)
            tot += self.world.step(actions)
            if self.world.done():
                self.world.reset()
        self.last_reward = tot / max(1, K)
        return self.policy.delta()

    def observe(self) -> np.ndarray:
        g = self.world.cfg.grid
        arr = []
        for i in range(self.world.cfg.n_agents):
            ax, ay = self.world.agents[i]
            gx, gy = self.world.goals[i]
            arr += [ax/(g-1), ay/(g-1), gx/(g-1), gy/(g-1)]
        return np.array(arr, dtype=np.float32)


class FedServer:
    def __init__(self, clients: List[FedClient]):
        self.clients = clients
        self.global_weights = clients[0].policy.weights()
        self.round = 0

    def aggregate(self, deltas: List[Dict[str, np.ndarray]]):
        keys = deltas[0].keys()
        agg = {}
        for k in keys:
            stack = np.stack([d[k] for d in deltas], axis=0)
            agg[k] = stack.mean(axis=0)
        return agg

    def apply(self, delta: Dict[str, np.ndarray]):
        for k in self.global_weights:
            self.global_weights[k] = self.global_weights[k] + delta[k]

    def broadcast(self):
        for c in self.clients:
            c.policy.load(self.global_weights)


# ----------------------------- Rendering (mosaic on top + tower beneath) -----------------------------

class Renderer:
    def __init__(self, worlds: List[World], cols: int = 3, k_local: int = 48):
        self.worlds = worlds
        self.cols = len(worlds)  # force one row
        self.rows = 1
        self.k_local = int(k_local)   # used in legend

        # Determine panel size from first world
        w0 = worlds[0].cfg
        self.panel_size = w0.margin_px*2 + w0.grid*w0.cell_px
        self.tower_side = 220  # height of the tower HUD row
        self.pad = 24

        # Window layout: env row on top, tower row beneath (full width)
        self.win_w = self.cols*self.panel_size + (self.cols+1)*self.pad
        self.win_h = self.panel_size + 3*self.pad + self.tower_side

        pygame.init()
        self.screen = pygame.display.set_mode((self.win_w, self.win_h))
        pygame.display.set_caption("Federated Worlds with Central Broadcasting Tower")
        self.clock = pygame.time.Clock()
        self.font = pygame.font.SysFont("consolas", 16)
        self.bigfont = pygame.font.SysFont("consolas", 20, bold=True)

        ensure_assets()
        self._sprite_cache: Dict[int, Tuple[pygame.Surface, pygame.Surface]] = {}

        # world background color palette (cycled)
        self.palette = [
            (28, 35, 43),
            (33, 47, 60),
            (28, 49, 39),
            (48, 36, 54),
            (54, 48, 36),
            (36, 46, 60),
            (60, 36, 46),
            (39, 52, 28),
            (46, 36, 60),
        ]

        # Watermark config
        self.watermark_text = "sreejeetm1729"
        self.watermark_font = pygame.font.SysFont("consolas", 18, bold=True)

    def _get_sprites(self, cell_px: int) -> Tuple[pygame.Surface, pygame.Surface]:
        if cell_px in self._sprite_cache:
            return self._sprite_cache[cell_px]

        # Keep sprites well inside the cell to avoid overlap with grid lines
        size = max(8, int(round(cell_px * 0.65)))  # ~65% of cell
        def load_and_scale(path: str) -> pygame.Surface:
            img = pygame.image.load(path).convert_alpha()
            return pygame.transform.smoothscale(img, (size, size))

        drone = load_and_scale(DRONE_PATH)
        building = load_and_scale(BUILDING_PATH)
        self._sprite_cache[cell_px] = (drone, building)
        return drone, building

    def world_rect(self, idx: int) -> pygame.Rect:
        # One row of worlds at the TOP
        c = idx
        x = self.pad + c*(self.panel_size + self.pad)
        y = self.pad
        return pygame.Rect(x, y, self.panel_size, self.panel_size)

    def tower_rect(self) -> pygame.Rect:
        # Full-width tower row BENEATH the env row
        x = self.pad
        y = self.pad + self.panel_size + self.pad
        w = self.win_w - 2*self.pad
        h = self.tower_side
        return pygame.Rect(x, y, w, h)

    def draw_grid(self, surf: pygame.Surface, cfg: WorldCfg):
        m = cfg.margin_px
        g = cfg.grid
        cs = cfg.cell_px
        w = surf.get_width()
        h = surf.get_height()
        surf.fill(cfg.bg_color)
        grid_color = (65, 73, 82)
        for i in range(g+1):
            x = m + i*cs
            y = m + i*cs
            pygame.draw.line(surf, grid_color, (m, y), (w - m, y), 1)
            pygame.draw.line(surf, grid_color, (x, m), (x, h - m), 1)

    def cell_rect(self, cfg: WorldCfg, x: int, y: int) -> pygame.Rect:
        return pygame.Rect(cfg.margin_px + x*cfg.cell_px,
                           cfg.margin_px + y*cfg.cell_px,
                           cfg.cell_px, cfg.cell_px)

    def draw_world(self, idx: int, surface: pygame.Surface):
        world = self.worlds[idx]
        cfg = world.cfg
        cfg.bg_color = self.palette[idx % len(self.palette)]
        self.draw_grid(surface, cfg)

        # sprites scaled to the cell size
        drone_sprite, building_sprite = self._get_sprites(cfg.cell_px)

        # buildings (goals) - strict pixel-centered placement
        for (gx, gy) in world.goals:
            cell = self.cell_rect(cfg, gx, gy)
            br = building_sprite.get_rect(center=cell.center)
            surface.blit(building_sprite, br.topleft)

        # drones (agents) - strict pixel-centered placement
        for (ax, ay) in world.agents:
            cell = self.cell_rect(cfg, ax, ay)
            dr = drone_sprite.get_rect(center=cell.center)
            surface.blit(drone_sprite, dr.topleft)

        # HUD
        t_label = self.font.render(f"t={world.t}", True, (220, 220, 220))
        surface.blit(t_label, (6, 6))

    # ----- helper to draw arrow with alpha -----
    def _draw_arrow(self, surf: pygame.Surface, color_rgba: Tuple[int,int,int,int],
                    start: Tuple[int,int], end: Tuple[int,int], width: int = 3, head_len: int = 12, head_w: int = 8):
        layer = pygame.Surface((surf.get_width(), surf.get_height()), pygame.SRCALPHA)
        pygame.draw.line(layer, color_rgba, start, end, width)

        dx = end[0] - start[0]
        dy = end[1] - start[1]
        ang = math.atan2(dy, dx)
        ux, uy = math.cos(ang), math.sin(ang)

        bx = end[0] - ux * head_len
        by = end[1] - uy * head_len

        px, py = -uy, ux
        left = (int(bx + px * head_w/2), int(by + py * head_w/2))
        right = (int(bx - px * head_w/2), int(by - py * head_w/2))

        pygame.draw.polygon(layer, color_rgba, [end, left, right])
        surf.blit(layer, (0, 0))

    def _draw_watermark(self):
        # Render watermark text with a soft shadow; semi-transparent
        text = self.watermark_text
        # Shadow
        shadow = self.watermark_font.render(text, True, (0, 0, 0))
        shadow.set_alpha(120)
        # Main text
        fg = self.watermark_font.render(text, True, (255, 255, 255))
        fg.set_alpha(170)

        # Position: bottom-right with padding
        pad = 10
        sw, sh = shadow.get_size()
        x = self.screen.get_width() - sw - pad
        y = self.screen.get_height() - sh - pad

        # Blit shadow slightly offset
        self.screen.blit(shadow, (x + 2, y + 2))
        # Blit main text
        self.screen.blit(fg, (x, y))

    def draw_tower(self, server: FedServer, broadcast_phase: float):
        rect = self.tower_rect()
        tower_surf = pygame.Surface((rect.w, rect.h), pygame.SRCALPHA)
        tower_surf.fill((24, 27, 31))

        # Title
        title = self.bigfont.render("Central Server", True, (240, 240, 240))
        tower_surf.blit(title, (16, 14))
        sub = self.font.render(f"Round: {server.round}", True, (180, 180, 180))
        tower_surf.blit(sub, (16, 40))

        # Tower mast centered within its full-width bar
        cx = rect.w // 2
        base_y = rect.h - 20
        mast_h = 120
        pygame.draw.rect(tower_surf, (200, 200, 200), (cx-6, base_y-mast_h, 12, mast_h))
        pygame.draw.polygon(tower_surf, (180, 180, 180),
                            [(cx-30, base_y), (cx+30, base_y), (cx, base_y-24)])
        pygame.draw.circle(tower_surf, (240, 240, 240), (cx, base_y-mast_h), 8)

        # Pulsing broadcast waves (decorative rings)
        for k in range(1, 5):
            radius = int(8 + 20*k + 10*math.sin(broadcast_phase + k))
            color = (255, 80, 80, max(30, 140 - 28*k))
            pygame.draw.circle(tower_surf, color, (cx, base_y - mast_h), radius, width=2)

        # Legend: both flows
        legend = [f"K = {self.k_local} local steps", "Aggregate (blue)", "Broadcast (red)"]
        for i, line in enumerate(legend):
            txt = self.font.render("• " + line, True, (200, 200, 200))
            tower_surf.blit(txt, (16, 70 + i*18))

        # Blit tower row
        self.screen.blit(tower_surf, rect.topleft)

        # ---- Coordinates for tower head ----
        tower_head = (rect.x + cx, rect.y + base_y - mast_h)

        # ---- Incoming BLUE arrows: world CENTER -> tower head ----
        for i, _w in enumerate(self.worlds):
            wrect = self.world_rect(i)
            pulse = abs(math.sin(broadcast_phase + math.pi * 0.5))
            alpha = int(110 + 110 * pulse)
            width = 2 + int(2 * pulse)
            color = (90, 170, 255, alpha)  # soft blue
            source = (wrect.centerx, wrect.centery)
            self._draw_arrow(self.screen, color, source, tower_head, width=width)

        # ---- Outgoing RED arrows: tower head -> world BOTTOM edge (as in your last snippet) ----
        for i, _w in enumerate(self.worlds):
            wrect = self.world_rect(i)
            pulse = abs(math.sin(broadcast_phase))
            alpha = int(120 + 100 * pulse)
            width = 2 + int(2 * pulse)
            color = (255, 80, 80, alpha)  # red
            world_bottom = (wrect.centerx, wrect.bottom)
            self._draw_arrow(self.screen, color, tower_head, world_bottom, width=width)

    def draw(self, server: FedServer, broadcast_phase: float):
        self.screen.fill((18, 20, 22))
        # draw worlds (one row at the top)
        for i in range(len(self.worlds)):
            r = self.world_rect(i)
            panel = pygame.Surface((r.w, r.h))
            self.draw_world(i, panel)
            pygame.draw.rect(panel, (70, 78, 88), panel.get_rect(), width=2, border_radius=10)
            self.screen.blit(panel, r.topleft)

        # draw central tower row beneath
        self.draw_tower(server, broadcast_phase)

        # draw watermark on top of everything
        self._draw_watermark()

        pygame.display.flip()


# ----------------------------- Recording helper (fixed) -----------------------------
class PygameRecorder:
    """
    Records the current pygame display surface.
    Order of backends:
      1) OpenCV (MP4)
      2) imageio + imageio-ffmpeg (MP4)
      3) GIF fallback (RAM-heavy)
    Hotkeys: R toggles recording, S saves a PNG snapshot.
    """
    def __init__(self, surface, fps=30, path="recording.mp4", gif_path="recording.gif"):
        self.surface = surface
        self.fps = int(fps)
        self.path = path
        self.gif_path = gif_path
        self._writer = None      # cv2 VideoWriter or imageio Writer
        self._mode = None        # "cv2" | "imageio-mp4" | "gif"
        self._frames = []        # only used for GIF fallback
        self._handle_keys = False
        self._enabled = True
        self._frame_size = (surface.get_width(), surface.get_height())

    def start(self, handle_events=False):
        self._enabled = True
        self._handle_keys = bool(handle_events)

        # --- Try OpenCV first ---
        try:
            import cv2  # type: ignore
            for code in ["avc1", "H264", "mp4v", "MJPG"]:
                fourcc = cv2.VideoWriter_fourcc(*code)
                vw = cv2.VideoWriter(self.path, fourcc, self.fps, self._frame_size)
                if vw.isOpened():
                    self._writer = vw
                    self._mode = "cv2"
                    print(f"[rec] Recording to MP4 via OpenCV → {self.path} [{code}]")
                    return
        except Exception:
            pass

        # --- Try imageio (requires: pip install imageio imageio-ffmpeg) ---
        try:
            import imageio  # type: ignore
            self._writer = imageio.get_writer(
                self.path, fps=self.fps, codec="libx264", quality=8
            )
            self._mode = "imageio-mp4"
            print(f"[rec] Recording to MP4 via imageio-ffmpeg → {self.path}")
            return
        except Exception as e:
            print(f"[rec] imageio-ffmpeg unavailable: {e}")

        # --- Fallback: GIF (RAM-heavy) ---
        try:
            import imageio  # type: ignore
            self._mode = "gif"
            print(f"[rec] Falling back to GIF (RAM-heavy) → {self.gif_path}")
            return
        except Exception:
            self._mode = None
            print("[rec] No recording backend available (cv2/imageio missing).")

    def _grab_frame_rgb(self):
        arr = pygame.surfarray.array3d(self.surface)
        return np.transpose(arr, (1, 0, 2))  # (H, W, 3), RGB uint8

    def capture(self):
        if not self._enabled or self._mode is None:
            return
        frame = self._grab_frame_rgb()

        if self._mode == "cv2":
            import cv2
            bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            self._writer.write(bgr)

        elif self._mode == "imageio-mp4":
            self._writer.append_data(frame)

        elif self._mode == "gif":
            self._frames.append(frame)

    def close(self):
        if self._mode == "cv2" and self._writer is not None:
            self._writer.release()
            print(f"[rec] Saved MP4 to {self.path}")

        elif self._mode == "imageio-mp4" and self._writer is not None:
            self._writer.close()
            print(f"[rec] Saved MP4 to {self.path}")

        elif self._mode == "gif" and self._frames:
            import imageio
            imageio.mimsave(self.gif_path, self._frames, duration=1.0/max(1, self.fps))
            print(f"[rec] Saved GIF to {self.gif_path}")

        self._writer = None
        self._frames = []
        self._mode = None
        self._enabled = False

    def process_event(self, event):
        if not self._handle_keys:
            return
        if event.type == pygame.KEYDOWN:
            if event.key == pygame.K_r:
                self._enabled = not self._enabled
                print(f"[rec] {'RESUMED' if self._enabled else 'PAUSED'} recording.")
            elif event.key == pygame.K_s:
                fname = f"snap_{int(time.time())}.png"
                pygame.image.save(self.surface, fname)
                print(f"[rec] Saved snapshot → {fname}")


# ----------------------------- Main Loop -----------------------------

def main():
    # ------- parameters -------
    M = 4            # number of client worlds (all shown on ONE row)
    LOCAL_STEPS = 48
    FPS = 30
    MAX_SECONDS = 45  # <-- 2 minutes video cutoff

    ensure_assets()

    # ------- build worlds with heterogeneity -------
    worlds = []
    for i in range(M):
        cfg = WorldCfg(
            grid=10 + (i % 3),
            n_agents=2,
            step_penalty=0.01 + 0.002*i,
            goal_reward=1.0,
            max_steps=99999,
            cell_px=34,
            margin_px=20,
        )
        worlds.append(World(cfg, seed=1234 + i))

    clients = [FedClient(i, worlds[i], RandomPolicy(worlds[i].cfg.n_agents)) for i in range(M)]
    server = FedServer(clients)
    renderer = Renderer(worlds, cols=M, k_local=LOCAL_STEPS)  # legend shows K

    # ---- set up recorder ----
    recorder = PygameRecorder(surface=renderer.screen, fps=FPS,
                              path="federated_worlds.mp4", gif_path="federated_worlds.gif")
    recorder.start(handle_events=True)   # R toggles record, S snapshot

    phase = 0.0
    running = True
    start_time = time.time()

    while running:
        # ------------- events -------------
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            recorder.process_event(event)   # hotkeys

        # ------------- local rollouts -------------
        deltas = []
        for c in clients:
            d = c.rollout_and_update(K=LOCAL_STEPS)
            deltas.append(d)

        # ------------- server aggregate + broadcast -------------
        delta = server.aggregate(deltas)
        server.apply(delta)
        server.broadcast()
        server.round += 1

        # ------------- render -------------
        phase += 0.25
        renderer.draw(server, phase)

        # ------------- record this frame -------------
        recorder.capture()

        renderer.clock.tick(FPS)

        # ------------- 2-minute cutoff -------------
        if time.time() - start_time >= MAX_SECONDS:
            running = False

    # ---- teardown ----
    recorder.close()
    pygame.quit()
    sys.exit(0)


if __name__ == "__main__":
    main()




[rec] Recording to MP4 via imageio-ffmpeg → federated_worlds.mp4
[rec] Saved MP4 to federated_worlds.mp4


SystemExit: 0