diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 2495cbc..90b14a4 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e6e596..cfd2161 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index be8e18a..443a7e8 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,27 @@ openai-to-sqlite search content.db 'this is my search term' -t documents Add `--count 20` to retrieve 20 results (the default is 10). ``` +### Search for similar content with the similar command + +Having saved the embeddings for content, you can search for similar content with the `similar` command: +```bash +oopenai-to-sqlite similar embeddings.db '' +``` +The output will be a list of cosine similarity scores and content IDs: +``` +% openai-to-sqlite similar embeddings-bjcp-2021.db '23G Gose' +1.000 23G Gose +0.929 24A Witbier +0.921 23A Berliner Weisse +0.909 05B Kölsch +0.907 01D American Wheat Beer +0.906 27 Historical Beer: Lichtenhainer +0.905 23D Lambic +0.905 10A Weissbier +0.904 04B Festbier +0.904 01B American Lager +``` + ## Development To contribute to this tool, first checkout the code. Then create a new virtual environment: diff --git a/openai_to_sqlite/cli.py b/openai_to_sqlite/cli.py index ceb0027..19efd02 100644 --- a/openai_to_sqlite/cli.py +++ b/openai_to_sqlite/cli.py @@ -378,6 +378,44 @@ def batch_rows(rows, batch_size): yield batch +@cli.command() +@click.argument( + "db_path", + type=click.Path(file_okay=True, dir_okay=False, allow_dash=False), +) +@click.argument("entry") +@click.option( + "table_name", + "-t", + "--table", + default="embeddings", + help="Name of the table containing the embeddings", +) +def similar(db_path, entry, table_name): + """ + Display similar entries + """ + db = sqlite_utils.Database(db_path) + table = db[table_name] + if not table.exists(): + raise click.ClickException(f"Table {table_name} does not exist") + # Fetch the embedding for the query + try: + row = table.get(entry) + except sqlite_utils.db.NotFoundError: + raise click.ClickException(f"Entry not found:" + entry) + vector = decode(row["embedding"]) + # Now calculate cosine similarity with everything in the database table + other_vectors = [(row["id"], decode(row["embedding"])) for row in table.rows] + results = [ + (id, cosine_similarity(vector, other_vector)) + for id, other_vector in other_vectors + ] + results.sort(key=lambda r: r[1], reverse=True) + for id, score in results[:10]: + print(f"{score:.3f} {id}") + + encoding = None diff --git a/setup.py b/setup.py index dc8cac3..353270a 100644 --- a/setup.py +++ b/setup.py @@ -33,5 +33,5 @@ def get_long_description(): """, install_requires=["click", "httpx", "sqlite-utils>=3.28", "openai", "tiktoken"], extras_require={"test": ["pytest", "pytest-httpx", "pytest-mock"]}, - python_requires=">=3.7", + python_requires=">=3.8", ) diff --git a/tests/test_openai_to_sqlite.py b/tests/test_openai_to_sqlite.py index f53bc2a..93f9110 100644 --- a/tests/test_openai_to_sqlite.py +++ b/tests/test_openai_to_sqlite.py @@ -258,3 +258,34 @@ def test_query(mocker): {"role": "user", "content": "hello"}, ], ) + + +@pytest.mark.parametrize("table_option", (None, "-t", "--table")) +def test_similar(httpx_mock, tmpdir, table_option): + db_path = str(tmpdir / "embeddings.db") + db = sqlite_utils.Database(db_path) + table = "embeddings" + if table_option: + table = "other_table" + db[table].insert_all( + [ + {"id": 1, "embedding": MOCK_EMBEDDING}, + {"id": 2, "embedding": MOCK_EMBEDDING}, + ], + pk="id", + ) + extra_opts = [] + if table_option: + extra_opts.extend([table_option, "other_table"]) + runner = CliRunner() + result = runner.invoke( + cli, + [ + "similar", + db_path, + "1" + ] + + extra_opts, + ) + assert result.exit_code == 0 + assert result.output == "1.000 1\n1.000 2\n"