In [1]:
%%capture
!pip install wget annoy
!pip install -U sentence-transformers

Collecting wget
  Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip
Collecting annoy
  Downloading https://files.pythonhosted.org/packages/00/15/5a9db225ebda93a235aebd5e42bbf83ab7035e7e4783c6cb528c635c9afb/annoy-1.16.3.tar.gz (644kB)
[K    100% |################################| 645kB 2.5MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: wget, annoy
  Running setup.py bdist_wheel for wget ... [?25ldone
[?25h  Stored in directory: /home/nbserver/.cache/pip/wheels/40/15/30/7d8f7cea2902b4db79e3fea550d7d7b85ecb27ef992b618f3f
  Running setup.py bdist_wheel for annoy ... [?25ldone
[?25h  Stored in directory: /home/nbserver/.cache/pip/wheels/f3/01/54/6ef760fe9f9fc6ba8c19cebbe6358212b5f3b5b0195c0b813f
Successfully built wget annoy
Installing collected packages: wget, annoy
Successfully installed annoy-1.16.3 wget-3.2
Collecting sentence-transformers
  Downloading https://files.pythonhosted.org/pa

In [7]:
import psycopg2
from sqlalchemy import create_engine
import pandas as pd
from multiprocessing import Pool

from sentence_transformers import SentenceTransformer
import scipy.spatial
from annoy import AnnoyIndex

In [8]:
db_string = "postgresql://postgres:postgres@postgres/postgres"
db = create_engine(db_string)

def query_df(line_query, cell_query=None, conn=db):
    if cell_query==None:
      return pd.read_sql(line_query, conn)
    return pd.read_sql(cell_query, conn)

# Custom notebook magic commands for loading sql.
from IPython.core.magic import register_line_cell_magic
def create_df_sql_magic(magic_name, conn):
    def sql_df(line_query, cell_query=None, conn=db):
        if cell_query==None:
          return pd.read_sql(line_query, conn)
        return pd.read_sql(cell_query, conn)
    custom_func = sql_df
    custom_func.__name__ = magic_name
    register_line_cell_magic(custom_func)
create_df_sql_magic('sql_df', db)

parent_query = 'SELECT * FROM message;'
reply_query = 'SELECT * FROM reply;'

parents = query_df(parent_query)
replies = query_df(reply_query)

df = pd.concat([parents, replies])
df = df[['message_id', 'text']]
assert df.isna().sum().sum() == 0
print(df.shape)

(637568, 2)


In [9]:
def no_whitespace(text):
    for r in (("\t", " "), ("\n", " "), ('"', '')):
        text = text.replace(*r)
    return text

def no_url(text):
    tokens = text.split()
    new = []
    for t in tokens:
        if 'http' in t:
            new.append('<URL>')        
        else:
            new.append(t)
    clean = ' '.join(new)
    return clean

def no_short_reply(text):
    if len(text) < 30:
        text = None
    return text

def cleaner(series):
    series = series.apply(no_whitespace)
    series = series.apply(no_url)
    series = series.apply(no_short_reply)
    return series

def fast_clean(df):
    with Pool(16) as p:
        seq = [df.text]
        listy = p.map(cleaner, seq)
        results = [pd.Series(i) for i in listy]
        clean = results[0]
    return clean

In [10]:
%%time
df['cleaned'] = fast_clean(df)
df = df.dropna()
df = df.reset_index(drop=True)

CPU times: user 599 ms, sys: 480 ms, total: 1.08 s
Wall time: 3.95 s


In [11]:
%%time
# Drop questions longer than 510 characters.
df = df.loc[df['cleaned'].str.len() < 511]

# Reset index.
df = df.reset_index()

# Get a list of all the posts/messages.
corpus = list(df.cleaned)

CPU times: user 262 ms, sys: 9.02 ms, total: 271 ms
Wall time: 270 ms


In [12]:
len(corpus)

417255

In [None]:
embedder = SentenceTransformer('/content/content/output/test1/')
embedder.to("cuda")