In [None]:
import os
import matplotlib as mpl
if os.environ.get('DISPLAY','') == '':
    print('no display found. Using non-interactive Agg backend')
    mpl.use('Agg')
import sys
import socket
import re
import numpy as np
import string
from timeit import default_timer as timer
from datetime import datetime
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf,desc,row_number,col,year,month,dayofmonth,dayofweek,to_timestamp,size,isnan,lower,rand,lit,mean,stddev,count,sqrt,round
import pyspark.sql.functions as F
from pyspark.sql.functions import broadcast
from pyspark.sql.types import MapType, StringType, IntegerType, StructType, StructField, FloatType, ArrayType

In [None]:
print('Hostname:', socket.gethostname())
try:
    spark
except NameError:
    if 'samuel' in socket.gethostname().lower():
        print('Create Local SparkSession')
        spark = SparkSession.builder.config(
        "spark.driver.host", "localhost").appName(
        "sample-tweets-for-labeling").getOrCreate()
    else:
        print('Create Cluster SparkSession')
        spark = SparkSession.builder.appName(
        "sample-tweets-for-labeling").getOrCreate()
spark

In [None]:
country_code = "US"
print('Country:', country_code)

# Local
if  'samuel' in socket.gethostname().lower():
    path_to_data = os.path.join('../../data/classification',country_code)
    path_to_fig = '../../fig'
# Cluster
else:
    path_to_data = os.path.join('/user/spf248/twitter/data/classification',country_code)
    path_to_fig = '/home/spf248/twitter/fig'
print('Path to data:',path_to_data)

In [None]:
print('Import random tweets')
random = spark.read.parquet(os.path.join(path_to_data,'random-scored'))
random.cache()

In [None]:
keywords=sorted([keyword for keyword in random.columns 
                 if keyword not in ['tweet_id','text','keyword'] 
                 and 'target_' not in keyword])
print('Keywords:\n')
print('\n'.join(keywords))

In [None]:
targets=sorted([target for target in random.columns if 'target_' in target])
print('Targets:\n')
print('\n'.join(targets))

In [None]:
cutoffs=[50000,100000,200000]
print('Cutoffs:\n')
print('\n'.join([str(x) for x in cutoffs]))

In [None]:
print('COMPUTE BASE RATES')
base_rates=random.select(*(mean(col(c).cast("int")).alias(c) for c in keywords)).collect()[0].asDict()

In [None]:
print('COMPUTE LIFTS')

lift={}

for target in targets:
    
    print(target)
    
    random=random.orderBy(desc(target))

    for cutoff in cutoffs:
        
        print(cutoff)
        
        label=target+'-cutoff_'+str(cutoff)
        
        lift[label+'-mean']=\
        random.limit(cutoff).select(*((mean(col(c).cast("int"))).alias(c) 
        for c in keywords)).collect()[0].asDict()
        
        lift[label+'-serr']=\
        random.limit(cutoff).select(*((stddev(col(c).cast("int"))/\
        sqrt(count(col(c).cast("int")))).alias(c) for c in keywords)).collect()[0].asDict()
        
    print()
    
lift=pd.concat([pd.Series(lift[key])/pd.Series(base_rates) for key in lift],1,keys=lift.keys())
lift.columns=pd.MultiIndex.from_product(
[[target.replace('target_','') for target in targets],cutoffs,['mean','serr']],names=['target','cutoff','stat'])
lift.index.name='keyword'

In [None]:
width = 0.2

for target in sorted(targets):

    target=target.replace('target_','')
    
    with sns.axes_style("white"):
        sns.set_style("ticks")
        sns.set_context("talk")

        fig,ax = plt.subplots(nrows=1,ncols=1,figsize=(15,8))

        for i,cutoff in enumerate(cutoffs):

            y=lift.xs(target, level='target', axis=1).xs(cutoff, level='cutoff', axis=1)['mean']
            yerr=lift.xs(target, level='target', axis=1).xs(cutoff, level='cutoff', axis=1)['serr']
            x=np.arange(len(y))

            bars=ax.bar(
            x+i*width, 
            y, 
            width, 
            yerr=yerr,
            color=sns.cubehelix_palette(len(cutoffs))[i],
            error_kw=dict(ecolor='k', lw=1, capsize=1, capthick=1),
            label='{:,}'.format(cutoff)+' tweets')

        ax.legend(loc='best')    
        plt.axhline(y=0, color='grey', linestyle='--', linewidth=1,)
        ax.tick_params(which='both',direction='in',pad=3)
        ax.locator_params(axis='y',nbins=5)
        ax.set_title(target.replace('_',' ').capitalize(),fontweight='bold')
        ax.set_xticks(range(len(keywords)))
        ax.set_xticklabels([keyword.replace('_',' ') for keyword in keywords],rotation=30,ha='right')
        ax.axhline(y=1, color='k', linestyle='--', linewidth=1, zorder=0)
        plt.gca().set_ylim(bottom=0)
        ax.set_ylabel('Lift Factor',fontweight='bold')
        
    plt.savefig(os.path.join(path_to_fig,'lift_'+target+'.pdf'),bbox_inches='tight')