In [None]:
import os
from snowball_spider import *
import pytz
import psycopg2
import string
from pyspark import SparkContext
from pyspark.streaming import StreamingContext

def grab_comments(stock_code, count):
    ''' returns stock object in dictionary form: stock_object -> {'stock', 'data'} where data -> multiple comment objects of dictionary type '''
    stock_object = {}
    snowball = Snowball(stock_code)
    page_list = snowball.get_all_pages(count)  # first page first post: page_list[0][0]; type: dict

    comments = snowball.find_post_objects(page_list)[0]  # all raw comment json files
    print('\nNumber of comments for stock {} in current page(s):'.format(stock_code), len(comments),
          end='\n\n')  # number of comments on all requested pages
    print('-' * 115)
    # parsed_comments = snowball.parse_comment_text(comments)#all parsed comment texts only

    comment_object_list = snowball.comment_object_list(comments)
    # for comment_object in comment_object_list:
    #     print(type(comment_object))
    #     print(comment_object)
    #     print('-'*115)
    stock_object['stock'] = stock_code
    stock_object['data'] = comment_object_list 
    return stock_object

def look_up_db(db, cur, table_name, option):
# look up database:
    printable = ['by_count', 'by_period', 'find_count']
    cases = {
        'find_count': lambda: db.find_record_number(table_name),
        'by_count': lambda: db.select_comments_by_count(table_name, 5),
        'by_period': lambda: db.select_comments_by_period(table_name, ['10 hour',]),
        'find_last_time': lambda: db.find_time_last_comment(table_name)
    }
    cases[option]()
    db_temp = cur.fetchall()
    if option in printable:
        for data in db_temp:
            print(data)
            print('-' * 115)
    return db_temp

'''
TO DO: create function that batch classify sentiment
'''

def grab_rdd(stock_code):
    return grab_comments(stock_code, 1)

def main():
    stock_code_list =['00700', 'BABA', 'AAPL'] # set code list to grab comments
    stock_list = [] # a list that contains all stock comment objects
    # Connect to database:
    username, dbname = 'jian', 'postgres'
    try:
        conn = psycopg2.connect(user=username, dbname=dbname)
        conn.autocommit = True
        cur = conn.cursor()
        db = DB_Operations(cur)
        # db.drop_table()
        for t in stock_code_list:
            db.initialize_database(table_name='stock_'+t)
    except Exception as e:
        print(e)
    

    client = AipNlp_API('AipNlp.txt').connect_to_AipNlp()
#     print(client.sentimentClassify(text)['items'])

    # spark contains whole steps:
    sc = SparkContext.getOrCreate()
    
#     for stock in stock_code_list:
#         stock_list.append(grab_comments(stock, 1)) # One stock object, # of pages

    stock_list = sc.parallelize(stock_code_list).map(grab_rdd).collect()
#     stock_list.filter(lambda x: x["stock"]=='00700').collect()
    sc.stop()
    
    print(stock_list[0])

    for stock in stock_list:
        db.batch_insert_comments('stock_'+stock['stock'].lower(), stock['data'], sentiment_list)


    last_comment_time_web = stock_list[0]['data'][0]['timestamp']
    last_comment_time_db = look_up_db(db, cur, 'stock_'+stock_list[0]['stock'], 'find_last_time')[0][0] # timestamp format, to compare with time of latest comment on website
    print('The time of last comment for stock 00700 up to date:', datetime.fromtimestamp(last_comment_time_web/1000.0).astimezone(pytz.timezone('Asia/Shanghai')).strftime("%Y/%m/%d %H:%M:%S"))
    print('The time of last comment for stock 00700 recorded at:', datetime.fromtimestamp(last_comment_time_db/1000.0).astimezone(pytz.timezone('Asia/Shanghai')).strftime("%Y/%m/%d %H:%M:%S"))
    print('-'*115)
    look_up_db(db, cur, 'stock_'+stock_code_list[0].lower(), 'by_count')


    if conn:
        cur.close()
        conn.close()


if __name__ == '__main__':
    main()