In [19]:
import os
from bson import ObjectId

from engine.elastic import ElasticEngine
from engine.markov.utils import TextProcessor, PostgresStorage

In [2]:
ELASTIC_HOST = os.environ.get('ELASTIC_HOST', 'localhost')
ELASTIC_PORT = os.environ.get('ELASTIC_PORT', 9200)
ELASTIC_USER = os.environ.get('ELASTIC_USER', None)
ELASTIC_PASS = os.environ.get('ELASTIC_PASS', None)

In [60]:
es = ElasticEngine.connect(host=ELASTIC_HOST, port=ELASTIC_PORT, user=ELASTIC_USER, password=ELASTIC_PASS)
pg = PostgresStorage.connect(host=ELASTIC_HOST, port=15432, user=ELASTIC_USER, dbname=ELASTIC_USER, password=ELASTIC_PASS)

# Habr model

In [6]:
sql = '''
    SELECT text 
    FROM posts INNER JOIN habs
    ON posts.post_id = habs.post_id 
    WHERE hab in (%s, %s)'''
params = ['Математика', 'Машинное обучение']

raw_data = (row[0] for row in pg.exec_query(query=sql, params=params))
train_data = TextProcessor.get_sentences_gens(raw_data)

In [12]:
habr_index_id = 't9-index-' + str(ObjectId())
es.add_index(name=habr_index_id)
habr_index_id

't9-index-6065dbecdc9984e75eb4fb07'

In [14]:
import tqdm

for sentence in tqdm.tqdm(train_data):
    es.add_doc(index_name=habr_index_id, text=sentence)

12998it [1:09:08,  3.56it/s]

KeyboardInterrupt: 

In [59]:
import uuid
from typing import List, Union, Optional, Dict, Any, Iterable

from elasticsearch import Elasticsearch, helpers
from elasticsearch.exceptions import ConnectionError


class ElasticEngine:
    es: Elasticsearch
    bulk_actions_count: int = 10_000

    def __init__(self, es: Elasticsearch):
        self.es = es

    @classmethod
    def connect(cls, host: str, port: Union[int, str], user: Optional[str], password: Optional[str]):
        con_str = f'http://{user}:{password}@{host}:{port}/' if user and password else f'{host}:{port}'
        es = Elasticsearch(con_str)
        if not es.ping():
            raise ConnectionError('ping failed')
        return cls(es)

    def add_index(self, name: str, number_of_shards: int = 1, number_of_replicas: int = 2) -> None:
        self.es.indices.create(index=name, body={
            "settings": {
                "index": {
                    "number_of_shards": number_of_shards,
                    "number_of_replicas": number_of_replicas,
                    "analysis": {
                        "analyzer": {
                            "t9_analyzer": {
                                "type": "custom",
                                "tokenizer": "standard",
                                "filter": [
                                    "lowercase",
                                    "custom_edge_ngram"
                                ]
                            }
                        },
                        "filter": {
                            "custom_edge_ngram": {
                                "type": "edge_ngram",
                                "min_gram": 2,
                                "max_gram": 10
                            }
                        }
                    }
                }
            },
            "mappings": {
                "properties": {
                    "text": {
                        "type": "text",
                        "analyzer": "t9_analyzer",
                        "search_analyzer": "standard"
                    }
                }
            }
        })

    def add_doc(self, index_name: str, text: str) -> None:
        self.es.index(index=index_name, body={
            'text': text
        })

    def add_many(self, index_name: str, sentences: Iterable) -> None:
        actions = []
        for sentence in tqdm.tqdm(sentences):
            if len(actions) < self.bulk_actions_count:
                actions.append({
                    '_op_type': 'create',
                    '_index': index_name,
                    '_id': uuid.uuid4(),
                    'doc': {
                        'text': sentence
                    }
                })
            else:
                helpers.bulk(self.es, actions=actions, stats_only=True)
                actions = []

    def delete_index(self, index_name: str) -> None:
        self.es.indices.delete(index=index_name)

    def get_indices_stats(self, index_name: str) -> Dict[str, Any]:
        return self.es.indices.stats(index=index_name, human=True).get("indices", {})

    def get(self, index_name: str, phrase: str, count: int = 10) -> List[str]:
        sentences = [
            doc['_source']['text']
            for doc in self.es.search(
                index=index_name,
                body={
                    "query": {
                        "match_phrase": {
                            "text": phrase
                        }
                    },
                    "size": count
                })['hits']['hits']
        ]
        return [
            sentence[sentence.find(phrase):]
            for sentence in sentences
        ]


In [28]:
es.add_many(index_name=habr_index_id, sentences=train_data)


0it [00:00, ?it/s][A
1001it [00:02, 447.51it/s][A
2002it [00:04, 461.00it/s][A
3003it [00:06, 448.31it/s][A
4004it [00:08, 468.45it/s][A
5005it [00:10, 476.78it/s][A
6006it [00:13, 449.28it/s][A
7007it [00:16, 378.76it/s][A
8008it [00:19, 391.48it/s][A
9009it [00:21, 400.49it/s][A
10010it [00:23, 410.84it/s][A
11011it [00:26, 406.66it/s][A
12012it [00:28, 420.54it/s][A
13013it [00:31, 373.28it/s][A
14014it [00:34, 383.44it/s][A
15015it [00:36, 411.35it/s][A
16016it [00:38, 430.19it/s][A
17017it [00:40, 460.82it/s][A
18018it [00:41, 495.77it/s][A
19019it [00:43, 513.13it/s][A
20020it [00:45, 519.54it/s][A
21021it [00:47, 510.55it/s][A
22022it [00:49, 474.88it/s][A
23023it [00:54, 372.73it/s][A
24024it [00:56, 380.99it/s][A
25025it [00:58, 412.79it/s][A
26026it [01:00, 431.35it/s][A
27027it [01:03, 416.44it/s][A
28028it [01:05, 403.01it/s][A
29029it [01:08, 394.03it/s][A
30002it [01:08, 553.29it/s][A
30322it [01:10, 330.47it/s][A
31031it [01:12, 326.85it/

255255it [12:20, 392.78it/s][A
256256it [12:23, 406.67it/s][A
257257it [12:26, 381.13it/s][A
258258it [12:27, 428.81it/s][A
259259it [12:30, 420.62it/s][A
260260it [12:32, 449.75it/s][A
261261it [12:34, 458.10it/s][A
262262it [12:36, 478.34it/s][A
263263it [12:38, 482.39it/s][A
264264it [12:41, 392.79it/s][A
265265it [12:43, 420.16it/s][A
266266it [12:46, 397.34it/s][A
267267it [12:47, 464.97it/s][A
268268it [12:49, 491.07it/s][A
269269it [12:51, 532.09it/s][A
270270it [12:53, 543.37it/s][A
271271it [12:54, 536.89it/s][A
272272it [12:56, 584.76it/s][A
273273it [12:57, 605.03it/s][A
274274it [12:59, 638.09it/s][A
275275it [13:01, 594.78it/s][A
276276it [13:02, 600.35it/s][A
277277it [13:04, 635.52it/s][A
278278it [13:06, 567.97it/s][A
279279it [13:07, 613.83it/s][A
280280it [13:09, 591.82it/s][A
281281it [13:10, 626.42it/s][A
282282it [13:12, 654.27it/s][A
283283it [13:14, 612.66it/s][A
284284it [13:15, 645.25it/s][A
285285it [13:18, 511.59it/s][A
286286it

In [58]:
es.get(index_name=habr_index_id, phrase='рекуррентн', count=5)

['рекуррентных сетях такая матрица или матрицы одна и та же для каждого слоя так как слой рекуррентный и свойства сети зависят от входного сигнала',
 'рекуррентные нейронные сети',
 'рекуррентного  кодировщика',
 'рекуррентные  сети в частности ',
 'рекуррентной нейронной сети типа     или  ']

# Wiki model

In [61]:
wiki_index_id = 't9-index-' + str(ObjectId())
es.add_index(name=wiki_index_id)
wiki_index_id

't9-index-6065f39cdc9984e75eb4fb08'