In [13]:
%%capture
!pip install wget annoy
!pip install -U sentence-transformers

In [7]:
import psycopg2
from sqlalchemy import create_engine
import pandas as pd
from multiprocessing import Pool

from sentence_transformers import SentenceTransformer
import scipy.spatial
from annoy import AnnoyIndex

import tarfile
import tempfile
from transformers import cached_path

In [8]:
db_string = "postgresql://postgres:postgres@postgres/postgres"
db = create_engine(db_string)

def query_df(line_query, cell_query=None, conn=db):
    if cell_query==None:
      return pd.read_sql(line_query, conn)
    return pd.read_sql(cell_query, conn)

# Custom notebook magic commands for loading sql.
from IPython.core.magic import register_line_cell_magic
def create_df_sql_magic(magic_name, conn):
    def sql_df(line_query, cell_query=None, conn=db):
        if cell_query==None:
          return pd.read_sql(line_query, conn)
        return pd.read_sql(cell_query, conn)
    custom_func = sql_df
    custom_func.__name__ = magic_name
    register_line_cell_magic(custom_func)
create_df_sql_magic('sql_df', db)

parent_query = 'SELECT * FROM message;'
reply_query = 'SELECT * FROM reply;'

parents = query_df(parent_query)
replies = query_df(reply_query)

df = pd.concat([parents, replies])
df = df[['message_id', 'text']]
assert df.isna().sum().sum() == 0
print(df.shape)

(637568, 2)


In [9]:
def no_whitespace(text):
    for r in (("\t", " "), ("\n", " "), ('"', '')):
        text = text.replace(*r)
    return text

def no_url(text):
    tokens = text.split()
    new = []
    for t in tokens:
        if 'http' in t:
            new.append('<URL>')        
        else:
            new.append(t)
    clean = ' '.join(new)
    return clean

def no_short_reply(text):
    if len(text) < 30:
        text = None
    return text

def cleaner(series):
    series = series.apply(no_whitespace)
    series = series.apply(no_url)
    series = series.apply(no_short_reply)
    return series

def fast_clean(df):
    with Pool(16) as p:
        seq = [df.text]
        listy = p.map(cleaner, seq)
        results = [pd.Series(i) for i in listy]
        clean = results[0]
    return clean

In [10]:
%%time
df['cleaned'] = fast_clean(df)
df = df.dropna()
df = df.reset_index(drop=True)

CPU times: user 599 ms, sys: 480 ms, total: 1.08 s
Wall time: 3.95 s


In [11]:
%%time
# Drop questions longer than 510 characters.
df = df.loc[df['cleaned'].str.len() < 511]

# Reset index.
df = df.reset_index()

# Get a list of all the posts/messages.
corpus = list(df.cleaned)

CPU times: user 262 ms, sys: 9.02 ms, total: 271 ms
Wall time: 270 ms


In [12]:
len(corpus)

417255

In [24]:
url = 'https://model-2.s3.us-east-2.amazonaws.com/distil-bert-SO.tar.gz'

def download_pretrained_model():
    """ Download and extract finetuned model from S3 """
    # this func is from https://github.com/huggingface/transfer-learning-conv-ai/blob/master/utils.py
    resolved_archive_file = cached_path(url)
    tempdir = tempfile.mkdtemp()
    with tarfile.open(resolved_archive_file, 'r:gz') as archive:
        archive.extractall(tempdir)
    return tempdir

embedder = SentenceTransformer(download_pretrained_model())
embedder.to("cuda")

FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmpaeauw32w/modules.json'

In [None]:
%%time
corpus_embeddings = embedder.encode(corpus[:100000])
embs = np.asarray(corpus_embeddings)

In [None]:
%%time
num_docs, vec_dim = embs.shape

indx = AnnoyIndex(vec_dim, 'angular')
for i in range(num_docs):
    indx.add_item(i, embs[i])

trees = int(np.log(num_docs).round(0))
print(trees)
indx.build(trees)
indx.save('a.ann')

In [None]:
%%time
index = AnnoyIndex(vec_dim, 'angular')
index.load('annoy.ann')
for i in index.get_nns_by_item(0,10):
    print(i, df.Title[i])

In [None]:
example = ['convert pandas dataframe column to list']
example_embedding = embedder.encode(example)
emb = np.asarray(example_embedding)

In [None]:
%%time
for i in index.get_nns_by_vector(emb.ravel(), 10): # Gets the top 5 similar to unseen example embedding
#     print('\n')
    print(i, df.Title[i])