## BGP LLM based Anomaly Detection Pipeline
#### Powered by Scikit-Learn and LLaMA2 via Ollama

### License
This software is licensed under the GNU Affero General Public License v3.0 (AGPLv3).
Copyright (C) 2024 [IP INFUSION INC.]
Author: [Shaji R. Nathan]
Contact: [shaji.nathan@ipinfusion.com]

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Affero General Public License for more details.

You should have received a copy of the GNU AGPLv3 along with this program. If not, see https://www.gnu.org/licenses/agpl-3.0.html.

### Description
This software provides an automated pipeline for BGP anomaly detection, integrating:

* Scikit-learn for data preprocessing
* Ollama-hosted LLaMA2 for AI-powered classification
* Pandas & NumPy for dataset handling

It loads BGP data from RIS (Routing Information Service), preprocesses it using a Scikit-learn pipeline, and classifies entries as "normal" or "abnormal" using an AI-driven anomaly detection approach.

### Features

* Preprocesses BGP data (missing values, scaling, encoding)
* Uses LLaMA2 AI Model for anomaly detection via Ollama
* Fully automated pipeline using Scikit-learn
* Locally hosted, no external dependencies

### Usage 

To run this notebook, install dependencies:

* pip install scikit-learn pandas requests jupyter 
* Start the Ollama server: (Download the server and model from : https://ollama.com/)
* ### ollama serve

In [5]:
import json
import websocket
import threading
import ipywidgets as widgets
import pandas as pd
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [6]:
# Initialize DataFrame to store messages
columns = ['timestamp', 'peer', 'peer_asn', 'id', 'host', 'type', 
           'path', 'community', 'origin', 'announcements', 'withdrawals', 'raw']
data_df = pd.DataFrame(columns=columns)

# Global control variables
ws = None
ws_thread = None
message_count = 0
max_messages = 10  # Default value
is_running = False

# Output widget to display logs
output = widgets.Output()

# Widgets for user inputs and controls
message_input = widgets.IntText(value=10, description='Messages:', min=1)
filename_input = widgets.Text(value='ris_live_data.csv', description='Filename:')
start_button = widgets.Button(description='Start', button_style='success')
stop_button = widgets.Button(description='Stop', button_style='danger')
save_button = widgets.Button(description='Save CSV', button_style='info')

# Function to handle incoming messages
def on_message(ws, message):
    global message_count, data_df, max_messages

    parsed = json.loads(message)
    msg_type = parsed.get("type")

    if msg_type == "ris_message":
        data = parsed["data"]
        new_row = {
            'timestamp': data.get('timestamp'),
            'peer': data.get('peer'),
            'peer_asn': data.get('peer_asn'),
            'id': data.get('id'),
            'host': data.get('host'),
            'type': data.get('type'),
            'path': data.get('path'),
            'community': data.get('community'),
            'origin': data.get('origin'),
            'announcements': data.get('announcements'),
            'withdrawals': data.get('withdrawals'),
            'raw': data.get('raw')
        }

        # Replace None values and append using .loc
        new_row_filtered = {k: (v if v is not None else '') for k, v in new_row.items()}
        data_df.loc[len(data_df)] = new_row_filtered

        message_count += 1
        with output:
            print(f"Received message {message_count}")

        if message_count >= max_messages:
            with output:
                print(f"Reached message limit ({max_messages}). Stopping WebSocket.")
            ws.close()

def on_error(ws, error):
    with output:
        print("WebSocket Error:", error)

def on_close(ws, close_status_code, close_msg):
    global is_running
    is_running = False
    with output:
        print(f"WebSocket closed (Code: {close_status_code}, Message: {close_msg})")
        print(f"Total messages received: {message_count}")

def on_open(ws):
    with output:
        print("WebSocket connection opened.")

    # Subscribe without filters
    params = {
        "type": "ris_subscribe",
        "data": {}
    }
    ws.send(json.dumps(params))
    with output:
        print("Subscribed to RIS feed.")

# Function to start the WebSocket connection
def start_websocket(b):
    global ws, ws_thread, message_count, max_messages, data_df, is_running

    if is_running:
        with output:
            print("WebSocket is already running.")
        return

    # Reset variables
    max_messages = message_input.value
    message_count = 0
    data_df = pd.DataFrame(columns=columns)
    is_running = True

    clear_output(wait=True)
    display(ui, output)

    ws_url = "wss://ris-live.ripe.net/v1/ws/?client=py-enhanced"

    def run_ws():
        global ws
        ws = websocket.WebSocketApp(ws_url,
                                    on_open=on_open,
                                    on_message=on_message,
                                    on_error=on_error,
                                    on_close=on_close)
        ws.run_forever()

    ws_thread = threading.Thread(target=run_ws)
    ws_thread.start()

def stop_websocket(b):
    global ws, is_running
    if ws:
        ws.close()
        is_running = False
        with output:
            print("WebSocket connection stopped.")

def save_to_csv(b):
    filename = filename_input.value.strip()
    if not filename.endswith('.csv'):
        filename += '.csv'

    if not data_df.empty:
        data_df.to_csv(filename, index=False)
        with output:
            print(f"Data saved to '{filename}'")
    else:
        with output:
            print("No data to save.")

# Bind buttons to functions
start_button.on_click(start_websocket)
stop_button.on_click(stop_websocket)
save_button.on_click(save_to_csv)

# UI layout
ui = widgets.VBox([
    widgets.HBox([message_input, filename_input]),
    widgets.HBox([start_button, stop_button, save_button]),
])

# Display UI
display(ui, output)


VBox(children=(HBox(children=(IntText(value=10, description='Messages:'), Text(value='ris_live_data.csv', desc…

Output()

WebSocket connection opened.
Subscribed to RIS feed.
Received message 1
Received message 2
Received message 3
Received message 4
Received message 5
Received message 6
Received message 7
Received message 8
Received message 9
Received message 10
Received message 11
Received message 12
Received message 13
Received message 14
Received message 15
Received message 16
Received message 17
Received message 18
Received message 19
Received message 20
Received message 21
Received message 22
Received message 23
Received message 24
Received message 25
Received message 26
Received message 27
Received message 28
Received message 29
Received message 30
Received message 31
Received message 32
Received message 33
Received message 34
Received message 35
Received message 36
Received message 37
Received message 38
Received message 39
Received message 40
Reached message limit (40). Stopping WebSocket.
WebSocket closed (Code: None, Message: None)
Total messages received: 40


## Load the BGP Data and setup an SKlearn ML Pipeline

In [9]:
import pandas as pd
import numpy as np
import requests
import ast

from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn import set_config
#diagram the sklearn pipeline
set_config(display='diagram')

## Specify LLM in use 

In [10]:
model ="llama2"
ollama_url="http://localhost:11434"

## Load Preprocessed RiS BGP Data 

In [11]:
# Load the RIS BGP data from CSV
csv_filename = "./data/ris_live_data.csv"
bgp_data = pd.read_csv(csv_filename)

# Display first few rows to inspect
print("Sample RIS BGP Data:")
print(bgp_data.head())

# Select features for ML pipeline
numerical_features = ['timestamp', 'peer_asn']
categorical_features = ['type', 'origin', 'host']

# Handle complex fields (e.g., lists stored as strings)
for col in ['path', 'community', 'announcements', 'withdrawals']:
    if col in bgp_data.columns:
        bgp_data[col] = bgp_data[col].fillna('[]').apply(ast.literal_eval)


Sample RIS BGP Data:
      timestamp           peer  peer_asn                              id  \
0  1.740634e+09  193.148.251.1     34927  193.148.251.1-019545dbaa7a004b   
1  1.740634e+09  193.148.251.1     34927  193.148.251.1-019545dbaa7a004c   
2  1.740634e+09  193.148.251.1     34927  193.148.251.1-019545dbaa7a004d   
3  1.740634e+09  193.148.251.1     34927  193.148.251.1-019545dbaa7a004e   
4  1.740634e+09  193.148.251.1     34927  193.148.251.1-019545dbaa7a004f   

             host    type                                               path  \
0  rrc00.ripe.net  UPDATE            [34927, 52025, 174, 4230, 28283, 19990]   
1  rrc00.ripe.net  UPDATE  [34927, 52025, 174, 13786, 262903, 262591, 262...   
2  rrc00.ripe.net  UPDATE                        [34927, 52025, 174, 394380]   
3  rrc00.ripe.net  UPDATE         [34927, 52025, 174, 265269, 266228, 52936]   
4  rrc00.ripe.net  UPDATE                 [34927, 52025, 174, 17378, 396151]   

        community origin                 

## Data Preprocessing Pipeline

In [12]:
# Preprocessing for numerical features
numerical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='mean')),  # Fill missing values with mean
    ('scaler', StandardScaler())                  # Normalize data
])

# Preprocessing for categorical features
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),  # Fill missing values
    ('encoder', OneHotEncoder(handle_unknown='ignore'))    # One-hot encode
])

# Combine preprocessing steps
preprocessor = ColumnTransformer(transformers=[
    ('num', numerical_transformer, numerical_features),
    ('cat', categorical_transformer, categorical_features)
])


## Create a Custom LLaMA2 Classifier

In [13]:
class LLaMA2LocalClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, model_name=model, url=ollama_url):
        self.model_name = model_name
        self.url = url
        self.feature_names = None  # Placeholder for feature names

    def fit(self, X, y=None):
        """ Store transformed feature names after preprocessing. """
        if hasattr(X, "columns"):
            self.feature_names = X.columns.tolist()
        return self

    def predict(self, X):
        """ Convert NumPy array back to DataFrame if needed and classify data. """
        if isinstance(X, np.ndarray):
            if self.feature_names is None:
                raise ValueError("Feature names not set. Ensure the classifier is fitted first.")

            # Ensure feature names match transformed data shape
            if X.shape[1] != len(self.feature_names):
                raise ValueError(f"Feature count mismatch: expected {len(self.feature_names)}, got {X.shape[1]}")

            X = pd.DataFrame(X, columns=self.feature_names)

        predictions = []
        for i in range(X.shape[0]):
            sample = X.iloc[i].to_dict()
            response = self.query_llama2(sample)
            prediction = self.parse_response(response)
            predictions.append(prediction)
        return np.array(predictions)


    def query_llama2(self, sample):
        """ Send prompt to local instance of  Ollama LLaMA2 API. """
        prompt = f"""
        You are a network security expert specializing in BGP (Border Gateway Protocol) analysis.
        You need to classify BGP update messages as 'normal' or 'abnormal'. 

        ## Definitions:
        - **Normal BGP Data:** Expected AS paths, no sudden prefix withdrawals, and stable peer behavior.
        - **Abnormal BGP Data:** Large-scale prefix withdrawals, hijacked AS paths, or unexpected peer ASN activity.

        ## Examples:
        1. Normal: Timestamp: 1740523555, Peer ASN: 34019, Type: UPDATE, Origin: IGP, Path: [34019, 3303, 3356]
        2. Abnormal: Timestamp: 1740523556, Peer ASN: 65432, Type: UPDATE, Origin: IGP, Path: [65432, 3356, 12389] (unexpected peer)

        ## Now classify this BGP message:
        {sample}

        Reply only with 'normal' or 'abnormal'.
        """

        payload = {
            "model": self.model_name,
            "prompt": prompt,
            "stream": False
        }

        response = requests.post(f"{self.url}/api/generate", json=payload)
        if response.status_code == 200:
            return response.json()
        else:
            raise Exception(f"Ollama API Error: {response.status_code}, {response.text}")

    def parse_response(self, response):
        """ Extract classification from the API response. """
        text = response.get('response', '').strip().lower()
        return 'abnormal' if 'abnormal' in text else 'normal'


## Create the ML Pipeline

In [17]:
# Initialize the LLaMA2 Local Classifier

llama2_local_clf = LLaMA2LocalClassifier(model_name=model, url= ollama_url)

# Create the full pipeline
ml_pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', llama2_local_clf)
])

## Display the pipeline 
ml_pipeline

## Train & Run the Pipeline

In [18]:
# Prepare Data
bgp_data_filtered = bgp_data[numerical_features + categorical_features]


# Fit the pipeline with a dummy target variable (since LLaMA2 does not need training)
dummy_y = np.zeros(len(bgp_data_filtered))  # Required for scikit-learn API

ml_pipeline.fit(bgp_data_filtered, dummy_y)  # Fit preprocessing steps

# Extract transformed feature names from the preprocessor

preprocessor = ml_pipeline.named_steps['preprocessor']
feature_names = preprocessor.get_feature_names_out()

# Ensure the classifier gets the correct feature names

ml_pipeline.named_steps['classifier'].feature_names = list(feature_names)

# Run prediction after ensuring pipeline is fitted

predictions = ml_pipeline.predict(bgp_data_filtered)

# Store classified results
bgp_data['classification'] = predictions
bgp_data.to_csv("./data/classified_ris_bgp_data3.csv", index=False)

#Display Classified BGP data 

print("\nClassified BGP Data:")
print(bgp_data[['timestamp', 'peer', 'peer_asn', 'type', 'classification']])



Classified BGP Data:
       timestamp           peer  peer_asn    type classification
0   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
1   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
2   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
3   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
4   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
5   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
6   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
7   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
8   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
9   1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
10  1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
11  1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
12  1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
13  1.740634e+09  193.148.251.1     34927  UPDATE       abnormal
14 