Skip to content

Commit

Permalink
Merge pull request #832 from readthedocs/similarity-api
Browse files Browse the repository at this point in the history
Add initial similarity API
  • Loading branch information
ericholscher committed Mar 7, 2024
2 parents 9a173cf + e93206b commit 84a4111
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .envs/local/django.sample
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ METABASE_SECRET_KEY=000000000000000000000000000000000000000000000000000000000000
# This is a workaround for some celery issues that are likely fixed in future versions.
# https://github.com/celery/celery/issues/5761
COLUMNS=80

# Analyzer
# ------------------------------------------------------------------------------
# See ``adserver.analyzer.backends`` for available backends
# ADSERVER_ANALYZER_BACKEND=
14 changes: 12 additions & 2 deletions adserver/analyzer/backends/st.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os

import trafilatura
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from textacy import preprocessing
Expand All @@ -16,10 +17,11 @@ class SentenceTransformerAnalyzerBackend(BaseAnalyzerBackend):
Quick and dirty analyzer that uses the SentenceTransformer library
"""

MODEL_NAME = "multi-qa-MiniLM-L6-cos-v1"
MODEL_NAME = os.getenv("SENTENCE_TRANSFORMERS_MODEL", "multi-qa-MiniLM-L6-cos-v1")
MODEL_HOME = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers")

def preprocess_text(self, text):
log.info("Preprocessing text: %s", text)
self.preprocessor = preprocessing.make_pipeline(
preprocessing.normalize.unicode,
preprocessing.remove.punctuation,
Expand All @@ -28,14 +30,22 @@ def preprocess_text(self, text):
return self.preprocessor(text).lower()[: self.MAX_INPUT_LENGTH]

def analyze_response(self, resp):
# Disable the analysis for now
return []

def get_content(self, *args):
downloaded = trafilatura.fetch_url(self.url)
result = trafilatura.extract(
downloaded, include_comments=False, include_tables=False
)
return self.preprocess_text(result)

def embed_response(self, resp) -> list:
"""Analyze an HTTP response and return a list of keywords/topics for the URL."""
model = SentenceTransformer(self.MODEL_NAME, cache_folder=self.MODEL_HOME)
text = self.get_content(resp)
if text:
log.info("Embedding text: %s", text[:100])
log.info("Postprocessed text: %s", text)
embedding = model.encode(text)
return embedding.tolist()

Expand Down
93 changes: 92 additions & 1 deletion adserver/analyzer/views.py
Original file line number Diff line number Diff line change
@@ -1 +1,92 @@
"""Intentionally blank."""
from urllib.parse import urlparse

from django.conf import settings
from pgvector.django import CosineDistance
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.renderers import StaticHTMLRenderer
from rest_framework.response import Response
from rest_framework.views import APIView

from adserver.analyzer.backends.st import SentenceTransformerAnalyzerBackend
from adserver.analyzer.models import AnalyzedUrl


if "adserver.analyzer" in settings.INSTALLED_APPS:

class EmbeddingViewSet(APIView):
"""
Returns a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL.
Example: http://localhost:5000/api/v1/similar/?url=https://www.gitbook.com/
.. http:get:: /api/v1/embedding/
Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL
:<json string url: **Required**. The URL to query for similar URLs and scores
:>json int count: The number of similar URLs returned
:>json array results: An array of similar URLs and scores
"""

permission_classes = [AllowAny]
renderer_classes = [StaticHTMLRenderer]

def get(self, request):
"""Return a list of similar URLs and scores based on querying the AnalyzedURL embedding for an incoming URL."""
url = request.query_params.get("url")

if not url:
return Response(
{"error": "url is required"}, status=status.HTTP_400_BAD_REQUEST
)

backend_instance = SentenceTransformerAnalyzerBackend(url)
response = backend_instance.fetch()
if not response:
return Response(
{"error": "Not able to fetch content from URL"},
status=status.HTTP_400_BAD_REQUEST,
)
processed_text = backend_instance.get_content(response)
analyzed_embedding = backend_instance.embedding(response)

unfiltered_urls = (
AnalyzedUrl.objects.filter(publisher__allow_paid_campaigns=True)
.exclude(embedding=None)
.annotate(distance=CosineDistance("embedding", analyzed_embedding))
.order_by("distance")[:25]
)

# Filter urls to ensure each domain is unique
unique_domains = set()
urls = []
for url in unfiltered_urls:
domain = urlparse(url.url).netloc
if domain not in unique_domains:
unique_domains.add(domain)
urls.append(url)

if not len(urls) > 3:
return Response(
{"error": "No similar URLs found"}, status=status.HTTP_404_NOT_FOUND
)

return Response(
f"""
<h2>Results:</h2>
<ul>
<li><a href="{urls[0].url}">{urls[0].url}</a></li>
<li><a href="{urls[1].url}">{urls[1].url}</a></li>
<li><a href="{urls[2].url}">{urls[2].url}</a></li>
<li><a href="{urls[3].url}">{urls[3].url}</a></li>
</ul>
<h2>
Text:
</h2>
<textarea style="height:100%; width:80%" disabled>
{processed_text}
</textarea>
"""
)
8 changes: 8 additions & 0 deletions adserver/api/urls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""API Urls for the ad server."""
from django.conf import settings
from django.urls import path
from rest_framework import routers

Expand All @@ -14,4 +15,11 @@
router = routers.SimpleRouter()
router.register(r"advertisers", AdvertiserViewSet, basename="advertisers")
router.register(r"publishers", PublisherViewSet, basename="publishers")

if "adserver.analyzer" in settings.INSTALLED_APPS:
from adserver.analyzer.views import EmbeddingViewSet

urlpatterns += [path(r"similar/", EmbeddingViewSet.as_view(), name="similar")]


urlpatterns += router.urls
10 changes: 10 additions & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@
"simple_history",
"django_slack",
"djstripe",
"corsheaders",
]

MIDDLEWARE = [
"django.middleware.security.SecurityMiddleware",
"enforce_host.EnforceHostMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"corsheaders.middleware.CorsMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
Expand Down Expand Up @@ -438,6 +440,14 @@
SLACK_FAIL_SILENTLY = env.bool("SLACK_FAIL_SILENTLY", default=True)


# CORS
# https://github.com/adamchainz/django-cors-headers
# --------------------------------------------------------------------------
CORS_ALLOWED_ORIGINS = env.list("CORS_ALLOWED_ORIGINS", default=[])
CORS_ALLOW_HEADERS = ["*"]
CORS_URLS_REGEX = r"^/api/v1/similar/.*$"


# Metabase
# Graphing and BI tool
# --------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions config/settings/development.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@
"schedule": crontab(minute="*/5"),
},
}

CORS_ALLOWED_ORIGINS += ["http://localhost:8000", "http://127.0.0.1:8000"]
2 changes: 1 addition & 1 deletion docker-compose/django/start
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -o pipefail
set -o nounset

# Reinstall dependencies without rebuilding docker image
# pip install -r /app/requirements/production.txt -r /app/requirements/analyzer.txt
pip install -r /app/requirements/development.txt

# Don't auto-migrate locally because this can cause weird issues when testing migrations
# python manage.py migrate
Expand Down
2 changes: 2 additions & 0 deletions requirements/analyzer.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ networkx<3.0
# Has to be downloaded directly like this (~30MB)
https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.4.0/en_core_web_md-3.4.0-py3-none-any.whl

# Used to parse web pages and get the "main section" of the page
trafilatura==1.7.0

#######################################################################
# Machine learning production requirements
Expand Down
3 changes: 3 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ PyJWT==2.4.0

# Postgres & Postgres vector support
pgvector==0.2.5

# CORS headers
django-cors-headers==3.8.0

0 comments on commit 84a4111

Please sign in to comment.