diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index d2f2476..b391a5b 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -12,7 +12,7 @@ jobs: name: Format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: python-version: "3.x" diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index f978673..cbe6f26 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: diff --git a/BCC-Examples/container-monitor/README.md b/BCC-Examples/container-monitor/README.md new file mode 100644 index 0000000..dc27049 --- /dev/null +++ b/BCC-Examples/container-monitor/README.md @@ -0,0 +1,49 @@ +# Container Monitor TUI + +A beautiful terminal-based container monitoring tool that combines syscall tracking, file I/O monitoring, and network traffic analysis using eBPF. + +## Features + +- 🎯 **Interactive Cgroup Selection** - Navigate and select cgroups with arrow keys +- 📊 **Real-time Monitoring** - Live graphs and statistics +- 🔥 **Syscall Tracking** - Total syscall count per cgroup +- 💾 **File I/O Monitoring** - Read/write operations and bytes with graphs +- 🌐 **Network Traffic** - RX/TX packets and bytes with live graphs +- ⚡ **Efficient Caching** - Reduced /proc lookups for better performance +- 🎨 **Beautiful TUI** - Clean, colorful terminal interface + +## Requirements + +- Python 3.7+ +- pythonbpf +- Root privileges (for eBPF) + +## Installation + +```bash +# Ensure you have pythonbpf installed +pip install pythonbpf + +# Run the monitor +sudo $(which python) container_monitor.py +``` + +## Usage + +1. **Selection Screen**: Use ↑↓ arrow keys to navigate through cgroups, press ENTER to select +2. **Monitoring Screen**: View real-time graphs and statistics, press ESC or 'b' to go back +3. **Exit**: Press 'q' at any time to quit + +## Architecture + +- `container_monitor.py` - Main BPF program combining all three tracers +- `data_collector.py` - Data collection, caching, and history management +- `tui. py` - Terminal user interface with selection and monitoring screens + +## BPF Programs + +- **vfs_read/vfs_write** - Track file I/O operations +- **__netif_receive_skb/__dev_queue_xmit** - Track network traffic +- **raw_syscalls/sys_enter** - Count all syscalls + +All programs filter by cgroup ID for per-container monitoring. diff --git a/BCC-Examples/container-monitor/container_monitor.py b/BCC-Examples/container-monitor/container_monitor.py new file mode 100644 index 0000000..7904060 --- /dev/null +++ b/BCC-Examples/container-monitor/container_monitor.py @@ -0,0 +1,220 @@ +"""Container Monitor - TUI-based cgroup monitoring combining syscall, file I/O, and network tracking.""" + +from pythonbpf import bpf, map, section, bpfglobal, struct, BPF +from pythonbpf.maps import HashMap +from pythonbpf.helper import get_current_cgroup_id +from ctypes import c_int32, c_uint64, c_void_p +from vmlinux import struct_pt_regs, struct_sk_buff + +from data_collection import ContainerDataCollector +from tui import ContainerMonitorTUI + + +# ==================== BPF Structs ==================== + + +@bpf +@struct +class read_stats: + bytes: c_uint64 + ops: c_uint64 + + +@bpf +@struct +class write_stats: + bytes: c_uint64 + ops: c_uint64 + + +@bpf +@struct +class net_stats: + rx_packets: c_uint64 + tx_packets: c_uint64 + rx_bytes: c_uint64 + tx_bytes: c_uint64 + + +# ==================== BPF Maps ==================== + + +@bpf +@map +def read_map() -> HashMap: + return HashMap(key=c_uint64, value=read_stats, max_entries=1024) + + +@bpf +@map +def write_map() -> HashMap: + return HashMap(key=c_uint64, value=write_stats, max_entries=1024) + + +@bpf +@map +def net_stats_map() -> HashMap: + return HashMap(key=c_uint64, value=net_stats, max_entries=1024) + + +@bpf +@map +def syscall_count() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=1024) + + +# ==================== File I/O Tracing ==================== + + +@bpf +@section("kprobe/vfs_read") +def trace_read(ctx: struct_pt_regs) -> c_int32: + cg = get_current_cgroup_id() + count = c_uint64(ctx.dx) + ptr = read_map.lookup(cg) + if ptr: + s = read_stats() + s.bytes = ptr.bytes + count + s.ops = ptr.ops + 1 + read_map.update(cg, s) + else: + s = read_stats() + s.bytes = count + s.ops = c_uint64(1) + read_map.update(cg, s) + + return c_int32(0) + + +@bpf +@section("kprobe/vfs_write") +def trace_write(ctx1: struct_pt_regs) -> c_int32: + cg = get_current_cgroup_id() + count = c_uint64(ctx1.dx) + ptr = write_map.lookup(cg) + + if ptr: + s = write_stats() + s.bytes = ptr.bytes + count + s.ops = ptr.ops + 1 + write_map.update(cg, s) + else: + s = write_stats() + s.bytes = count + s.ops = c_uint64(1) + write_map.update(cg, s) + + return c_int32(0) + + +# ==================== Network I/O Tracing ==================== + + +@bpf +@section("kprobe/__netif_receive_skb") +def trace_netif_rx(ctx2: struct_pt_regs) -> c_int32: + cgroup_id = get_current_cgroup_id() + skb = struct_sk_buff(ctx2.di) + pkt_len = c_uint64(skb.len) + + stats_ptr = net_stats_map.lookup(cgroup_id) + + if stats_ptr: + stats = net_stats() + stats.rx_packets = stats_ptr.rx_packets + 1 + stats.tx_packets = stats_ptr.tx_packets + stats.rx_bytes = stats_ptr.rx_bytes + pkt_len + stats.tx_bytes = stats_ptr.tx_bytes + net_stats_map.update(cgroup_id, stats) + else: + stats = net_stats() + stats.rx_packets = c_uint64(1) + stats.tx_packets = c_uint64(0) + stats.rx_bytes = pkt_len + stats.tx_bytes = c_uint64(0) + net_stats_map.update(cgroup_id, stats) + + return c_int32(0) + + +@bpf +@section("kprobe/__dev_queue_xmit") +def trace_dev_xmit(ctx3: struct_pt_regs) -> c_int32: + cgroup_id = get_current_cgroup_id() + skb = struct_sk_buff(ctx3.di) + pkt_len = c_uint64(skb.len) + + stats_ptr = net_stats_map.lookup(cgroup_id) + + if stats_ptr: + stats = net_stats() + stats.rx_packets = stats_ptr.rx_packets + stats.tx_packets = stats_ptr.tx_packets + 1 + stats.rx_bytes = stats_ptr.rx_bytes + stats.tx_bytes = stats_ptr.tx_bytes + pkt_len + net_stats_map.update(cgroup_id, stats) + else: + stats = net_stats() + stats.rx_packets = c_uint64(0) + stats.tx_packets = c_uint64(1) + stats.rx_bytes = c_uint64(0) + stats.tx_bytes = pkt_len + net_stats_map.update(cgroup_id, stats) + + return c_int32(0) + + +# ==================== Syscall Tracing ==================== + + +@bpf +@section("tracepoint/raw_syscalls/sys_enter") +def count_syscalls(ctx: c_void_p) -> c_int32: + cgroup_id = get_current_cgroup_id() + count_ptr = syscall_count.lookup(cgroup_id) + + if count_ptr: + new_count = count_ptr + c_uint64(1) + syscall_count.update(cgroup_id, new_count) + else: + syscall_count.update(cgroup_id, c_uint64(1)) + + return c_int32(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +# ==================== Main ==================== + +if __name__ == "__main__": + print("🔥 Loading BPF programs...") + + # Load and attach BPF program + b = BPF() + b.load() + b.attach_all() + + # Get map references and enable struct deserialization + read_map_ref = b["read_map"] + write_map_ref = b["write_map"] + net_stats_map_ref = b["net_stats_map"] + syscall_count_ref = b["syscall_count"] + + read_map_ref.set_value_struct("read_stats") + write_map_ref.set_value_struct("write_stats") + net_stats_map_ref.set_value_struct("net_stats") + + print("✅ BPF programs loaded and attached") + + # Setup data collector + collector = ContainerDataCollector( + read_map_ref, write_map_ref, net_stats_map_ref, syscall_count_ref + ) + + # Create and run TUI + tui = ContainerMonitorTUI(collector) + tui.run() diff --git a/BCC-Examples/container-monitor/data_collection.py b/BCC-Examples/container-monitor/data_collection.py new file mode 100644 index 0000000..05b4990 --- /dev/null +++ b/BCC-Examples/container-monitor/data_collection.py @@ -0,0 +1,208 @@ +"""Data collection and management for container monitoring.""" + +import os +import time +from pathlib import Path +from typing import Dict, List, Set, Optional +from dataclasses import dataclass +from collections import deque, defaultdict + + +@dataclass +class CgroupInfo: + """Information about a cgroup.""" + + id: int + name: str + path: str + + +@dataclass +class ContainerStats: + """Statistics for a container/cgroup.""" + + cgroup_id: int + cgroup_name: str + + # File I/O + read_ops: int = 0 + read_bytes: int = 0 + write_ops: int = 0 + write_bytes: int = 0 + + # Network I/O + rx_packets: int = 0 + rx_bytes: int = 0 + tx_packets: int = 0 + tx_bytes: int = 0 + + # Syscalls + syscall_count: int = 0 + + # Timestamp + timestamp: float = 0.0 + + +class ContainerDataCollector: + """Collects and manages container monitoring data from BPF.""" + + def __init__( + self, read_map, write_map, net_stats_map, syscall_map, history_size: int = 100 + ): + self.read_map = read_map + self.write_map = write_map + self.net_stats_map = net_stats_map + self.syscall_map = syscall_map + + # Caching + self._cgroup_cache: Dict[int, CgroupInfo] = {} + self._cgroup_cache_time = 0 + self._cache_ttl = 5.0 + 0 # Refresh cache every 5 seconds + + # Historical data for graphing + self._history_size = history_size + self._history: Dict[int, deque] = defaultdict( + lambda: deque(maxlen=history_size) + ) + + def get_all_cgroups(self) -> List[CgroupInfo]: + """Get all cgroups with caching.""" + current_time = time.time() + + # Use cached data if still valid + if current_time - self._cgroup_cache_time < self._cache_ttl: + return list(self._cgroup_cache.values()) + + # Refresh cache + self._refresh_cgroup_cache() + return list(self._cgroup_cache.values()) + + def _refresh_cgroup_cache(self): + """Refresh the cgroup cache from /proc.""" + cgroup_map: Dict[int, Set[str]] = defaultdict(set) + + # Scan /proc to find all cgroups + for proc_dir in Path("/proc").glob("[0-9]*"): + try: + cgroup_file = proc_dir / "cgroup" + if not cgroup_file.exists(): + continue + + with open(cgroup_file) as f: + for line in f: + parts = line.strip().split(":") + if len(parts) >= 3: + cgroup_path = parts[2] + cgroup_mount = f"/sys/fs/cgroup{cgroup_path}" + + if os.path.exists(cgroup_mount): + stat_info = os.stat(cgroup_mount) + cgroup_id = stat_info.st_ino + cgroup_map[cgroup_id].add(cgroup_path) + + except (PermissionError, FileNotFoundError, OSError): + continue + + # Update cache with best names + new_cache = {} + for cgroup_id, paths in cgroup_map.items(): + # Pick the most descriptive path + best_path = self._get_best_cgroup_path(paths) + name = self._get_cgroup_name(best_path) + + new_cache[cgroup_id] = CgroupInfo(id=cgroup_id, name=name, path=best_path) + + self._cgroup_cache = new_cache + self._cgroup_cache_time = time.time() + + def _get_best_cgroup_path(self, paths: Set[str]) -> str: + """Select the most descriptive cgroup path.""" + path_list = list(paths) + + # Prefer paths with more components (more specific) + # Prefer paths containing docker, podman, etc. + for keyword in ["docker", "podman", "kubernetes", "k8s", "systemd"]: + for path in path_list: + if keyword in path.lower(): + return path + + # Return longest path (most specific) + return max(path_list, key=lambda p: (len(p.split("/")), len(p))) + + def _get_cgroup_name(self, path: str) -> str: + """Extract a friendly name from cgroup path.""" + if not path or path == "/": + return "root" + + # Remove leading/trailing slashes + path = path.strip("/") + + # Try to extract container ID or service name + parts = path.split("/") + + # For Docker: /docker/ + if "docker" in path.lower(): + for i, part in enumerate(parts): + if part.lower() == "docker" and i + 1 < len(parts): + container_id = parts[i + 1][:12] # Short ID + return f"docker:{container_id}" + + # For systemd services + if "system.slice" in path: + for part in parts: + if part.endswith(".service"): + return part.replace(".service", "") + + # For user slices + if "user.slice" in path: + return f"user:{parts[-1]}" if parts else "user" + + # Default: use last component + return parts[-1] if parts else path + + def get_stats_for_cgroup(self, cgroup_id: int) -> ContainerStats: + """Get current statistics for a specific cgroup.""" + cgroup_info = self._cgroup_cache.get(cgroup_id) + cgroup_name = cgroup_info.name if cgroup_info else f"cgroup-{cgroup_id}" + + stats = ContainerStats( + cgroup_id=cgroup_id, cgroup_name=cgroup_name, timestamp=time.time() + ) + + # Get file I/O stats + read_stat = self.read_map.lookup(cgroup_id) + if read_stat: + stats.read_ops = int(read_stat.ops) + stats.read_bytes = int(read_stat.bytes) + + write_stat = self.write_map.lookup(cgroup_id) + if write_stat: + stats.write_ops = int(write_stat.ops) + stats.write_bytes = int(write_stat.bytes) + + # Get network stats + net_stat = self.net_stats_map.lookup(cgroup_id) + if net_stat: + stats.rx_packets = int(net_stat.rx_packets) + stats.rx_bytes = int(net_stat.rx_bytes) + stats.tx_packets = int(net_stat.tx_packets) + stats.tx_bytes = int(net_stat.tx_bytes) + + # Get syscall count + syscall_cnt = self.syscall_map.lookup(cgroup_id) + if syscall_cnt is not None: + stats.syscall_count = int(syscall_cnt) + + # Add to history + self._history[cgroup_id].append(stats) + + return stats + + def get_history(self, cgroup_id: int) -> List[ContainerStats]: + """Get historical statistics for graphing.""" + return list(self._history[cgroup_id]) + + def get_cgroup_info(self, cgroup_id: int) -> Optional[CgroupInfo]: + """Get cached cgroup information.""" + return self._cgroup_cache.get(cgroup_id) diff --git a/BCC-Examples/container-monitor/tui.py b/BCC-Examples/container-monitor/tui.py new file mode 100644 index 0000000..f006137 --- /dev/null +++ b/BCC-Examples/container-monitor/tui.py @@ -0,0 +1,527 @@ +"""Terminal User Interface for container monitoring.""" + +import time +import curses +from typing import Optional, List +from data_collection import ContainerDataCollector + + +class ContainerMonitorTUI: + """TUI for container monitoring with cgroup selection and live graphs.""" + + def __init__(self, collector: ContainerDataCollector): + self.collector = collector + self.selected_cgroup: Optional[int] = None + self.current_screen = "selection" # "selection" or "monitoring" + self.selected_index = 0 + self.scroll_offset = 0 + + def run(self): + """Run the TUI application.""" + curses.wrapper(self._main_loop) + + def _main_loop(self, stdscr): + """Main curses loop.""" + # Configure curses + curses.curs_set(0) # Hide cursor + stdscr.nodelay(True) # Non-blocking input + stdscr.timeout(100) # Refresh every 100ms + + # Initialize colors + curses.start_color() + curses.init_pair(1, curses.COLOR_CYAN, curses.COLOR_BLACK) + curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK) + curses.init_pair(3, curses.COLOR_YELLOW, curses.COLOR_BLACK) + curses.init_pair(4, curses.COLOR_RED, curses.COLOR_BLACK) + curses.init_pair(5, curses.COLOR_MAGENTA, curses.COLOR_BLACK) + curses.init_pair(6, curses.COLOR_WHITE, curses.COLOR_BLUE) + curses.init_pair(7, curses.COLOR_BLUE, curses.COLOR_BLACK) + curses.init_pair(8, curses.COLOR_WHITE, curses.COLOR_CYAN) + + while True: + stdscr.clear() + + try: + if self.current_screen == "selection": + self._draw_selection_screen(stdscr) + elif self.current_screen == "monitoring": + self._draw_monitoring_screen(stdscr) + + stdscr.refresh() + + # Handle input + key = stdscr.getch() + if key != -1: + if not self._handle_input(key): + break # Exit requested + + except KeyboardInterrupt: + break + except Exception as e: + # Show error + stdscr.addstr(0, 0, f"Error: {str(e)}") + stdscr.refresh() + time.sleep(2) + + def _draw_selection_screen(self, stdscr): + """Draw the cgroup selection screen.""" + height, width = stdscr.getmaxyx() + + # Draw fancy header box + self._draw_fancy_header( + stdscr, "🐳 CONTAINER MONITOR", "Select a Cgroup to Monitor" + ) + + # Instructions + instructions = "↑↓: Navigate | ENTER: Select | q: Quit | r: Refresh" + stdscr.attron(curses.color_pair(3)) + stdscr.addstr(3, (width - len(instructions)) // 2, instructions) + stdscr.attroff(curses.color_pair(3)) + + # Get cgroups + cgroups = self.collector.get_all_cgroups() + + if not cgroups: + msg = "No cgroups found. Waiting for activity..." + stdscr.attron(curses.color_pair(4)) + stdscr.addstr(height // 2, (width - len(msg)) // 2, msg) + stdscr.attroff(curses.color_pair(4)) + return + + # Sort cgroups by name + cgroups.sort(key=lambda c: c.name) + + # Adjust selection bounds + if self.selected_index >= len(cgroups): + self.selected_index = len(cgroups) - 1 + if self.selected_index < 0: + self.selected_index = 0 + + # Calculate visible range + list_height = height - 8 + if self.selected_index < self.scroll_offset: + self.scroll_offset = self.selected_index + elif self.selected_index >= self.scroll_offset + list_height: + self.scroll_offset = self.selected_index - list_height + 1 + + # Draw cgroup list with fancy borders + start_y = 5 + stdscr.attron(curses.color_pair(1)) + stdscr.addstr(start_y, 2, "╔" + "═" * (width - 6) + "╗") + stdscr.attroff(curses.color_pair(1)) + + for i in range(list_height): + idx = self.scroll_offset + i + y = start_y + 1 + i + + stdscr.attron(curses.color_pair(1)) + stdscr.addstr(y, 2, "║") + stdscr.addstr(y, width - 3, "║") + stdscr.attroff(curses.color_pair(1)) + + if idx >= len(cgroups): + continue + + cgroup = cgroups[idx] + + if idx == self.selected_index: + # Highlight selected with better styling + stdscr.attron(curses.color_pair(8) | curses.A_BOLD) + line = f" ► {cgroup.name:<35} │ ID: {cgroup.id} " + stdscr.addstr(y, 3, line[: width - 6]) + stdscr.attroff(curses.color_pair(8) | curses.A_BOLD) + else: + stdscr.attron(curses.color_pair(7)) + line = f" {cgroup.name:<35} │ ID: {cgroup.id} " + stdscr.addstr(y, 3, line[: width - 6]) + stdscr.attroff(curses.color_pair(7)) + + # Bottom border + bottom_y = start_y + 1 + list_height + stdscr.attron(curses.color_pair(1)) + stdscr.addstr(bottom_y, 2, "╚" + "═" * (width - 6) + "╝") + stdscr.attroff(curses.color_pair(1)) + + # Footer with count and scroll indicator + footer = f"Total: {len(cgroups)} cgroups" + if len(cgroups) > list_height: + footer += f" │ Showing {self.scroll_offset + 1}-{min(self.scroll_offset + list_height, len(cgroups))}" + stdscr.attron(curses.color_pair(1)) + stdscr.addstr(height - 2, (width - len(footer)) // 2, footer) + stdscr.attroff(curses.color_pair(1)) + + def _draw_monitoring_screen(self, stdscr): + """Draw the monitoring screen for selected cgroup.""" + height, width = stdscr.getmaxyx() + + if self.selected_cgroup is None: + return + + # Get current stats + stats = self.collector.get_stats_for_cgroup(self.selected_cgroup) + history = self.collector.get_history(self.selected_cgroup) + + # Draw fancy header + self._draw_fancy_header( + stdscr, f"📊 {stats.cgroup_name}", "Live Performance Metrics" + ) + + # Instructions + instructions = "ESC/b: Back to List | q: Quit" + stdscr.attron(curses.color_pair(3)) + stdscr.addstr(3, (width - len(instructions)) // 2, instructions) + stdscr.attroff(curses.color_pair(3)) + + # Calculate metrics for rate display + rates = self._calculate_rates(history) + + y = 5 + + # Syscall count in a fancy box + self._draw_metric_box( + stdscr, + y, + 2, + width - 4, + "⚡ SYSTEM CALLS", + f"{stats.syscall_count:,}", + f"Rate: {rates['syscalls_per_sec']:.1f}/sec", + curses.color_pair(5), + ) + + y += 4 + + # Network I/O Section + self._draw_section_header(stdscr, y, "🌐 NETWORK I/O", 1) + y += 1 + + # RX graph with legend + rx_label = f"RX: {self._format_bytes(stats.rx_bytes)}" + rx_rate = f"{self._format_bytes(rates['rx_bytes_per_sec'])}/s" + rx_pkts = f"{stats.rx_packets:,} pkts ({rates['rx_pkts_per_sec']:.1f}/s)" + + self._draw_labeled_graph( + stdscr, + y, + 2, + width - 4, + 4, + rx_label, + rx_rate, + rx_pkts, + [s.rx_bytes for s in history], + curses.color_pair(2), + "Received Traffic (last 100 samples)", + ) + + y += 6 + + # TX graph with legend + tx_label = f"TX: {self._format_bytes(stats.tx_bytes)}" + tx_rate = f"{self._format_bytes(rates['tx_bytes_per_sec'])}/s" + tx_pkts = f"{stats.tx_packets:,} pkts ({rates['tx_pkts_per_sec']:.1f}/s)" + + self._draw_labeled_graph( + stdscr, + y, + 2, + width - 4, + 4, + tx_label, + tx_rate, + tx_pkts, + [s.tx_bytes for s in history], + curses.color_pair(3), + "Transmitted Traffic (last 100 samples)", + ) + + y += 6 + + # File I/O Section + self._draw_section_header(stdscr, y, "💾 FILE I/O", 1) + y += 1 + + # Read graph with legend + read_label = f"READ: {self._format_bytes(stats.read_bytes)}" + read_rate = f"{self._format_bytes(rates['read_bytes_per_sec'])}/s" + read_ops = f"{stats.read_ops:,} ops ({rates['read_ops_per_sec']:.1f}/s)" + + self._draw_labeled_graph( + stdscr, + y, + 2, + width - 4, + 4, + read_label, + read_rate, + read_ops, + [s.read_bytes for s in history], + curses.color_pair(4), + "Read Operations (last 100 samples)", + ) + + y += 6 + + # Write graph with legend + write_label = f"WRITE: {self._format_bytes(stats.write_bytes)}" + write_rate = f"{self._format_bytes(rates['write_bytes_per_sec'])}/s" + write_ops = f"{stats.write_ops:,} ops ({rates['write_ops_per_sec']:.1f}/s)" + + self._draw_labeled_graph( + stdscr, + y, + 2, + width - 4, + 4, + write_label, + write_rate, + write_ops, + [s.write_bytes for s in history], + curses.color_pair(5), + "Write Operations (last 100 samples)", + ) + + def _draw_fancy_header(self, stdscr, title: str, subtitle: str): + """Draw a fancy header with title and subtitle.""" + height, width = stdscr.getmaxyx() + + # Top border + stdscr.attron(curses.color_pair(6) | curses.A_BOLD) + stdscr.addstr(0, 0, "═" * width) + + # Title + stdscr.addstr(0, (width - len(title)) // 2, f" {title} ") + stdscr.attroff(curses.color_pair(6) | curses.A_BOLD) + + # Subtitle + stdscr.attron(curses.color_pair(1)) + stdscr.addstr(1, (width - len(subtitle)) // 2, subtitle) + stdscr.attroff(curses.color_pair(1)) + + # Bottom border + stdscr.attron(curses.color_pair(6)) + stdscr.addstr(2, 0, "═" * width) + stdscr.attroff(curses.color_pair(6)) + + def _draw_metric_box( + self, + stdscr, + y: int, + x: int, + width: int, + label: str, + value: str, + detail: str, + color_pair: int, + ): + """Draw a fancy box for displaying a metric.""" + # Top border + stdscr.attron(color_pair | curses.A_BOLD) + stdscr.addstr(y, x, "┌" + "─" * (width - 2) + "┐") + + # Label + stdscr.addstr(y + 1, x, "│") + stdscr.addstr(y + 1, x + 2, label) + stdscr.addstr(y + 1, x + width - 1, "│") + + # Value (large) + stdscr.addstr(y + 2, x, "│") + stdscr.attroff(color_pair | curses.A_BOLD) + stdscr.attron(curses.color_pair(2) | curses.A_BOLD) + stdscr.addstr(y + 2, x + 4, value) + stdscr.attroff(curses.color_pair(2) | curses.A_BOLD) + stdscr.attron(color_pair | curses.A_BOLD) + stdscr.addstr(y + 2, x + width - 1, "│") + + # Detail + stdscr.addstr(y + 2, x + width - len(detail) - 3, detail) + + # Bottom border + stdscr.addstr(y + 3, x, "└" + "─" * (width - 2) + "┘") + stdscr.attroff(color_pair | curses.A_BOLD) + + def _draw_section_header(self, stdscr, y: int, title: str, color_pair: int): + """Draw a section header.""" + height, width = stdscr.getmaxyx() + stdscr.attron(curses.color_pair(color_pair) | curses.A_BOLD) + stdscr.addstr(y, 2, title) + stdscr.addstr(y, len(title) + 3, "─" * (width - len(title) - 5)) + stdscr.attroff(curses.color_pair(color_pair) | curses.A_BOLD) + + def _draw_labeled_graph( + self, + stdscr, + y: int, + x: int, + width: int, + height: int, + label: str, + rate: str, + detail: str, + data: List[float], + color_pair: int, + description: str, + ): + """Draw a graph with labels and legend.""" + # Header with metrics + stdscr.attron(curses.color_pair(1) | curses.A_BOLD) + stdscr.addstr(y, x, label) + stdscr.attroff(curses.color_pair(1) | curses.A_BOLD) + + stdscr.attron(curses.color_pair(2)) + stdscr.addstr(y, x + len(label) + 2, rate) + stdscr.attroff(curses.color_pair(2)) + + stdscr.attron(curses.color_pair(7)) + stdscr.addstr(y, x + len(label) + len(rate) + 4, detail) + stdscr.attroff(curses.color_pair(7)) + + # Draw the graph + if len(data) > 1: + self._draw_bar_graph_enhanced( + stdscr, y + 1, x, width, height, data, color_pair + ) + else: + stdscr.attron(curses.color_pair(7)) + stdscr.addstr(y + 2, x + 2, "Collecting data...") + stdscr.attroff(curses.color_pair(7)) + + # Graph legend at bottom + stdscr.attron(curses.color_pair(7)) + stdscr.addstr(y + height + 1, x, f"└─ {description}") + stdscr.attroff(curses.color_pair(7)) + + def _draw_bar_graph_enhanced( + self, + stdscr, + y: int, + x: int, + width: int, + height: int, + data: List[float], + color_pair: int, + ): + """Draw an enhanced bar graph with axis and scale.""" + if not data or width < 2: + return + + # Calculate statistics + max_val = max(data) if max(data) > 0 else 1 + min_val = min(data) + avg_val = sum(data) / len(data) + + # Take last 'width - 10' data points (leave room for Y-axis) + graph_width = width - 12 + recent_data = data[-graph_width:] if len(data) > graph_width else data + + # Draw Y-axis labels + stdscr.attron(curses.color_pair(7)) + stdscr.addstr(y, x, f"│{self._format_bytes(max_val):>9}") + stdscr.addstr(y + height // 2, x, f"│{self._format_bytes(avg_val):>9}") + stdscr.addstr(y + height - 1, x, f"│{self._format_bytes(min_val):>9}") + stdscr.attroff(curses.color_pair(7)) + + # Draw bars + for row in range(height): + threshold = (height - row) / height + bar_line = "" + + for val in recent_data: + normalized = val / max_val if max_val > 0 else 0 + if normalized >= threshold: + bar_line += "█" + elif normalized >= threshold - 0.15: + bar_line += "▓" + elif normalized >= threshold - 0.35: + bar_line += "▒" + elif normalized >= threshold - 0.5: + bar_line += "░" + else: + bar_line += " " + + stdscr.attron(color_pair) + stdscr.addstr(y + row, x + 11, bar_line) + stdscr.attroff(color_pair) + + # Draw X-axis + stdscr.attron(curses.color_pair(7)) + stdscr.addstr(y + height, x + 10, "├" + "─" * len(recent_data)) + stdscr.addstr(y + height, x + 10 + len(recent_data), "→ time") + stdscr.attroff(curses.color_pair(7)) + + def _calculate_rates(self, history: List) -> dict: + """Calculate per-second rates from history.""" + if len(history) < 2: + return { + "syscalls_per_sec": 0.0, + "rx_bytes_per_sec": 0.0, + "tx_bytes_per_sec": 0.0, + "rx_pkts_per_sec": 0.0, + "tx_pkts_per_sec": 0.0, + "read_bytes_per_sec": 0.0, + "write_bytes_per_sec": 0.0, + "read_ops_per_sec": 0.0, + "write_ops_per_sec": 0.0, + } + + # Calculate delta between last two samples + recent = history[-1] + previous = history[-2] + time_delta = recent.timestamp - previous.timestamp + + if time_delta <= 0: + time_delta = 1.0 + + return { + "syscalls_per_sec": (recent.syscall_count - previous.syscall_count) + / time_delta, + "rx_bytes_per_sec": (recent.rx_bytes - previous.rx_bytes) / time_delta, + "tx_bytes_per_sec": (recent.tx_bytes - previous.tx_bytes) / time_delta, + "rx_pkts_per_sec": (recent.rx_packets - previous.rx_packets) / time_delta, + "tx_pkts_per_sec": (recent.tx_packets - previous.tx_packets) / time_delta, + "read_bytes_per_sec": (recent.read_bytes - previous.read_bytes) + / time_delta, + "write_bytes_per_sec": (recent.write_bytes - previous.write_bytes) + / time_delta, + "read_ops_per_sec": (recent.read_ops - previous.read_ops) / time_delta, + "write_ops_per_sec": (recent.write_ops - previous.write_ops) / time_delta, + } + + def _format_bytes(self, bytes_val: float) -> str: + """Format bytes into human-readable string.""" + if bytes_val < 0: + bytes_val = 0 + for unit in ["B", "KB", "MB", "GB", "TB"]: + if bytes_val < 1024.0: + return f"{bytes_val:.2f}{unit}" + bytes_val /= 1024.0 + return f"{bytes_val:.2f}PB" + + def _handle_input(self, key: int) -> bool: + """Handle keyboard input. Returns False to exit.""" + if key == ord("q") or key == ord("Q"): + return False # Exit + + if self.current_screen == "selection": + if key == curses.KEY_UP: + self.selected_index = max(0, self.selected_index - 1) + elif key == curses.KEY_DOWN: + cgroups = self.collector.get_all_cgroups() + self.selected_index = min(len(cgroups) - 1, self.selected_index + 1) + elif key == ord("\n") or key == curses.KEY_ENTER or key == 10: + # Select cgroup + cgroups = self.collector.get_all_cgroups() + if cgroups and 0 <= self.selected_index < len(cgroups): + cgroups.sort(key=lambda c: c.name) + self.selected_cgroup = cgroups[self.selected_index].id + self.current_screen = "monitoring" + elif key == ord("r") or key == ord("R"): + # Force refresh cache + self.collector._cgroup_cache_time = 0 + + elif self.current_screen == "monitoring": + if key == 27 or key == ord("b") or key == ord("B"): # ESC or 'b' + self.current_screen = "selection" + self.selected_cgroup = None + + return True # Continue running diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index ac3a975..dfd2128 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -1,6 +1,6 @@ from .expr_pass import eval_expr, handle_expr, get_operand_value from .type_normalization import convert_to_bool, get_base_type_and_depth -from .ir_ops import deref_to_depth +from .ir_ops import deref_to_depth, access_struct_field from .call_registry import CallHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry @@ -10,6 +10,7 @@ "convert_to_bool", "get_base_type_and_depth", "deref_to_depth", + "access_struct_field", "get_operand_value", "CallHandlerRegistry", "VmlinuxHandlerRegistry", diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 9f3bfa4..d34dff5 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -6,11 +6,11 @@ from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes from .call_registry import CallHandlerRegistry +from .ir_ops import deref_to_depth, access_struct_field from .type_normalization import ( convert_to_bool, handle_comparator, get_base_type_and_depth, - deref_to_depth, ) from .vmlinux_registry import VmlinuxHandlerRegistry from ..vmlinux_parser.dependency_node import Field @@ -77,89 +77,6 @@ def _handle_attribute_expr( logger.info( f"Variable type: {var_type}, Variable ptr: {var_ptr}, Variable Metadata: {var_metadata}" ) - # Check if this is a pointer to a struct (from map lookup) - if ( - isinstance(var_type, ir.PointerType) - and var_metadata - and isinstance(var_metadata, str) - ): - if var_metadata in structs_sym_tab: - logger.info( - f"Handling pointer to struct {var_metadata} from map lookup" - ) - - if func is None: - raise ValueError( - f"func parameter required for null-safe pointer access to {var_name}.{attr_name}" - ) - - # Load the pointer value (ptr) - struct_ptr = builder.load(var_ptr) - - # Create blocks for null check - null_check_block = builder.block - not_null_block = func.append_basic_block( - name=f"{var_name}_not_null" - ) - merge_block = func.append_basic_block(name=f"{var_name}_merge") - - # Check if pointer is null - null_ptr = ir.Constant(struct_ptr.type, None) - is_not_null = builder.icmp_signed("!=", struct_ptr, null_ptr) - logger.info(f"Inserted null check for pointer {var_name}") - - builder.cbranch(is_not_null, not_null_block, merge_block) - - # Not-null block: Access the field - builder.position_at_end(not_null_block) - - # Get struct metadata - metadata = structs_sym_tab[var_metadata] - struct_ptr = builder.bitcast( - struct_ptr, metadata.ir_type.as_pointer() - ) - - if attr_name not in metadata.fields: - raise ValueError( - f"Field '{attr_name}' not found in struct '{var_metadata}'" - ) - - # GEP to field - field_gep = metadata.gep(builder, struct_ptr, attr_name) - - # Load field value - field_val = builder.load(field_gep) - field_type = metadata.field_type(attr_name) - - logger.info( - f"Loaded field {attr_name} from struct pointer, type: {field_type}" - ) - - # Branch to merge - not_null_after_load = builder.block - builder.branch(merge_block) - - # Merge block: PHI node for the result - builder.position_at_end(merge_block) - phi = builder.phi(field_type, name=f"{var_name}_{attr_name}") - - # If null, return zero/default value - if isinstance(field_type, ir.IntType): - zero_value = ir.Constant(field_type, 0) - elif isinstance(field_type, ir.PointerType): - zero_value = ir.Constant(field_type, None) - elif isinstance(field_type, ir.ArrayType): - # For arrays, we can't easily create a zero constant - # This case is tricky - for now, just use undef - zero_value = ir.Constant(field_type, ir.Undefined) - else: - zero_value = ir.Constant(field_type, ir.Undefined) - - phi.add_incoming(zero_value, null_check_block) - phi.add_incoming(field_val, not_null_after_load) - - logger.info(f"Created PHI node for {var_name}.{attr_name}") - return phi, field_type if ( hasattr(var_metadata, "__module__") and var_metadata.__module__ == "vmlinux" @@ -180,13 +97,23 @@ def _handle_attribute_expr( ) return None - # Regular user-defined struct - metadata = structs_sym_tab.get(var_metadata) - if metadata and attr_name in metadata.fields: - gep = metadata.gep(builder, var_ptr, attr_name) - val = builder.load(gep) - field_type = metadata.field_type(attr_name) - return val, field_type + if var_metadata in structs_sym_tab: + return access_struct_field( + builder, + var_ptr, + var_type, + var_metadata, + expr.attr, + structs_sym_tab, + func, + ) + else: + logger.error(f"Struct metadata for '{var_name}' not found") + else: + logger.error(f"Undefined variable '{var_name}' for attribute access") + else: + logger.error("Unsupported attribute base expression type") + return None diff --git a/pythonbpf/expr/ir_ops.py b/pythonbpf/expr/ir_ops.py index f6835e2..df6f503 100644 --- a/pythonbpf/expr/ir_ops.py +++ b/pythonbpf/expr/ir_ops.py @@ -17,41 +17,100 @@ def deref_to_depth(func, builder, val, target_depth): # dereference with null check pointee_type = cur_type.pointee - null_check_block = builder.block - not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}") - merge_block = func.append_basic_block(name=f"deref_merge_{depth}") - null_ptr = ir.Constant(cur_type, None) - is_not_null = builder.icmp_signed("!=", cur_val, null_ptr) - logger.debug(f"Inserted null check for pointer at depth {depth}") + def load_op(builder, ptr): + return builder.load(ptr) - builder.cbranch(is_not_null, not_null_block, merge_block) + cur_val = _null_checked_operation( + func, builder, cur_val, load_op, pointee_type, f"deref_{depth}" + ) + cur_type = pointee_type + logger.debug(f"Dereferenced to depth {depth}, type: {pointee_type}") + return cur_val - builder.position_at_end(not_null_block) - dereferenced_val = builder.load(cur_val) - logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}") - builder.branch(merge_block) - builder.position_at_end(merge_block) - phi = builder.phi(pointee_type, name=f"deref_result_{depth}") +def _null_checked_operation(func, builder, ptr, operation, result_type, name_prefix): + """ + Generic null-checked operation on a pointer. + """ + curr_block = builder.block + not_null_block = func.append_basic_block(name=f"{name_prefix}_not_null") + merge_block = func.append_basic_block(name=f"{name_prefix}_merge") - zero_value = ( - ir.Constant(pointee_type, 0) - if isinstance(pointee_type, ir.IntType) - else ir.Constant(pointee_type, None) - ) - phi.add_incoming(zero_value, null_check_block) + null_ptr = ir.Constant(ptr.type, None) + is_not_null = builder.icmp_signed("!=", ptr, null_ptr) + builder.cbranch(is_not_null, not_null_block, merge_block) - phi.add_incoming(dereferenced_val, not_null_block) + builder.position_at_end(not_null_block) + result = operation(builder, ptr) + not_null_after = builder.block + builder.branch(merge_block) - # Continue with phi result - cur_val = phi - cur_type = pointee_type - return cur_val + builder.position_at_end(merge_block) + phi = builder.phi(result_type, name=f"{name_prefix}_result") + + if isinstance(result_type, ir.IntType): + null_val = ir.Constant(result_type, 0) + elif isinstance(result_type, ir.PointerType): + null_val = ir.Constant(result_type, None) + else: + null_val = ir.Constant(result_type, ir.Undefined) + + phi.add_incoming(null_val, curr_block) + phi.add_incoming(result, not_null_after) + + return phi -def deref_struct_ptr( - func, builder, struct_ptr, struct_metadata, field_name, structs_sym_tab +def access_struct_field( + builder, var_ptr, var_type, var_metadata, field_name, structs_sym_tab, func=None ): - """Dereference a pointer to a struct type.""" - return deref_to_depth(func, builder, struct_ptr, 1) + """ + Access a struct field - automatically returns value or pointer based on field type. + """ + metadata = ( + structs_sym_tab.get(var_metadata) + if isinstance(var_metadata, str) + else var_metadata + ) + if not metadata or field_name not in metadata.fields: + raise ValueError(f"Field '{field_name}' not found in struct") + + field_type = metadata.field_type(field_name) + is_ptr_to_struct = isinstance(var_type, ir.PointerType) and isinstance( + var_metadata, str + ) + + # Get struct pointer + struct_ptr = builder.load(var_ptr) if is_ptr_to_struct else var_ptr + + should_load = not isinstance(field_type, ir.ArrayType) + + def field_access_op(builder, ptr): + typed_ptr = builder.bitcast(ptr, metadata.ir_type.as_pointer()) + field_ptr = metadata.gep(builder, typed_ptr, field_name) + return builder.load(field_ptr) if should_load else field_ptr + + # Handle null check for pointer-to-struct + if is_ptr_to_struct: + if func is None: + raise ValueError("func required for null-safe struct pointer access") + + if should_load: + result_type = field_type + else: + result_type = field_type.as_pointer() + + result = _null_checked_operation( + func, + builder, + struct_ptr, + field_access_op, + result_type, + f"field_{field_name}", + ) + return result, field_type + + field_ptr = metadata.gep(builder, struct_ptr, field_name) + result = builder.load(field_ptr) if should_load else field_ptr + return result, field_type diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index 1730635..dcbfe24 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -16,6 +16,7 @@ smp_processor_id, uid, skb_store_bytes, + get_current_cgroup_id, get_stack, XDP_DROP, XDP_PASS, @@ -79,6 +80,7 @@ def helper_call_handler( "handle_helper_call", "emit_probe_read_kernel_str_call", "emit_probe_read_kernel_call", + "get_current_cgroup_id", "ktime", "pid", "deref", diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index f52e87a..e59898f 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -30,6 +30,7 @@ class BPFHelperID(Enum): BPF_SKB_STORE_BYTES = 9 BPF_GET_CURRENT_PID_TGID = 14 BPF_GET_CURRENT_UID_GID = 15 + BPF_GET_CURRENT_CGROUP_ID = 80 BPF_GET_CURRENT_COMM = 16 BPF_PERF_EVENT_OUTPUT = 25 BPF_GET_STACK = 67 @@ -68,6 +69,33 @@ def bpf_ktime_get_ns_emitter( return result, ir.IntType(64) +@HelperHandlerRegistry.register( + "get_current_cgroup_id", + param_types=[], + return_type=ir.IntType(64), +) +def bpf_get_current_cgroup_id( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_current_cgroup_id helper function call. + """ + # func is an arg to just have a uniform signature with other emitters + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_CGROUP_ID.value) + fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + return result, ir.IntType(64) + + @HelperHandlerRegistry.register( "lookup", param_types=[ir.PointerType(ir.IntType(64))], diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index aecb5e9..d6a76e0 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -5,6 +5,7 @@ from pythonbpf.expr import ( get_operand_value, eval_expr, + access_struct_field, ) logger = logging.getLogger(__name__) @@ -135,7 +136,7 @@ def get_or_create_ptr_from_arg( and field_type.element.width == 8 ): ptr, sz = get_char_array_ptr_and_size( - arg, builder, local_sym_tab, struct_sym_tab + arg, builder, local_sym_tab, struct_sym_tab, func ) if not ptr: raise ValueError("Failed to get char array pointer from struct field") @@ -266,7 +267,9 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): ) -def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): +def get_char_array_ptr_and_size( + buf_arg, builder, local_sym_tab, struct_sym_tab, func=None +): """Get pointer to char array and its size.""" # Struct field: obj.field @@ -277,11 +280,11 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab) if not (local_sym_tab and var_name in local_sym_tab): raise ValueError(f"Variable '{var_name}' not found") - struct_type = local_sym_tab[var_name].metadata - if not (struct_sym_tab and struct_type in struct_sym_tab): - raise ValueError(f"Struct type '{struct_type}' not found") + struct_ptr, struct_type, struct_metadata = local_sym_tab[var_name] + if not (struct_sym_tab and struct_metadata in struct_sym_tab): + raise ValueError(f"Struct type '{struct_metadata}' not found") - struct_info = struct_sym_tab[struct_type] + struct_info = struct_sym_tab[struct_metadata] if field_name not in struct_info.fields: raise ValueError(f"Field '{field_name}' not found") @@ -292,8 +295,24 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab) ) return None, 0 - struct_ptr = local_sym_tab[var_name].var - field_ptr = struct_info.gep(builder, struct_ptr, field_name) + # Check if char array + if not ( + isinstance(field_type, ir.ArrayType) + and isinstance(field_type.element, ir.IntType) + and field_type.element.width == 8 + ): + logger.warning("Field is not a char array") + return None, 0 + + field_ptr, _ = access_struct_field( + builder, + struct_ptr, + struct_type, + struct_metadata, + field_name, + struct_sym_tab, + func, + ) # GEP to first element: [N x i8]* -> i8* buf_ptr = builder.gep( diff --git a/pythonbpf/helper/helpers.py b/pythonbpf/helper/helpers.py index c80d57d..253c4b0 100644 --- a/pythonbpf/helper/helpers.py +++ b/pythonbpf/helper/helpers.py @@ -57,6 +57,11 @@ def get_stack(buf, flags=0): return ctypes.c_int64(0) +def get_current_cgroup_id(): + """Get the current cgroup ID""" + return ctypes.c_int64(0) + + XDP_ABORTED = ctypes.c_int64(0) XDP_DROP = ctypes.c_int64(1) XDP_PASS = ctypes.c_int64(2) diff --git a/pythonbpf/helper/printk_formatter.py b/pythonbpf/helper/printk_formatter.py index 721213e..4364166 100644 --- a/pythonbpf/helper/printk_formatter.py +++ b/pythonbpf/helper/printk_formatter.py @@ -222,7 +222,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta # Special case: struct field char array needs pointer to first element if isinstance(expr, ast.Attribute): char_array_ptr, _ = get_char_array_ptr_and_size( - expr, builder, local_sym_tab, struct_sym_tab + expr, builder, local_sym_tab, struct_sym_tab, func ) if char_array_ptr: return char_array_ptr diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index ac498dc..2d0beb9 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -135,7 +135,7 @@ def process_perf_event_map(map_name, rval, module, structs_sym_tab): logger.info(f"Map parameters: {map_params}") map_global = create_bpf_map(module, map_name, map_params) # Generate debug info for BTF - create_map_debug_info(module, map_global.sym, map_name, map_params) + create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab) return map_global diff --git a/tests/c-form/xdp_test.bpf.c b/tests/c-form/xdp_test.bpf.c index e553c37..c039e11 100644 --- a/tests/c-form/xdp_test.bpf.c +++ b/tests/c-form/xdp_test.bpf.c @@ -1,19 +1,18 @@ -#include "vmlinux.h" -#include #include #include #include +#include struct fake_iphdr { - unsigned short useless; - unsigned short tot_len; - unsigned short id; - unsigned short frag_off; - unsigned char ttl; - unsigned char protocol; - unsigned short check; - unsigned int saddr; - unsigned int daddr; + unsigned short useless; + unsigned short tot_len; + unsigned short id; + unsigned short frag_off; + unsigned char ttl; + unsigned char protocol; + unsigned short check; + unsigned int saddr; + unsigned int daddr; }; SEC("xdp") @@ -21,16 +20,14 @@ int xdp_prog(struct xdp_md *ctx) { unsigned long data = ctx->data; unsigned long data_end = ctx->data_end; - if (data + sizeof(struct ethhdr) + sizeof(struct fake_iphdr) <= data_end) { - struct fake_iphdr *iph = (void *)data + sizeof(struct ethhdr); - - bpf_printk("%d", iph->saddr); - - return XDP_PASS; - } else { + if (data + sizeof(struct ethhdr) + sizeof(struct fake_iphdr) > data_end) { return XDP_ABORTED; } - struct task_struct * a = btf_bpf_get_current_task_btf(); + struct fake_iphdr *iph = (void *)data + sizeof(struct ethhdr); + + bpf_printk("%d", iph->saddr); + + return XDP_PASS; } char _license[] SEC("license") = "GPL"; diff --git a/tests/passing_tests/hash_map_struct.py b/tests/passing_tests/hash_map_struct.py index 9f6cbac..4252dd0 100644 --- a/tests/passing_tests/hash_map_struct.py +++ b/tests/passing_tests/hash_map_struct.py @@ -1,6 +1,6 @@ from pythonbpf import bpf, section, struct, bpfglobal, compile, map from pythonbpf.maps import HashMap -from pythonbpf.helper import pid +from pythonbpf.helper import pid, comm from ctypes import c_void_p, c_int64 @@ -9,6 +9,7 @@ class val_type: counter: c_int64 shizzle: c_int64 + comm: str(16) @bpf @@ -22,6 +23,7 @@ def last() -> HashMap: def hello_world(ctx: c_void_p) -> c_int64: obj = val_type() obj.counter, obj.shizzle = 42, 96 + comm(obj.comm) t = last.lookup(obj) if t: print(f"Found existing entry: counter={obj.counter}, pid={t}") diff --git a/tests/passing_tests/struct_pylib.py b/tests/passing_tests/struct_pylib.py new file mode 100644 index 0000000..0fe9018 --- /dev/null +++ b/tests/passing_tests/struct_pylib.py @@ -0,0 +1,93 @@ +""" +Test struct values in HashMap. + +This example stores a struct in a HashMap and reads it back, +testing the new set_value_struct() functionality in pylibbpf. +""" + +from pythonbpf import bpf, map, struct, section, bpfglobal, BPF +from pythonbpf.helper import ktime, smp_processor_id, pid, comm +from pythonbpf.maps import HashMap +from ctypes import c_void_p, c_int64, c_uint32, c_uint64 +import time +import os + + +@bpf +@struct +class task_info: + pid: c_uint64 + timestamp: c_uint64 + comm: str(16) + + +@bpf +@map +def cpu_tasks() -> HashMap: + return HashMap(key=c_uint32, value=task_info, max_entries=256) + + +@bpf +@section("tracepoint/sched/sched_switch") +def trace_sched_switch(ctx: c_void_p) -> c_int64: + cpu = smp_processor_id() + + # Create task info struct + info = task_info() + info.pid = pid() + info.timestamp = ktime() + comm(info.comm) + + # Store in map + cpu_tasks.update(cpu, info) + + return 0 # type: ignore + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +# Compile and load +b = BPF() +b.load() +b.attach_all() + +print("Testing HashMap with Struct Values") + +cpu_map = b["cpu_tasks"] +cpu_map.set_value_struct("task_info") # Enable struct deserialization + +print("Listening for context switches.. .\n") + +num_cpus = os.cpu_count() or 16 + +try: + while True: + time.sleep(1) + + print(f"--- Snapshot at {time.strftime('%H:%M:%S')} ---") + + for cpu in range(num_cpus): + try: + info = cpu_map.lookup(cpu) + + if info: + comm_str = ( + bytes(info.comm).decode("utf-8", errors="ignore").rstrip("\x00") + ) + ts_sec = info.timestamp / 1e9 + + print( + f" CPU {cpu}: PID={info.pid}, comm={comm_str}, ts={ts_sec:.3f}s" + ) + except KeyError: + # No data for this CPU yet + pass + + print() + +except KeyboardInterrupt: + print("\nStopped")