diff --git a/CHANGELOG.md b/CHANGELOG.md
index b73df5af..e1125451 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,8 +1,11 @@
## Changelog
-### v1.5.1 (February 23, 2026)
+### v1.5.1 (March 30, 2026)
- Fix challenge learner
- Update requirements.
+- Updated documentations website.
+- Add RAG var to LearnerPipeline and its documentation with examples.
+- Minor bug fixing in LLM-Augmenter.
### v1.5.0 (February 5, 2026)
- Fix challenge learners
diff --git a/README.md b/README.md
index 8cab52a6..4f4df5eb 100644
--- a/README.md
+++ b/README.md
@@ -134,7 +134,9 @@ print(metrics)
Other available learners:
- [LLM-Based Learner](https://ontolearner.readthedocs.io/learners/llm.html)
+- [Retriever-Based Learner](https://ontolearner.readthedocs.io/learners/retrieval.html)
- [RAG-Based Learner](https://ontolearner.readthedocs.io/learners/rag.html)
+- [LLMs4OL Challenge Learners](https://ontolearner.readthedocs.io/learners/llms4ol.html)
---
diff --git a/docs/source/index.rst b/docs/source/index.rst
index db235ce8..6709df99 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,5 +1,3 @@
-
-
.. raw:: html
@@ -109,8 +107,8 @@ Working with OntoLearner is straightforward:
random_state=42
)
- # Initialize a multi-component learning pipeline (retriever + LLM)
- # This configuration enables a Retrieval-Augmented Generation (RAG) setup
+ # RAG can be configured either by passing both IDs (shown here),
+ # or by passing a prebuilt `rag=` learner object.
pipeline = LearnerPipeline(
retriever_id='sentence-transformers/all-MiniLM-L6-v2',
llm_id='Qwen/Qwen2.5-0.5B-Instruct',
diff --git a/docs/source/learners/llm.rst b/docs/source/learners/llm.rst
index 2d732560..a3e444fe 100644
--- a/docs/source/learners/llm.rst
+++ b/docs/source/learners/llm.rst
@@ -93,7 +93,7 @@ You will see a evaluations results.
Pipeline Usage
-----------------------
-The OntoLearner package also offers a streamlined ``LearnerPipeline`` class that simplifies the entire process of initializing, training, predicting, and evaluating a RAG setup into a single call. This is particularly useful for rapid experimentation and deployment.
+The OntoLearner package also offers a streamlined ``LearnerPipeline`` class that simplifies initialization, training, prediction, and evaluation into a single call. In this section, we run the pipeline in **LLM-only** mode by setting ``llm_id`` only.
.. code-block:: python
@@ -113,7 +113,7 @@ The OntoLearner package also offers a streamlined ``LearnerPipeline`` class that
# Set up the learner pipeline using a lightweight instruction-tuned LLM
pipeline = LearnerPipeline(
- llm_id='Qwen/Qwen2.5-0.5B-Instruct', # Small-scale LLM for reasoning over term-type assignments
+ llm_id='Qwen/Qwen2.5-0.5B-Instruct', # LLM-only mode
hf_token='...', # Hugging Face access token for loading gated models
batch_size=32 # Batch size for parallel inference (if applicable)
)
diff --git a/docs/source/learners/rag.rst b/docs/source/learners/rag.rst
index 010d453f..c34d3a3e 100644
--- a/docs/source/learners/rag.rst
+++ b/docs/source/learners/rag.rst
@@ -25,8 +25,8 @@ We start by importing necessary components from the ontolearner package, loading
AgrO, # Example agricultural ontology
train_test_split, # Helper function for data splitting
LabelMapper, # Maps ontology labels to/from textual representations
- StandardizedPrompting # Standard prompting strategy across tasks
- evaluation_report
+ StandardizedPrompting, # Standard prompting strategy across tasks
+ evaluation_report,
)
# Load the AgrO ontology (an agricultural domain ontology)
@@ -99,16 +99,24 @@ To build a RAG model, you first initialize its constituent parts: an LLM learner
Pipeline Usage
---------------------
-Similar to LLM and Retrieval learner, RAG Learner is also callable via streamlined ``LearnerPipeline`` class that simplifies the entire learning process.
+Similar to LLM and Retrieval learners, RAG is callable via ``LearnerPipeline``, you can run RAG in two equivalent ways:
-You initialize the ``LearnerPipeline`` by directly providing the ``retriever_id``, ``llm_id``, and other parameters like ``hf_token``, ``batch_size``, and ``top_k`` (number of top retrievals to include in RAG prompting). Then, you simply call the ``pipeline`` instance with your ``train_data``, ``test_data``, specify ``evaluate=True`` to compute metrics, and define the ``task`` (e.g., `'term-typing'`).
+1. Provide both ``retriever_id`` and ``llm_id`` (pipeline auto-composes an ``AutoRAGLearner``).
+2. Provide a prebuilt ``rag`` learner object for custom configurations.
.. code-block:: python
- # Import core modules from the OntoLearner library
- from ontolearner import LearnerPipeline, AgrO, train_test_split
+ from ontolearner import (
+ LearnerPipeline,
+ AutoLLMLearner,
+ AutoRetrieverLearner,
+ AutoRAGLearner,
+ LabelMapper,
+ StandardizedPrompting,
+ AgrO,
+ train_test_split,
+ )
- # Load the AgrO ontology, which contains concepts related to wines, their properties, and categories
ontology = AgrO()
ontology.load() # Load entities, types, and structured term annotations from the ontology
ontological_data = ontology.extract()
diff --git a/docs/source/learners/retrieval.rst b/docs/source/learners/retrieval.rst
index f2bd41ce..05858c12 100644
--- a/docs/source/learners/retrieval.rst
+++ b/docs/source/learners/retrieval.rst
@@ -81,7 +81,7 @@ When working with large contexts, the retriever model may encounter memory issue
Pipeline Usage
-----------------------
-Similar to LLM learner, Retrieval Learner is also callable via streamlined ``LearnerPipeline`` class that simplifies the entire process learning.
+Similar to the LLM learner, Retrieval learner is also callable via the streamlined ``LearnerPipeline`` class. In this section we use **retriever-only** mode by providing ``retriever_id`` only.
.. code-block:: python
@@ -100,7 +100,7 @@ Similar to LLM learner, Retrieval Learner is also callable via streamlined ``Lea
)
# Initialize the learning pipeline using a dense retriever
- # This configuration uses sentence embeddings to match similar relational contexts
+ # This is retriever-only mode (no LLM component)
pipeline = LearnerPipeline(
retriever_id='sentence-transformers/all-MiniLM-L6-v2', # Hugging Face model ID for retrieval
batch_size=10, # Number of samples to process per batch (if batching is enabled internally)
@@ -125,6 +125,10 @@ Similar to LLM learner, Retrieval Learner is also callable via streamlined ``Lea
# Print the full output dictionary (includes predictions)
print(outputs)
+.. note::
+
+ For RAG with ``LearnerPipeline`` see: `https://ontolearner.readthedocs.io/learners/rag.html `_.
+
.. hint::
See `Learning Tasks `_ for possible tasks within Learners.
@@ -372,6 +376,9 @@ Here the ``LLMAugmentedRetrieverLearner`` is the high-level wrapper that orchest
augments = {"config": llm_augmenter_generator.get_config()}
augments[task] = llm_augmenter_generator.augment(ontological_data, task=task)
+ base_retriever = LLMAugmentedRetriever()
+ learner = LLMAugmentedRetrieverLearner(base_retriever=base_retriever)
+
learner.set_augmenter(augments)
learner.load(model_id="Qwen/Qwen3-Embedding-8B")
diff --git a/docs/source/package_reference/pipeline.rst b/docs/source/package_reference/pipeline.rst
index 48099e88..760342d7 100644
--- a/docs/source/package_reference/pipeline.rst
+++ b/docs/source/package_reference/pipeline.rst
@@ -1,6 +1,12 @@
Learner Pipeline
====================
+``LearnerPipeline`` supports:
+
+- retriever-only mode (set ``retriever_id``)
+- llm-only mode (set ``llm_id``)
+- rag mode (set both ``retriever_id`` and ``llm_id``), or provide a prebuilt ``rag`` learner
+
LearnerPipeline
---------------------
.. autoclass:: ontolearner._learner.LearnerPipeline
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index 2f7544ed..7ded5acf 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -137,7 +137,11 @@ To alighn with machine learning follow, once the ontology is loaded, and ontolog
)
-Once the data is split into training and testing sets, you can apply learning models to the ontology learning tasks. OntoLearner supports multiple modeling approaches, including retrieval-based methods, Large Language Model (LLM)-based techniques, and Retrieval-Augmented Generation (RAG) strategies. The ``LearnerPipeline`` within OntoLearner is designed for ease of use, abstracting away the complexities of loading models and preparing datasets or data loaders. You can configure the pipeline with your choice of LLMs, retrievers, or RAG components.
+Once the data is split into training and testing sets, you can apply learning models to the ontology learning tasks. OntoLearner supports multiple modeling approaches, including retrieval-based methods, Large Language Model (LLM)-based techniques, and Retrieval-Augmented Generation (RAG) strategies. The ``LearnerPipeline`` supports all three modes:
+
+- Retriever-only: set ``retriever_id``
+- LLM-only: set ``llm_id``
+- RAG: set both ``retriever_id`` + ``llm_id`` for AutoRAGLearner. For prebuild RAG pass ``rag`` learner.
In the example below, we configure a RAG-based learner by specifying the Qwen LLM (`Qwen/Qwen2.5-0.5B-Instruct `_) and a retriever based on a sentence-transformer model (`all-MiniLM-L6-v2 `_):
@@ -165,6 +169,34 @@ In the example below, we configure a RAG-based learner by specifying the Qwen LL
- ``llm_id``: The instruction-following language model used to generate candidate outputs.
- ``top_k``: Number of retrieved examples passed to the LLM (used in RAG setup).
- ``hf_token``: Required for loading gated models from Hugging Face.
+ - ``rag``: Optional prebuilt ``AutoRAGLearner`` (or compatible) object for custom RAG setups.
+
+If you already created a RAG learner object, you can pass it directly:
+
+.. code-block:: python
+
+ from ontolearner import (
+ LearnerPipeline,
+ AutoLLMLearner,
+ AutoRetrieverLearner,
+ AutoRAGLearner,
+ LabelMapper,
+ StandardizedPrompting,
+ )
+
+ retriever = AutoRetrieverLearner(top_k=3)
+ llm = AutoLLMLearner(
+ prompting=StandardizedPrompting,
+ label_mapper=LabelMapper(),
+ token=''
+ )
+ rag = AutoRAGLearner(retriever=retriever, llm=llm)
+
+ pipeline = LearnerPipeline(
+ rag=rag,
+ retriever_id='sentence-transformers/all-MiniLM-L6-v2',
+ llm_id='Qwen/Qwen2.5-0.5B-Instruct'
+ )
Once configured, the pipeline is executed on the training and test data:
diff --git a/examples/llm_learner_alexbek_rag_term_typing.py b/examples/llm_learner_alexbek_rag_term_typing.py
index 17becc25..44c0a1c4 100644
--- a/examples/llm_learner_alexbek_rag_term_typing.py
+++ b/examples/llm_learner_alexbek_rag_term_typing.py
@@ -27,11 +27,11 @@
output_dir="./results/",
)
-# Build the pipeline and pass raw structured objects end-to-end.
-# We place the RAG learner in the llm slot and set llm_id accordingly.
+# Build the pipeline and pass the dedicated RAG learner explicitly.
pipe = LearnerPipeline(
- llm=rag_learner,
+ rag=rag_learner,
llm_id="Qwen/Qwen2.5-0.5B-Instruct",
+ retriever_id="sentence-transformers/all-MiniLM-L6-v2",
ontologizer_data=True,
)
diff --git a/examples/pipeline.ipynb b/examples/pipeline.ipynb
new file mode 100644
index 00000000..af8488c9
--- /dev/null
+++ b/examples/pipeline.ipynb
@@ -0,0 +1,495 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "508bbd89-d74d-42ed-93c1-90011da9642e",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "HUGGINGFACE_ACCESS_TOKEN=\" \"\n",
+ "OPENAI_KEY=\" \"\n",
+ "\n",
+ "TASK = 'taxonomy-discovery'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "7adfb18a-3d87-4212-baad-333dea72cb90",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2026-03-30 14:52:33.687210: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2026-03-30 14:52:33.764010: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2026-03-30 14:52:42.196519: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "Failed to initialize disk cache, falling back to memory-only cache: near \"-\": syntax error\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "from ontolearner import LearnerPipeline, train_test_split, Conference\n",
+ "from ontolearner.learner import (\n",
+ " AutoLLMLearner,\n",
+ " LLMAugmentedRetrieverLearner,\n",
+ " LLMAugmentedRAGLearner,\n",
+ " StandardizedPrompting,\n",
+ " LabelMapper,\n",
+ ")\n",
+ "\n",
+ "def load_term_typing_split(test_size: float = 0.2):\n",
+ " ontology = load_term_typing()\n",
+ " return train_test_split(ontology, test_size=test_size, random_state=42)\n",
+ "\n",
+ "\n",
+ "def load_term_typing():\n",
+ " ontology = Conference()\n",
+ " ontology.load()\n",
+ " return ontology.extract()\n",
+ "\n",
+ "\n",
+ "def run_auto_llm():\n",
+ " train_data, test_data = load_term_typing_split()\n",
+ " pipeline = LearnerPipeline(\n",
+ " llm_id=\"Qwen/Qwen3-0.6B\",\n",
+ " hf_token=HUGGINGFACE_ACCESS_TOKEN,\n",
+ " batch_size=32,\n",
+ " max_new_tokens=10,\n",
+ " )\n",
+ " return pipeline(train_data=train_data, test_data=test_data, task=TASK, evaluate=True)\n",
+ "\n",
+ "\n",
+ "def run_auto_retriever():\n",
+ " train_data, test_data = load_term_typing_split()\n",
+ " pipeline = LearnerPipeline(\n",
+ " retriever_id=\"sentence-transformers/all-MiniLM-L6-v2\",\n",
+ " top_k=5,\n",
+ " )\n",
+ " return pipeline(train_data=train_data, test_data=test_data, task=TASK, evaluate=True)\n",
+ "\n",
+ "\n",
+ "def run_auto_rag():\n",
+ " train_data, test_data = load_term_typing_split()\n",
+ " pipeline = LearnerPipeline(\n",
+ " retriever_id=\"sentence-transformers/all-MiniLM-L6-v2\",\n",
+ " llm_id=\"Qwen/Qwen3-0.6B\",\n",
+ " hf_token=HUGGINGFACE_ACCESS_TOKEN,\n",
+ " top_k=5,\n",
+ " batch_size=16,\n",
+ " )\n",
+ " return pipeline(train_data=train_data, test_data=test_data, task=TASK, evaluate=True)\n",
+ "\n",
+ "def run_augmented_retriever():\n",
+ " train_data, test_data = load_term_typing_split()\n",
+ " data = load_term_typing()\n",
+ " from ontolearner.learner.retriever import LLMAugmenterGenerator, LLMAugmentedRetriever\n",
+ " llm_augmenter_generator = LLMAugmenterGenerator(model_id='gpt-4.1-mini', token = OPENAI_KEY, top_n_candidate=10)\n",
+ " augments = {\"config\": llm_augmenter_generator.get_config()}\n",
+ " augments[TASK] = llm_augmenter_generator.augment(data, task=TASK)\n",
+ " \n",
+ " retriever = LLMAugmentedRetrieverLearner(base_retriever=LLMAugmentedRetriever(), top_k=5)\n",
+ " retriever.set_augmenter(augments)\n",
+ " \n",
+ " pipeline = LearnerPipeline(retriever=retriever, retriever_id=\"sentence-transformers/all-MiniLM-L6-v2\")\n",
+ " return pipeline(train_data=train_data, test_data=test_data, task=TASK, evaluate=True)\n",
+ "\n",
+ "def run_augmented_rag():\n",
+ " train_data, test_data = load_term_typing_split()\n",
+ " data = load_term_typing()\n",
+ " from ontolearner.learner.retriever import LLMAugmenterGenerator, LLMAugmentedRetriever\n",
+ " llm_augmenter_generator = LLMAugmenterGenerator(model_id='gpt-4.1-mini', token = OPENAI_KEY, top_n_candidate=10)\n",
+ " augments = {\"config\": llm_augmenter_generator.get_config()}\n",
+ " augments[TASK] = llm_augmenter_generator.augment(data, task=TASK)\n",
+ " \n",
+ " retriever = LLMAugmentedRetrieverLearner(base_retriever=LLMAugmentedRetriever(), top_k=5)\n",
+ "\n",
+ " \n",
+ " llm = AutoLLMLearner(\n",
+ " prompting=StandardizedPrompting,\n",
+ " label_mapper=LabelMapper(),\n",
+ " token=HUGGINGFACE_ACCESS_TOKEN,\n",
+ " batch_size=16,\n",
+ " max_new_tokens=10,\n",
+ " )\n",
+ " rag = LLMAugmentedRAGLearner(retriever=retriever, llm=llm)\n",
+ " rag.set_augmenter(augments)\n",
+ " \n",
+ " pipeline = LearnerPipeline(\n",
+ " rag=rag,\n",
+ " retriever_id=\"sentence-transformers/all-MiniLM-L6-v2\",\n",
+ " llm_id=\"Qwen/Qwen3-0.6B\",\n",
+ " ontologizer_data=True,\n",
+ " )\n",
+ " return pipeline(train_data=train_data, test_data=test_data, task=TASK, evaluate=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "a3d17278-f391-4d3f-836f-597d61431f97",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`torch_dtype` is deprecated! Use `dtype` instead!\n",
+ "/nfs/home/babaeih/onto-leaarner/ontolearner/learner/llm.py:102: UserWarning: No requirement for fiting the taxonomy-discovery model, the predict module will use the input data to do the 'is-a' relationship detection\n",
+ " warnings.warn(\"No requirement for fiting the taxonomy-discovery model, the predict module will use the input data to do the 'is-a' relationship detection\")\n",
+ "100%|██████████| 3/3 [00:13<00:00, 4.40s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Metrics: {'f1_score': 0.042105263157894736, 'precision': 0.023255813953488372, 'recall': 0.2222222222222222, 'total_correct': 2, 'total_predicted': 86, 'total_ground_truth': 9}\n",
+ "Elapsed time: 13.206007719039917\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "outputs = run_auto_llm()\n",
+ "print(\"Metrics:\", outputs.get(\"metrics\"))\n",
+ "print(\"Elapsed time:\", outputs[\"elapsed_time\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "4b3e703b-3cac-4e35-97b0-132e5f5c2309",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0\n",
+ "INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n",
+ "/nfs/home/babaeih/onto-leaarner/ontolearner/learner/retriever/learner.py:80: UserWarning: No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\n",
+ " warnings.warn(\"No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\")\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1f1f59642fc44565942faa2d1731f31b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "44ec572c9e064acbb0624bddd95555d9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Metrics: {'f1_score': 0.12631578947368421, 'precision': 0.06976744186046512, 'recall': 0.6666666666666666, 'total_correct': 6, 'total_predicted': 86, 'total_ground_truth': 9}\n",
+ "Elapsed time: 0.13821125030517578\n"
+ ]
+ }
+ ],
+ "source": [
+ "outputs = run_auto_retriever()\n",
+ "print(\"Metrics:\", outputs.get(\"metrics\"))\n",
+ "print(\"Elapsed time:\", outputs[\"elapsed_time\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "7b015849-bfb8-4fb5-b297-db415e176b57",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0\n",
+ "INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n",
+ "/nfs/home/babaeih/onto-leaarner/ontolearner/learner/rag/rag.py:68: UserWarning: No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\n",
+ " warnings.warn(\"No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\")\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0fa818b921234c689716d97dbc1ef267",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c8e92739e22a43a9822d8b705e568852",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 6/6 [00:14<00:00, 2.33s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Metrics: {'f1_score': 0.1348314606741573, 'precision': 0.075, 'recall': 0.6666666666666666, 'total_correct': 6, 'total_predicted': 80, 'total_ground_truth': 9}\n",
+ "Elapsed time: 14.042728662490845\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "outputs = run_auto_rag()\n",
+ "print(\"Metrics:\", outputs.get(\"metrics\"))\n",
+ "print(\"Elapsed time:\", outputs[\"elapsed_time\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "65de59a2-7fa0-4d0b-afc6-5daee62574a3",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "33it [00:52, 1.60s/it]\n",
+ "INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0\n",
+ "INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n",
+ "/nfs/home/babaeih/onto-leaarner/ontolearner/learner/retriever/learner.py:80: UserWarning: No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\n",
+ " warnings.warn(\"No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\")\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f58e86975e3c47868a0a91040cca6773",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3596db316bb84154860fe76b855799d6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Metrics: {'f1_score': 0.09815950920245399, 'precision': 0.05194805194805195, 'recall': 0.8888888888888888, 'total_correct': 8, 'total_predicted': 154, 'total_ground_truth': 9}\n",
+ "Elapsed time: 0.05592513084411621\n"
+ ]
+ }
+ ],
+ "source": [
+ "outputs = run_augmented_retriever()\n",
+ "print(\"Metrics:\", outputs.get(\"metrics\"))\n",
+ "print(\"Elapsed time:\", outputs[\"elapsed_time\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "f4b0cae2-bbae-4647-948a-abf5bd8c5501",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "33it [00:40, 1.23s/it]\n",
+ "INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0\n",
+ "INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n",
+ "/nfs/home/babaeih/onto-leaarner/ontolearner/learner/rag/rag.py:68: UserWarning: No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\n",
+ " warnings.warn(\"No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.\")\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "93f5122640524d03aec7da7f761701c4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bc4f8f53c08a45ba96ee6960aacb72a3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Batches: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 10/10 [00:22<00:00, 2.26s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Metrics: {'f1_score': 0.10596026490066225, 'precision': 0.056338028169014086, 'recall': 0.8888888888888888, 'total_correct': 8, 'total_predicted': 142, 'total_ground_truth': 9}\n",
+ "Elapsed time: 22.675782442092896\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "outputs = run_augmented_rag()\n",
+ "print(\"Metrics:\", outputs.get(\"metrics\"))\n",
+ "print(\"Elapsed time:\", outputs[\"elapsed_time\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e0fa7848-2165-4cdf-9c3d-cce45ce54b83",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0c405cd3-f056-4b07-a496-0f25df181c1b",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b42da937-bb2b-42ae-a5a8-0d3913eef7cb",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.12",
+ "language": "python",
+ "name": "py312"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/ontolearner/_learner.py b/ontolearner/_learner.py
index 1b142e45..23262ee1 100644
--- a/ontolearner/_learner.py
+++ b/ontolearner/_learner.py
@@ -30,13 +30,18 @@
class LearnerPipeline:
"""
- Unified pipeline for ontology learning using LLMs, retrievers, or RAG-based models.
- Supports end-to-end training, prediction, and evaluation in a scikit-learn-like interface.
+ Unified pipeline for ontology learning using retriever-only, LLM-only,
+ or Retrieval-Augmented Generation (RAG) learners.
+
+ RAG can be configured in two ways:
+ 1) pass both ``retriever`` and ``llm`` (or their model IDs), or
+ 2) pass a prebuilt ``rag`` learner.
"""
def __init__(self,
retriever: Optional[Any] = None,
llm: Optional[Any] = None,
+ rag: Optional[Any] = None,
retriever_id: Optional[str] = None,
llm_id: Optional[str] = None,
prompting: Optional[AutoPrompt] = StandardizedPrompting,
@@ -46,20 +51,24 @@ def __init__(self,
top_k: int = 5,
batch_size: int = 10,
device: str = 'cpu',
- max_new_tokens: int=10):
+ max_new_tokens: int = 10):
"""
- Initialize the pipeline for a specific ontology learning task.
+ Initialize the pipeline for ontology learning tasks.
Args:
- task: One of ["term-typing", "taxonomy-discovery", "non-taxonomic-re"]
- prompting: Optional prompting strategy (defaults to StandardizedPrompting)
- retriever: Pre-initialized retriever learner (if any)
- llm: Pre-initialized LLM learner (if any)
- retriever_id: HF model ID for retriever (if not provided explicitly)
- llm_id: HF model ID for LLM (if not provided explicitly)
- hf_token: Hugging Face token (for gated LLM access)
- ontologizer_data: If True, uses Ontologizer-style datasets
- top_k: Number of top examples to retrieve for RAG or Retriever
+ retriever: Pre-initialized retriever learner.
+ llm: Pre-initialized LLM learner.
+ rag: Pre-initialized ``AutoRAGLearner`` (or compatible) instance.
+ retriever_id: Retriever model ID used when loading retriever components.
+ llm_id: LLM model ID used when loading LLM components.
+ prompting: Prompting strategy for AutoLLMLearner initialization.
+ label_mapper: Label mapper for AutoLLMLearner initialization.
+ hf_token: Hugging Face token (for gated model access).
+ ontologizer_data: If True, uses Ontologizer-style datasets by default.
+ top_k: Number of top examples retrieved for retriever/RAG workflows.
+ batch_size: Batch size used by learner backends where applicable.
+ device: Target device for model execution (e.g., 'cpu', 'cuda').
+ max_new_tokens: Max generated tokens for LLM generation.
"""
self.ontologizer_data = ontologizer_data
# Instantiate retriever
@@ -77,10 +86,14 @@ def __init__(self,
max_new_tokens=max_new_tokens)
llm_id = llm_id if llm_id is not None else 'Qwen/Qwen2.5-0.5B-Instruct'
# Determine pipeline strategy
- if retriever and llm:
+ if retriever and llm and not rag:
self.learner = AutoRAGLearner(retriever=retriever, llm=llm)
self.learner.load(retriever_id=retriever_id, llm_id=llm_id)
self.model_type = "rag"
+ elif rag:
+ self.learner = rag
+ self.learner.load(retriever_id=retriever_id, llm_id=llm_id)
+ self.model_type = "rag"
elif retriever:
self.learner = retriever
self.learner.load(model_id=retriever_id)
diff --git a/ontolearner/learner/retriever/augmented_retriever.py b/ontolearner/learner/retriever/augmented_retriever.py
index ede4414e..2f52012d 100644
--- a/ontolearner/learner/retriever/augmented_retriever.py
+++ b/ontolearner/learner/retriever/augmented_retriever.py
@@ -330,7 +330,11 @@ def augmented_retrieve(self, query: List[str], top_k: int = 5, batch_size: int =
augmented_queries, index_map = [], []
for qu_idx, qu in enumerate(query):
- augmented = self.augmenter.transform(qu, task=task)
+ try:
+ augmented = self.augmenter.transform(qu, task=task)
+ except Exception:
+ augmented = self.augmenter[task].get(qu, [qu])
+
for aug in augmented:
augmented_queries.append(aug)
index_map.append(qu_idx)
diff --git a/ontolearner/learner/term_typing/rwthdbis.py b/ontolearner/learner/term_typing/rwthdbis.py
index ed18d964..09ff40fd 100644
--- a/ontolearner/learner/term_typing/rwthdbis.py
+++ b/ontolearner/learner/term_typing/rwthdbis.py
@@ -187,12 +187,10 @@ def _term_typing(self, data: Any, test: bool = False) -> Optional[Any]:
)
return None
-################################################################################
-# Data Preprocessing ##########################################################
-################################################################################
-
-### Generate Context Information by GPT(via g4f.Client) ##########################################################
-
+ ########################
+ ### Data Preprocessing
+ ########################
+ ### Generate Context Information by GPT(via g4f.Client)
def _normalize_text(self, raw_text: str, *, drop_questions: bool = False) -> str:
"""
Normalize plain text consistently across the pipeline.
@@ -511,8 +509,7 @@ def run_bucket(bucket_rows: List[dict], out_path: Path) -> int:
)
return remaining_short
-### Extract Context Information from Ontology ##########################################################
-
+ ### Extract Context Information from Ontology
def _extract_terms_from_ontology(self, ontology: Any) -> List[str]:
"""
Collect unique term names from `ontology.type_taxonomies.taxonomies`,
@@ -640,8 +637,7 @@ def preprocess_context_from_ontology(
self.context_json_path = str(merged_path)
return merged_path
-### Process Training / Inference Data - Augmented with Context Information (from Ontology or GPT) ##########################################################
-
+ ### Process Training / Inference Data - Augmented with Context Information (from Ontology or GPT)
def _load_context_map(self) -> None:
"""
Populate in-memory maps from the context JSON (`self.context_json_path`).
@@ -709,8 +705,7 @@ def _lookup_context_info(self, raw_term: str) -> str:
break # one hit per subterm
return ".".join(matched_infos)
-### Process Training Data - for Fine-tuning for Text Classification(FT-TC) ##########################################################
-
+ ### Process Training Data - for Fine-tuning for Text Classification(FT-TC)
def _expand_multilabel_training_rows(
self, term_typings: List[Any]
) -> Tuple[List[str], List[int], Dict[int, str], Dict[str, int]]:
@@ -780,10 +775,9 @@ def _collect_eval_terms(self, eval_data: Any) -> List[str]:
return terms
-################################################################################
-# Model Training ##########################################################
-################################################################################
-
+ ####################
+ # Model Training ##
+ ####################
def _train_from_term_typings(self, train_data: Any, train_method: int = 2) -> None:
"""Train the term-typing classifier from `.term_typings`.
@@ -1041,8 +1035,6 @@ def tokenize_batch(batch: Dict[str, List[str]]):
trainer.save_model(self.output_dir)
self.tokenizer.save_pretrained(self.output_dir)
-
-
def _ensure_loaded_for_inference(self) -> None:
"""Load model/tokenizer for inference if not already loaded.
@@ -1066,10 +1058,9 @@ def _ensure_loaded_for_inference(self) -> None:
self.model.to(self.device).eval()
-################################################################################
-# Model Inference ##########################################################
-################################################################################
-
+ #####################
+ # Model Inference ##
+ #####################
def _predict_label_ids(self, terms: List[str]) -> List[int]:
"""Predict label ids (argmax) for a list of term strings.