In [1]:
# JAX STUFF
from jax import Array, grad, jit, vmap, tree_util
import jax.numpy as jnp

from chex import dataclass

from utils import tree_stack, tree_unstack, tree_index

# NOT JAX STUFF
import pandas as pd

from typing import * 

# For reading packet (.pcap) data files
import dpkt
from dpkt.utils import mac_to_str, inet_to_str

In [2]:
max_points = 1e7

# For a great example of how to parse PCAP files, I found this to be very helpful: 

with open('anarchy-online-server-side-packet-trace-1hr.pcap', 'rb') as f:
    pcap = dpkt.pcap.Reader(f)
    data = {
        'timestamp': [],
        'i_port': [], # ingress port
        'e_port': [], # egress port
        'eth_type': [], # Ethertype 
        'src': [], # packet source
        'dst': [], # packet destination
        'id': [], # Packet ID
        'pkt_len': [], # Length of the packet (in bytes?)
        'ttl': [], # Time to live
        'df': [], # Don't fragment
        'mf': [], # More fragments
        'offset': [], # Offset
        'protocol': [], # Protocol
    }
    for ts, buf in pcap:
        # Save the timestamp
        data['timestamp'].append(ts)

        # Get the ethernet frame from the packet
        eth = dpkt.ethernet.Ethernet(buf)
        data['i_port'].append(mac_to_str(eth.src))
        data['e_port'].append(mac_to_str(eth.dst))
        data['eth_type'].append(eth.type)

        # Now access the data within the Ethernet frame (the IP packet) TODO: Investigate this further.
        ip = eth.data
        data['src'].append(inet_to_str(ip.src))
        data['dst'].append(inet_to_str(ip.dst))
        data['id'].append(ip.id)
        data['pkt_len'].append(ip.len)
        data['ttl'].append(ip.ttl)
        data['df'].append(ip.df)
        data['mf'].append(ip.mf)
        data['offset'].append(ip.offset)
        data['protocol'].append(ip.p)
        
        if len(data['timestamp']) == max_points:
            break

# Put everything in a dataframe, and get the time in a better format
df = pd.DataFrame(data)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s')
df.set_index('timestamp', inplace=True)

In [7]:
# Get some general stats about the data
print(f"Loaded N {len(df)} packets of data covering {(df.index[-1] - df.index[0]).total_seconds()} seconds starting at {df.index[0]}")

devices = pd.unique(df['src'])
print("Number of devices", len(devices))

# Get the # of ports per device (for this dataset, most devices only have a single port)
ports_per_device = [len(pd.unique(df[df['src'] == device]['i_port'])) for device in devices]
print("Max ports on 1 device", max(ports_per_device))

Loaded N 964842 packets of data covering 3918.465169 seconds starting at 2005-09-14 13:21:52.803170944
Number of devices 120
Max ports on a device 2


In [6]:
# TODO: Need to figure out how to map the contents of the PCAP file into the following data format used by the paper. 
# Eventually, we could try to use additional features of the PCAP to get better results

@dataclass
class Packet:
    pid: int
    fid: int
    length: int
    trp: int
    in_port: int # Extra feature added in pre PFM augmentation

@dataclass
class Tau:
    timestamp: float # timestamp are floats in seconds (same as PCAP)
    packet: Packet

@dataclass
class Link: 
    bandwidth: float
    propagation_speed: float
    length: float

example_packet = Packet(
    pid=0,
    fid=0,
    length=50,
    trp=0,
    in_port=0
)
example_timestamp = 0.0
example_tau = Tau(
    timestamp=example_timestamp,
    packet=example_packet
)
example_link = Link(
    bandwidth=100.0, # Bytes per second
    length=1.0, # Meters?
    propagation_speed=2.0 # Meters? / seconds
)

In [8]:
# Given a link with capacity c and length l, forward a packet through the link
def link_forward(tau: Tau, link: Link):
    return Tau(
        timestamp=tau.timestamp + tau.packet.length / link.bandwidth + link.length / link.propagation_speed,
        packet=tau.packet # Don't modify the packet, since this is only a link.
    )

# forward the example tau over the example link (it works!)
example_forwarded_tau = link_forward(example_tau, example_link)
print(example_forwarded_tau) # Timestamp is 1.0

# Because we have written our Tau's using Chex dataclass, we can automatically vectorize over taus (exactly Eq. 5 in the paper)
batch_size = 256
batch_taus = tree_stack([example_forwarded_tau] * batch_size) 
many_taus = tree_unstack(vmap(link_forward, in_axes=(0, None))(batch_taus, example_link))
print(many_taus[0]) # Timestamp is 2.0 for all taus in the batch.

Tau(timestamp=1.0, packet=Packet(pid=0, fid=0, length=50, trp=0, in_port=0))
Tau(timestamp=Array(2., dtype=float32, weak_type=True), packet=Packet(pid=Array(0, dtype=int32, weak_type=True), fid=Array(0, dtype=int32, weak_type=True), length=Array(50, dtype=int32, weak_type=True), trp=Array(0, dtype=int32, weak_type=True), in_port=Array(0, dtype=int32, weak_type=True)))


In [None]:
# Given a ingress stream, forward the packets from ingress to egress
def forward_trace(trace_in):
    # First forward the trace from ingress to egress port
    # TODO: Implement
    # Then estimate the delay with the PTM model
    # TODO: Implement
    pass
