Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,586 changes: 415 additions & 1,171 deletions docs/examples/bayesian_optimization/00_bayes_study.ipynb

Large diffs are not rendered by default.

12,466 changes: 1,266 additions & 11,200 deletions docs/examples/grid_study/00_grid_study.ipynb

Large diffs are not rendered by default.

222 changes: 72 additions & 150 deletions docs/examples/grid_study/01_custom_grid_study.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,37 @@
"%pip install redis-retrieval-optimizer"
]
},
{
"cell_type": "markdown",
"id": "a498afe9",
"metadata": {},
"source": [
"## Check version"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5eea1c17",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.4.1'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import redis_retrieval_optimizer\n",
"\n",
"redis_retrieval_optimizer.__version__"
]
},
{
"cell_type": "markdown",
"id": "270a4f1b",
Expand All @@ -45,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "b66894d7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -246,12 +277,14 @@
"def gather_pre_filter_results(search_method_input: SearchMethodInput) -> SearchMethodOutput:\n",
" redis_res_vector = {}\n",
"\n",
" for key in search_method_input.raw_queries:\n",
" query_info = search_method_input.raw_queries[key]\n",
" query = pre_filter_query(query_info, 10, search_method_input.emb_model)\n",
" for key, query_info in search_method_input.raw_queries.items():\n",
"\n",
" query = pre_filter_query(query_info, search_method_input.ret_k, search_method_input.emb_model)\n",
"\n",
" res = run_search_w_time(\n",
" search_method_input.index, query, search_method_input.query_metrics\n",
" )\n",
"\n",
" score_dict = make_score_dict_vec(res, id_field_name=\"_id\")\n",
"\n",
" redis_res_vector[key] = score_dict\n",
Expand All @@ -265,12 +298,16 @@
"def gather_vector_results(search_method_input: SearchMethodInput) -> SearchMethodOutput:\n",
" redis_res_vector = {}\n",
"\n",
" for key in search_method_input.raw_queries:\n",
" text_query = search_method_input.raw_queries[key]\n",
" vec_query = vector_query(text_query, 10, search_method_input.emb_model)\n",
" for key, text_query in search_method_input.raw_queries.items():\n",
" # create query\n",
" vec_query = vector_query(text_query, search_method_input.ret_k, search_method_input.emb_model)\n",
"\n",
" # run with timing helper\n",
" res = run_search_w_time(\n",
" search_method_input.index, vec_query, search_method_input.query_metrics\n",
" )\n",
"\n",
" # format scores dict for ranx evaluation\n",
" score_dict = make_score_dict_vec(res, id_field_name=\"_id\")\n",
" redis_res_vector[key] = score_dict\n",
" \n",
Expand Down Expand Up @@ -335,133 +372,18 @@
"id": "cc56171b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/tyler.hutcherson/Library/Caches/pypoetry/virtualenvs/redis-retrieval-optimizer-Z5sMIYJj-py3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"09:56:39 datasets INFO PyTorch version 2.7.0 available.\n",
"09:56:40 sentence_transformers.SentenceTransformer INFO Use pytorch device_name: mps\n",
"09:56:40 sentence_transformers.SentenceTransformer INFO Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 4.18it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Recreating: loading corpus from file\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 2.60it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.69it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.04it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.18it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.12it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.56it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.25it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.53it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.27it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.48it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.56it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.55it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.90it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.01it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.20it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.45it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.65it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.36it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.15it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.05it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.01it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.60it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.79it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.81it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 32.98it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 18.21it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.96it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.50it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.06it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.57it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.98it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.42it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.87it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.03it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.14it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.35it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.25it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 32.23it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.36it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.40it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.12it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 35.41it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.67it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 34.36it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 33.65it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 32.82it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 4.29it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"09:56:43 sentence_transformers.SentenceTransformer INFO Use pytorch device_name: mps\n",
"09:56:43 sentence_transformers.SentenceTransformer INFO Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 55.47it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running search method: basic_vector\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Batches: 100%|██████████| 1/1 [00:00<00:00, 9.15it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.11it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.83it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.65it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 85.35it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.78it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 76.28it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 82.05it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 73.41it/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 72.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running search method: pre_filter_vector\n"
"14:59:00 datasets INFO PyTorch version 2.3.0 available.\n",
"14:59:00 sentence_transformers.SentenceTransformer INFO Use pytorch device_name: mps\n",
"14:59:00 sentence_transformers.SentenceTransformer INFO Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n",
"Recreating: loading corpus from file\n",
"14:59:08 sentence_transformers.SentenceTransformer INFO Use pytorch device_name: mps\n",
"14:59:08 sentence_transformers.SentenceTransformer INFO Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2\n",
"Running search method: basic_vector with dtype: float32\n",
"Running search method: pre_filter_vector with dtype: float32\n"
]
}
],
Expand Down Expand Up @@ -490,7 +412,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "47ef7edc",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -518,57 +440,57 @@
" <th>search_method</th>\n",
" <th>model</th>\n",
" <th>avg_query_time</th>\n",
" <th>recall@k</th>\n",
" <th>recall</th>\n",
" <th>precision</th>\n",
" <th>ndcg@k</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>pre_filter_vector</td>\n",
" <td>sentence-transformers/all-MiniLM-L6-v2</td>\n",
" <td>0.001590</td>\n",
" <td>1.0</td>\n",
" <td>0.25</td>\n",
" <td>0.914903</td>\n",
" <td>0.000536</td>\n",
" <td>1.000000</td>\n",
" <td>0.416667</td>\n",
" <td>0.553810</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>basic_vector</td>\n",
" <td>sentence-transformers/all-MiniLM-L6-v2</td>\n",
" <td>0.002136</td>\n",
" <td>0.9</td>\n",
" <td>0.23</td>\n",
" <td>0.717676</td>\n",
" <td>0.001578</td>\n",
" <td>0.866667</td>\n",
" <td>0.350000</td>\n",
" <td>0.470476</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" search_method model avg_query_time \\\n",
"1 pre_filter_vector sentence-transformers/all-MiniLM-L6-v2 0.001590 \n",
"0 basic_vector sentence-transformers/all-MiniLM-L6-v2 0.002136 \n",
"1 pre_filter_vector sentence-transformers/all-MiniLM-L6-v2 0.000536 \n",
"0 basic_vector sentence-transformers/all-MiniLM-L6-v2 0.001578 \n",
"\n",
" recall@k precision ndcg@k \n",
"1 1.0 0.25 0.914903 \n",
"0 0.9 0.23 0.717676 "
" recall precision f1 \n",
"1 1.000000 0.416667 0.553810 \n",
"0 0.866667 0.350000 0.470476 "
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"f1\"]].sort_values(by=\"f1\", ascending=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "redis-retrieval-optimizer-Z5sMIYJj-py3.11",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -582,7 +504,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
Loading