In [16]:
# !pip3 install transformers
# !pip3 install tensorflow
# !pip3 install torch
# !pip3 install weaviate

## Import the BERT transformer model and pytorch

We are using the `bert-base-uncased` model in this example, but any model will work. Feel free to adjust accordingly.

In [5]:
import torch
from transformers import AutoModel, AutoTokenizer
from nltk.tokenize import sent_tokenize



torch.set_grad_enabled(False)

# udpated to use different model if desired
MODEL_NAME = "distilbert-base-uncased"

# Create model and tokenizer
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


## Initialize Weaviate Client
This assumes you have Weaviate running locally on `:8080`. Adjust URL accordingly. You could also enter the WCS URL here, for example, if you are running a WCS cloud instance instead of running Weaviate locally.

In [6]:
import weaviate

client = weaviate.Client("http://localhost:8080")

## Load dataset from disk
Create some helper functions to create the dataset (20-newsgroup text posts) from disk. These methods are specific to the structure of your dataset, adjust accordingly.

In [73]:
import os
import random

def get_post_filenames(limit_objects=100):
    file_names = []
    i=0
    for root, dirs, files in os.walk("./data/20news-bydate-test"):
        for filename in files:
            path = os.path.join(root, filename)
            file_names += [path]
        
    random.shuffle(file_names)
    limit_objects = min(len(file_names), limit_objects)
      
    file_names = file_names[:limit_objects]

    return file_names

def read_posts(filenames=[]):
    posts = []
    for filename in filenames:
        f = open(filename, encoding="utf-8", errors='ignore')
        ## TODO: strip headers
        post = f.read()
        
        # strip the headers (the first occurrence of two newlines)
        post = post[post.find('\n\n'):]
        
        # remove posts with less than 10 words to remove some of the noise
        if len(post.split(' ')) < 10:
               continue
        
        post = post.replace('\n', ' ').replace('\t', ' ')
        if len(post) > 1000:
            post = post[:1000]
        posts += [post]

    return posts       


## Vectorize Dataset using BERT

The following is a helper function to vectorize all posts (using our BERT transformer) which are entered as an array. The return array contains all the vectors in the same order. BERT is optimized to run on GPUs, if you're using CPUs this might take a while. 

In [74]:
import time

def vectorize_posts(posts=[]):
    print("Vectorize your posts with BERT. If you are using CPUs this might take a while...")
    post_vectors=[]
    before=time.time()
    for i, post in enumerate(posts):
        sentences = sent_tokenize(post)
        tokens_pt = tokenizer(sentences, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
        outputs = model(**tokens_pt)
        vec = outputs[0].mean(0).mean(0).detach()
        post_vectors += [vec]
        if i % 25 == 0 and i != 0:
            print("So far {} objects vectorized in {}s".format(i, time.time()-before))
    after=time.time()
    
    print("Vectorized {} items in {}s".format(len(posts), after-before))
    
    return post_vectors

### Run everything we have so far

It is now time to run the functions we defined before. Let's load 50 random posts from disk, then vectorize them using BERT.

In [None]:
posts = read_posts(get_post_filenames(4000))
vectors = vectorize_posts(posts)
print(len(vectors))

Vectorize your posts with BERT. If you are using CPUs this might take a while...
So far 25 objects vectorized in 17.35766339302063s
So far 50 objects vectorized in 30.970993280410767s
So far 75 objects vectorized in 48.50780916213989s
So far 100 objects vectorized in 58.79331040382385s
So far 125 objects vectorized in 74.6122612953186s
So far 150 objects vectorized in 89.20110416412354s
So far 175 objects vectorized in 101.08796119689941s
So far 200 objects vectorized in 112.2345552444458s
So far 225 objects vectorized in 128.05660319328308s
So far 250 objects vectorized in 144.34292316436768s
So far 275 objects vectorized in 159.30127835273743s
So far 300 objects vectorized in 175.7815442085266s
So far 325 objects vectorized in 190.5825710296631s
So far 350 objects vectorized in 206.8574640750885s
So far 375 objects vectorized in 220.1920862197876s
So far 400 objects vectorized in 233.265465259552s
So far 425 objects vectorized in 246.8308732509613s
So far 450 objects vectorized in 26

## Initialize Weaviate

Now that we have vectors we can import both the posts and the vectors into Weaviate, so we can then search through them.

### Init a simple schema
Our schema is very simple, we just have one object class, the "Post". A post class has just a single property, which we call "content" and is of type "text".

Each class in schema creates one index, so by running the below we tell weaviate to create one brand new vector index waiting for us to import data.

In [76]:
def init_weaviate_schema(client):
    
    # a simple schema containing just a single class for our posts
    schema = {
        "classes": [{
                "class": "Post",
                "vectorizer": "none", # explicitly tell Weaviate not to vectorize anything, we are providing the vectors ourselves through our BERT model
                "properties": [{
                    "name": "content",
                    "dataType": ["text"],
                }]
        }]
    }
    
    # cleanup from previous runs
    client.schema.delete_all()
    
    client.schema.create(schema)

In [77]:
init_weaviate_schema(client)

In [78]:
## doing this manually until the client is updated
import requests

def import_posts_with_vectors(posts, vectors):
    if len(posts) != len(vectors):
        raise Exception("len of posts ({}) and vectors ({}) does not match".format(len(posts), len(vectors)))
        
    for i, post in enumerate(posts):
        r = requests.post('http://localhost:8080/v1/objects', json={
            "class": "Post",
            "vector": vectors[i].tolist(),
            "properties": {
                "content": post,
            }
        })
             
        if r.status_code > 399:
            print(res)
        
    
import_posts_with_vectors(posts, vectors)

In [93]:
# search query 
query = "camera lenses"
limit = 3

def query_to_vector(query=""):
    tokens_pt = tokenizer(query, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

search_vec = {"vector": query_to_vector(query).tolist()}

before = time.time()
res = client \
    .query.get("Post", ["content", "_additional {certainty}"]) \
    .with_near_vector(search_vec) \
    .with_limit(limit) \
    .do()
after = time.time()

print("Query \"{}\" with {} results took {}s".format(query, limit, after-before))
for post in res["data"]["Get"]["Post"]:
    print('---')
    print("{}: {}".format(post["_additional"]["certainty"], post["content"]))

Query "camera lenses" with 3 results took 0.014058113098144531s
---
0.8743142:    Nikon L35 Af camera. 35/2.8 lens and camera case. Package $50  Send e-mail 
---
0.7892951:        I have a few 12" composite monochrome monitors for sale.  Magnovax Computer Monitor 80, Model number BM7650 074B.  RCA type input for video only. (no audio).  Power, Brightness and Contrast dials in front, V and H hold and position controls on the back. Nice little monitor that can be used for  PCs, Amigas, your VCR, security monitor.  Excellent condition. I am asking for $40 plus shipping and COD (not to exceed $10) if applicable.      
---
0.7800093:   my 14" compacq vga monitor id dead due to the transformer's failure.  if you have this part and would like to get rid of it, pls let me know.  thanks.  eric 
