In [1]:
import os
spark_home = os.path.abspath(os.getcwd() + "/../spark-3.5.5-bin-hadoop3")
os.environ["SPARK_HOME"] = spark_home
print(f"SPARK_HOME set to: {spark_home}")

if os.name == 'nt':
    hadoop_home = os.path.abspath(os.getcwd() + "/../winutils")
    os.environ["HADOOP_HOME"] = hadoop_home
    print(f"HADOOP_HOME set to: {hadoop_home}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"Added Hadoop bin to PATH: {hadoop_bin}")

import findspark
findspark.init(spark_home)

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, from_json, schema_of_json
from pyspark.sql.types import ArrayType, StringType

spark = SparkSession.builder.appName("ArxivCategoryPrediction").getOrCreate()
sc = spark.sparkContext
print("SparkSession and SparkContext initialized successfully.")


SPARK_HOME set to: C:\spark_project\spark\spark-3.5.5-bin-hadoop3
HADOOP_HOME set to: C:\spark_project\spark\winutils
Added Hadoop bin to PATH: C:\spark_project\spark\winutils\bin
SparkSession and SparkContext initialized successfully.


In [2]:
import torch
from transformers import AutoTokenizer, AutoModel
import joblib

tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
scibert = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased").eval()
clf = joblib.load("cleaned_classifier.pkl")
mlb = joblib.load("label_binarizer.pkl")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scibert.to(device)

def predict_labels(title, summary, threshold=0.4):
    text = f"{title.strip()} {summary.strip()}"
    with torch.no_grad():
        tokens = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        tokens = {k: v.to(device) for k, v in tokens.items()}
        outputs = scibert(**tokens)
        cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
    proba = clf.predict_proba(cls_embedding)[0]
    labels = [label for label, p in zip(mlb.classes_, proba) if p >= threshold]
    return labels

predict_udf = udf(lambda title, summary: predict_labels(title, summary), ArrayType(StringType()))




In [3]:
host = "seppe.net"
port = 7778

raw_stream = spark.readStream \
    .format("socket") \
    .option("host", host) \
    .option("port", port) \
    .load()


In [4]:
# Defining schema using a sample JSON
sample_json = """
{
  "aid": "http://arxiv.org/abs/2503.19871v1",
  "title": "A natural MSSM from a novel $\\\\mathsf{SO(10)}$ [...]",
  "summary": "The $\\\\mathsf{SO(10)}$ model [...]",
  "main_category": "hep-ph",
  "categories": "hep-ph,hep-ex",
  "published": "2025-03-25T17:36:54Z"
}
"""

schema = schema_of_json(sample_json)

parsed_stream = raw_stream \
    .select(from_json(col("value"), schema).alias("json")) \
    .select("json.*")

In [None]:
import json

def process_batch(batch_df, batch_id):
    print(f"Processing batch_id: {batch_id} with {batch_df.count()} records.")
    pdf = batch_df.toPandas()

    predictions = []

    for _, row in pdf.iterrows():
        predicted = predict_labels(row["title"], row["summary"])
        predictions.append(json.dumps(predicted))

    pdf["predicted_labels"] = predictions

    print(pdf[["aid", "predicted_labels"]])

query = parsed_stream.writeStream \
    .foreachBatch(process_batch) \
    .start()

query.awaitTermination()


Processing batch_id: 0 with 0 records.
Empty DataFrame
Columns: [aid, predicted_labels]
Index: []
Processing batch_id: 1 with 8 records.
                                 aid predicted_labels
0  http://arxiv.org/abs/2505.20078v1    ["gr", "hep"]
1  http://arxiv.org/abs/2505.20079v1           ["cs"]
2  http://arxiv.org/abs/2505.20080v1         ["math"]
3  http://arxiv.org/abs/2505.20081v1           ["cs"]
4  http://arxiv.org/abs/2505.20082v1           ["cs"]
5  http://arxiv.org/abs/2505.20083v1         ["math"]
6  http://arxiv.org/abs/2505.20084v1    ["gr", "hep"]
7  http://arxiv.org/abs/2505.20085v1           ["cs"]
Processing batch_id: 2 with 7 records.
                                 aid     predicted_labels
0  http://arxiv.org/abs/2505.20086v1                   []
1  http://arxiv.org/abs/2505.20087v1               ["cs"]
2  http://arxiv.org/abs/2505.20088v1               ["cs"]
3  http://arxiv.org/abs/2505.20089v1               ["cs"]
4  http://arxiv.org/abs/2505.20090v1       ["cs"

In [None]:
query.stop()