In [1]:
import lilac as ll

ll.set_project_dir('data')

# slimorca = ll.get_dataset('local', 'slimorca')
# slimorca.compute_embedding('gte-small')
# slimorca.select_rows()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = ll.get_dataset('local', 'openorca-10k')


In [3]:
rows = list(ds.select_rows(columns=['question']))


In [8]:
texts = [row['question'] for row in rows[:1000]]


In [5]:
from sentence_transformers import SentenceTransformer


In [6]:
import time
import functools
from lilac.splitters.spacy_splitter import clustering_spacy_chunker
from lilac.embeddings.embedding import compute_split_embeddings
import tqdm


def time_it(model, texts, batch_size, megabatch_size):
  start = time.time()
  embed_fn = functools.partial(model.encode, batch_size=batch_size)
  split_fn = clustering_spacy_chunker
  list(
    tqdm.tqdm(
      compute_split_embeddings(texts, megabatch_size, embed_fn=embed_fn, split_fn=split_fn)))
  end = time.time()
  return end - start


In [9]:
data_rows = []
for model_name in ('thenlper/gte-small', 'thenlper/gte-base', 'all-MiniLM-L6-v2'):
  model = SentenceTransformer(model_name, device='mps')
  for batch_size in (32, 64, 128):
    for megabatch_size in (256, 512, 1024, 2048):
      data_rows.append({
        'model': model_name,
        'batch_size': batch_size,
        'megabatch_size': megabatch_size,
        'time': time_it(model, texts, batch_size, megabatch_size)
      })


1000it [00:09, 101.07it/s]
1000it [00:10, 98.47it/s]
1000it [00:09, 104.17it/s]
1000it [00:09, 104.78it/s]
1000it [00:14, 68.16it/s]
1000it [00:11, 84.71it/s]
1000it [00:11, 89.57it/s]
1000it [00:10, 94.28it/s]
1000it [00:16, 59.55it/s]
1000it [00:14, 71.09it/s]
1000it [00:14, 67.77it/s]
1000it [00:13, 74.63it/s]
Downloading (…)a8668/.gitattributes: 100%|██████████| 1.52k/1.52k [00:00<00:00, 5.95MB/s]
Downloading (…)_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 1.25MB/s]
Downloading (…)10cbba8668/README.md: 100%|██████████| 68.1k/68.1k [00:00<00:00, 78.9MB/s]
Downloading (…)cbba8668/config.json: 100%|██████████| 618/618 [00:00<00:00, 4.69MB/s]
Downloading model.safetensors: 100%|██████████| 219M/219M [00:22<00:00, 9.90MB/s] 
Downloading (…)668/onnx/config.json: 100%|██████████| 630/630 [00:00<00:00, 3.71MB/s]
Downloading model.onnx: 100%|██████████| 436M/436M [00:42<00:00, 10.3MB/s] 
Downloading (…)cial_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 438kB/s]
Dow

In [12]:
import pandas as pd

pd.DataFrame(data_rows).pivot(
  columns='batch_size', index=['model', 'megabatch_size'], values='time').round(2)


Unnamed: 0_level_0,batch_size,32,64,128
model,megabatch_size,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
all-MiniLM-L6-v2,256,6.14,6.17,7.26
all-MiniLM-L6-v2,512,5.5,5.87,6.62
all-MiniLM-L6-v2,1024,5.43,5.8,7.0
all-MiniLM-L6-v2,2048,5.38,5.55,6.41
thenlper/gte-base,256,28.12,27.41,36.32
thenlper/gte-base,512,22.15,24.94,31.88
thenlper/gte-base,1024,21.56,24.7,33.0
thenlper/gte-base,2048,21.29,23.96,31.09
thenlper/gte-small,256,9.9,14.67,16.79
thenlper/gte-small,512,10.16,11.81,14.07
