# Imports

In [23]:
from docarray import DocumentArray
from docarray import Document
from transformers import AutoModel, AutoTokenizer
from pprint import pprint


# Setup

Import some documents:

In [2]:
da = DocumentArray(storage='weaviate', config={'host': 'weaviate'})

with da:
    da.extend(
        [
            Document(text='Persist Documents with Weaviate.'),
            Document(text='And enjoy fast nearest neighbor search.'),
            Document(text='All while using DocArray API.'),
        ]
    )


Create embeddings for the imported documents:

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')


def collate_fn(da):
    return tokenizer(da.texts, return_tensors='pt', truncation=True, padding=True)


da.embed(model, collate_fn=collate_fn)


Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 15.7kB/s]
Downloading: 100%|██████████| 570/570 [00:00<00:00, 445kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 27.0MB/s]
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 30.2MB/s]
Downloading: 100%|██████████| 440M/440M [00:14<00:00, 30.8MB/s] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from 

# Semantic Search

In [5]:
results = da.find(
    DocumentArray([Document(text='How to persist Documents')]).embed(
        model, collate_fn=collate_fn
    ),
    limit=1,
)

results[0].texts


  batch_inputs[k] = torch.tensor(v, device=device)


['Persist Documents with Weaviate.']

# Query by Conditions

Based on: https://docarray.jina.ai/advanced/document-store/weaviate/#example-of-find-with-a-filter-only

In [12]:
n_dim = 3
da = DocumentArray(
    storage="weaviate",
    config={"n_dim": n_dim, "columns": {"price": "float"}, "host": "weaviate"},
)

with da:
    da.extend([Document(id=f"r{i}", tags={"price": i}) for i in range(10)])

print("\nIndexed Prices:\n")
for price in da[:, "tags__price"]:
    print(f"\t price={price}")

da.summary()



Indexed Prices:

	 price=0
	 price=1
	 price=2
	 price=3
	 price=4
	 price=5
	 price=6
	 price=7
	 price=8
	 price=9


In [13]:
max_price = 3
n_limit = 4

filter = {"path": "price", "operator": "LessThanEqual", "valueNumber": max_price}
results = da.find(filter=filter)

print('\n Returned examples that verify filter "price at most 3":\n')
for price in results[:, "tags__price"]:
    print(f"\t price={price}")



 Returned examples that verify filter "price at most 3":

	 price=0
	 price=1
	 price=2
	 price=3


In [28]:
da = DocumentArray(
    [
        Document(
            text="journal",
            weight=25,
            tags={"h": 14, "w": 21, "uom": "cm", "modality": "A"},
        ),
        Document(
            text="notebook",
            weight=50,
            tags={"h": 8.5, "w": 11, "uom": "in", "modality": "A"},
        ),
        Document(
            text="paper",
            weight=100,
            tags={"h": 8.5, "w": 11, "uom": "in", "modality": "D"},
        ),
        Document(
            text="planner",
            weight=75,
            tags={"h": 22.85, "w": 30, "uom": "cm", "modality": "D"},
        ),
        Document(
            text="postcard",
            weight=45,
            tags={"h": 10, "w": 15.25, "uom": "cm", "modality": "A"},
        ),
    ],
    storage="weaviate",
    config={"host": "weaviate", "columns": {"modality": "str"}},
)

da.summary()


In [31]:
filter = {"path": ["modality"], "operator": "Equal", "valueString": "A"}

r = da.find(filter=filter)

pprint(r.to_dict(exclude_none=True))  # just for pretty print
