In [44]:
import os
import random
import time

import nltk
import torch
import weaviate
from nltk.tokenize import sent_tokenize
from transformers import AutoModel, AutoTokenizer

import json
from pydantic import BaseModel

from typing import Optional

In [39]:
torch.set_grad_enabled(False)

WEAVIATE_URL = 'http://localhost'
WEAVIATE_PORT = '8123'
MODEL_NAME = "distilbert-base-uncased"

wine_filename = '../../data/winemag-data-130k-v2.json'

In [41]:
with open(wine_filename, 'r') as f:
    data = json.loads(f.read())


In [43]:
data[0]

{'points': '87',
 'title': 'Nicosia 2013 Vulkà Bianco  (Etna)',
 'description': "Aromas include tropical fruit, broom, brimstone and dried herb. The palate isn't overly expressive, offering unripened apple, citrus and dried sage alongside brisk acidity.",
 'taster_name': 'Kerin O’Keefe',
 'taster_twitter_handle': '@kerinokeefe',
 'price': None,
 'designation': 'Vulkà Bianco',
 'variety': 'White Blend',
 'region_1': 'Etna',
 'region_2': None,
 'province': 'Sicily & Sardinia',
 'country': 'Italy',
 'winery': 'Nicosia'}

In [56]:
class WineItem(BaseModel):
    """
    A WineItem object contains only the relevant 
    fields for our project from the wine dataset.

    It inherits from the Pydantic BaseModel to 
    allow for schema validation.
    """

    title: str
    description: str
    variety: str
    region_1: str
    region_2: Optional[str] = None
    country: str

    @property
    def region(self):
        region = f'{self.region_1} {self.region_2}' if self.region_2 is not None else self.region_1
        return region

In [57]:
wine_item = WineItem(**data[0])
wine_item.region

'Etna'

In [59]:
data[0]

{'points': '87',
 'title': 'Nicosia 2013 Vulkà Bianco  (Etna)',
 'description': "Aromas include tropical fruit, broom, brimstone and dried herb. The palate isn't overly expressive, offering unripened apple, citrus and dried sage alongside brisk acidity.",
 'taster_name': 'Kerin O’Keefe',
 'taster_twitter_handle': '@kerinokeefe',
 'price': None,
 'designation': 'Vulkà Bianco',
 'variety': 'White Blend',
 'region_1': 'Etna',
 'region_2': None,
 'province': 'Sicily & Sardinia',
 'country': 'Italy',
 'winery': 'Nicosia'}

In [3]:
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# initialize nltk (for tokenizing sentences)
nltk.download("punkt")

# initialize weaviate clien{t for importing and searching
client = weaviate.Client(f'{WEAVIATE_URL}:{WEAVIATE_PORT}')

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\altoz\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt.zip.
            changes and features of Weaviate >=1.14.x, but you are connected to Weaviate 1.1.0.
            If you want to make use of these new changes/features using this Python Client version, upgrade your
            Weaviate instance.
            Please consider upgrading to the latest version. See https://www.weaviate.io/developers/weaviate for details.


In [21]:
def get_post_filenames(limit_objects=100, posts_directory: str = POSTS_DIRECTORY):
    file_names = []
    i = 0
    for root, dirs, files in os.walk(posts_directory):
        for filename in files:
            path = os.path.join(root, filename)
            file_names += [path]

    random.shuffle(file_names)
    limit_objects = min(len(file_names), limit_objects)

    file_names = file_names[:limit_objects]

    return file_names


def read_posts(filenames=[]):
    posts = []
    for filename in filenames:
        f = open(filename, encoding="utf-8", errors="ignore")
        post = f.read()

        # strip the headers (the first occurrence of two newlines)
        post = post[post.find("\n\n") :]

        # remove posts with less than 10 words to remove some of the noise
        if len(post.split(" ")) < 10:
            continue

        post = post.replace("\n", " ").replace("\t", " ").strip()
        if len(post) > 1000:
            post = post[:1000]
        posts += [post]

    return posts


def text2vec(text):
    tokens_pt = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=500,
        add_special_tokens=True,
        return_tensors="pt",
    )
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()


def vectorize_posts(posts=[]):
    post_vectors = []
    before = time.time()
    for i, post in enumerate(posts):
        vec = text2vec(sent_tokenize(post))
        post_vectors += [vec]
        if i % 100 == 0 and i != 0:
            print("So far {} objects vectorized in {}s".format(i, time.time() - before))
    after = time.time()

    print("Vectorized {} items in {}s".format(len(posts), after - before))

    return post_vectors


def init_weaviate_schema():
    # a simple schema containing just a single class for our posts
    schema = {
        "classes": [
            {
                "class": "Post",
                "vectorizer": "none",  # explicitly tell Weaviate not to vectorize anything, we are providing the vectors ourselves through our BERT model
                "properties": [
                    {
                        "name": "content",
                        "dataType": ["text"],
                    }
                ],
            }
        ]
    }

    # cleanup from previous runs
    client.schema.delete_all()

    client.schema.create(schema)


def import_posts_with_vectors(posts, vectors, batchsize=256):
    batch = weaviate.ObjectsBatchRequest()

    for i, post in enumerate(posts):
        props = {
            "content": post,
        }
        batch.add(props, "Post", vector=vectors[i])

        # when either batch size is reached or we are at the last object
        if (i != 0 and i % batchsize == 0) or i == len(posts) - 1:
            # send off the batch
            client.batch.create(batch)

            # and reset for the next batch
            batch = weaviate.ObjectsBatchRequest()


def search(query="", limit=3):
    before = time.time()
    vec = text2vec(query)
    vec_took = time.time() - before

    before = time.time()
    near_vec = {"vector": vec.tolist()}
    res = (
        client.query.get("Post", ["content", "_additional {certainty}"])
        .with_near_vector(near_vec)
        .with_limit(limit)
        .do()
    )
    search_took = time.time() - before

    print(
        '\nQuery "{}" with {} results took {:.3f}s ({:.3f}s to vectorize and {:.3f}s to search)'.format(
            query, limit, vec_took + search_took, vec_took, search_took
        )
    )
    for post in res["data"]["Get"]["Post"]:
        print("{:.4f}: {}".format(post["_additional"]["certainty"], post["content"]))
        print("---")

In [22]:
init_weaviate_schema()

In [23]:
posts = read_posts(get_post_filenames(4000))

  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="ignore")
  f = open(filename, encoding="utf-8", errors="i

In [37]:
vectors = vectorize_posts(posts[:500])

So far 100 objects vectorized in 10.208102226257324s
So far 200 objects vectorized in 19.52180314064026s
So far 300 objects vectorized in 29.140095472335815s
So far 400 objects vectorized in 39.7047061920166s
Vectorized 500 items in 49.48389959335327s


In [38]:
import_posts_with_vectors(posts[:500], vectors)

AttributeError: module 'weaviate' has no attribute 'ObjectsBatchRequest'

NameError: name 'vectors' is not defined

In [None]:
search("the best camera lens", 1)
search("which software do i need to view jpeg files", 1)
search("windows vs mac", 1)
