In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../")
import warnings
warnings.filterwarnings("ignore")

In [2]:
import pickle
import glob
from tqdm.notebook import trange, tqdm
import json
import re
import pandas as pd
import numpy as np
from collections import defaultdict

from utills import chunker, cartesian_product, get_num_chunks
from train_utils import generate_unique_pairs, get_random_author_excluding, generate_doc_pairs, fit_transformers, vectorize


from sklearn.linear_model import SGDClassifier
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve, auc

from sklearn.utils.fixes import loguniform
from sklearn.model_selection import RandomizedSearchCV

In [3]:

from plotly.offline import init_notebook_mode
import plotly.offline as py
import plotly.graph_objs as go
init_notebook_mode(connected=True)

In [4]:
# BASE_PATH = '../data/reddit_2/'
COMPUTED_DATA_PATH = '../temp_data/reddit/preprocessed/'
chunk_sz=20
TEMP_DATA = '../temp_data/reddit/multidoc_' + str(chunk_sz) + '_limitted_data_40_capped/'
MAX_COMMENTS = 40

In [5]:
with open(COMPUTED_DATA_PATH + 'metadata.p', 'rb') as f:
    (train_files, test_files, min_count, author_mapping_all) = pickle.load(f)

Vectorize Data
===

In [6]:
exclude_users_root = np.random.choice(train_files, size=int(len(train_files)*0.5), replace=False)
exclude_users_root = [u.replace('../data/reddit_2/', '/scratch/jnw301/av/data/reddit_2/reddit_2/').replace('.jsonl', '') for u in exclude_users_root]

exclude_users = []
for u in exclude_users_root:
    exclude_users.extend(author_mapping_all[u])
    
    
exclude_users_test_root = np.random.choice(test_files, size=int(len(test_files)*0.5), replace=False)
exclude_users_test_root = [u.replace('../data/reddit_2/', '/scratch/jnw301/av/data/reddit_2/reddit_2/').replace('.jsonl', '') for u in exclude_users_test_root]

exclude_users_test = []
for u in exclude_users_test_root:
    exclude_users_test.extend(author_mapping_all[u])

In [8]:
transformer, scaler, secondary_scaler = fit_transformers(
    COMPUTED_DATA_PATH + 'train.jsonl', 
    author_mapping_all, 
    chunk_sz, 
    sample_fraction=0.1,
    exclude_users=exclude_users,
    max_comments=MAX_COMMENTS
)

Sampled: 7271
Reading preprocessed data...



Fitting transformer
Generating pairs











In [9]:
with open(TEMP_DATA + 'model_' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((transformer, scaler, secondary_scaler ), f)

In [10]:
XX_train, author_bounds, author_subreddit, x_shape = vectorize(
    preprocessed_path = COMPUTED_DATA_PATH + 'train.jsonl', 
    vectorized_x_path = TEMP_DATA + 'XX_train_' + str(chunk_sz) + '.npy', 
    transformer=transformer,
    scaler=scaler,
    chunk_sz=chunk_sz,
    max_comments=MAX_COMMENTS,
    exclude_users=exclude_users
)

Precomputing record size...








In [11]:
XX_test, author_bounds_test, author_subreddit_test, x_shape_test = vectorize(
    preprocessed_path = COMPUTED_DATA_PATH + 'test.jsonl', 
    vectorized_x_path = TEMP_DATA + 'XX_test_' + str(chunk_sz) + '.npy', 
    transformer=transformer,
    scaler=scaler,
    chunk_sz=chunk_sz,
    max_comments=MAX_COMMENTS,
    exclude_users=exclude_users_test
)

Precomputing record size...








In [12]:
with open(TEMP_DATA + 'experiment_data' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((
        author_bounds, 
        author_bounds_test, 
        author_subreddit, 
        author_subreddit_test, 
        x_shape, 
        x_shape_test,
        exclude_users,
        exclude_users_test
    ), f)

Train the classifier
===

In [13]:
author_mapping = defaultdict(set)
author_to_root = {}
for y in author_bounds.keys():
    u = re.search(r'(.*)_[A-Z]+$', y).group(1)
    author_mapping[u].add(y)
    author_to_root[y] = u

subreddit_to_author = defaultdict(list)
for k, v in author_subreddit.items():
    subreddit_to_author[v].append(k)
    
author_mapping_test = defaultdict(set)
author_to_root_test = {}
for y in author_bounds_test.keys():
    u = re.search(r'(.*)_[A-Z]+$', y).group(1)
    author_mapping_test[u].add(y)
    author_to_root_test[y] = u

subreddit_to_author_test = defaultdict(list)
for k, v in author_subreddit_test.items():
    subreddit_to_author_test[v].append(k)
    
X_idxs_train, Y_train = generate_doc_pairs(author_mapping, subreddit_to_author, author_to_root, author_bounds, author_subreddit)
X_idxs_test, Y_test = generate_doc_pairs(author_mapping_test, subreddit_to_author_test, author_to_root_test, author_bounds_test, author_subreddit_test)



















In [14]:
with open(TEMP_DATA + 'experiment_data' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((
        author_bounds, 
        author_bounds_test, 
        author_subreddit, 
        author_subreddit_test, 
        x_shape, 
        x_shape_test,
        exclude_users,
        exclude_users_test,
        X_idxs_train,
        Y_train,
        X_idxs_test,
        Y_test
    ), f)

In [15]:
batch_sz = 50000
x_test_diff_sample = secondary_scaler.transform(np.abs(XX_test[X_idxs_test[:batch_sz, 0]] - XX_test[X_idxs_test[:batch_sz, 1]]))
y_test_sample = Y_test[:batch_sz]

In [16]:

clf = SGDClassifier(loss='log', alpha=0.01)

aucs = []
for i in range(100):
    for idxs in chunker(np.arange(len(X_idxs_train)), batch_sz):
        x_diff = secondary_scaler.transform(np.abs(XX_train[X_idxs_train[idxs, 0]] - XX_train[X_idxs_train[idxs, 1]]))
        x_diff[np.isnan(x_diff)]=0
        y = Y_train[idxs]
        clf.partial_fit(x_diff, y, classes=[0, 1])

        probs = clf.predict_proba(x_test_diff_sample)[:, 1]

        fpr, tpr, thresh = roc_curve(y_test_sample, probs)
        roc_auc = auc(fpr, tpr)
        print('AUC:', roc_auc)
    print('~'*20, 'Epoch: ', i)
    aucs.append(roc_auc)

AUC: 0.8487539443493554
AUC: 0.860052592793639
AUC: 0.8672536600168858
AUC: 0.8824803133566155
AUC: 0.8891808136175121
AUC: 0.8940109426437612
AUC: 0.8962740087673351
AUC: 0.8997897065436529
AUC: 0.9021092201238082
AUC: 0.902216088270259
~~~~~~~~~~~~~~~~~~~~ Epoch:  0
AUC: 0.9032108781456596
AUC: 0.9072354266839864
AUC: 0.9070859913929087
AUC: 0.9095282019952348
AUC: 0.9108602552482326
AUC: 0.9103180273195531
AUC: 0.9111943217634021
AUC: 0.9136212771625642
AUC: 0.9123892047804711
AUC: 0.9119611376904271
~~~~~~~~~~~~~~~~~~~~ Epoch:  1
AUC: 0.9134338705170565
AUC: 0.9133827143084632
AUC: 0.9149042814120277
AUC: 0.9158014972773588
AUC: 0.9157209395490541
AUC: 0.9158822569127644
AUC: 0.9165856051053654
AUC: 0.9170155069861259
AUC: 0.9164827094009989
AUC: 0.9161217491805036
~~~~~~~~~~~~~~~~~~~~ Epoch:  2
AUC: 0.9169226681782991
AUC: 0.9173288764210363
AUC: 0.9168273183512121
AUC: 0.9184816986636676
AUC: 0.9179203985261583
AUC: 0.9185704111954283
AUC: 0.9182044097071647
AUC: 0.91874684435025

AUC: 0.9217253298266663
AUC: 0.9216367887556975
AUC: 0.9218253459511917
AUC: 0.9218189137678409
AUC: 0.9217224518492627
AUC: 0.9216331944888174
~~~~~~~~~~~~~~~~~~~~ Epoch:  30
AUC: 0.9217687462630662
AUC: 0.9214833874322845
AUC: 0.921675399482613
AUC: 0.9215866276606719
AUC: 0.9216507475870799
AUC: 0.9218019343421344
AUC: 0.9217966206600239
AUC: 0.9218342619123687
AUC: 0.9216743146325561
AUC: 0.9216231808580851
~~~~~~~~~~~~~~~~~~~~ Epoch:  31
AUC: 0.9217481837875409
AUC: 0.9214799790481322
AUC: 0.9217566398492106
AUC: 0.9217556110844595
AUC: 0.9216779553701179
AUC: 0.9218113422515646
AUC: 0.921827142283413
AUC: 0.9219568083054326
AUC: 0.9217714399601795
AUC: 0.9216474946393465
~~~~~~~~~~~~~~~~~~~~ Epoch:  32
AUC: 0.9217567424051982
AUC: 0.9216154170493319
AUC: 0.9215520711001434
AUC: 0.9218777296267624
AUC: 0.9216925615861726
AUC: 0.921861778965807
AUC: 0.921839168575403
AUC: 0.9218965806992436
AUC: 0.9217320296170475
AUC: 0.9216573287971004
~~~~~~~~~~~~~~~~~~~~ Epoch:  33
AUC: 0.92177

AUC: 0.921931068354968
AUC: 0.9218339574492802
AUC: 0.9217737362528406
~~~~~~~~~~~~~~~~~~~~ Epoch:  60
AUC: 0.9218386894466484
AUC: 0.9217344092364487
AUC: 0.9217587999347008
AUC: 0.9219112462054785
AUC: 0.9217964427894827
AUC: 0.9219077528921489
AUC: 0.9218823286218338
AUC: 0.9219485878020469
AUC: 0.9218344493975336
AUC: 0.9218085171545921
~~~~~~~~~~~~~~~~~~~~ Epoch:  61
AUC: 0.9218735793141368
AUC: 0.9217599568944366
AUC: 0.9217414631654748
AUC: 0.9219361480812315
AUC: 0.9218279723459382
AUC: 0.9219046842247058
AUC: 0.9218966335796746
AUC: 0.9219477625468337
AUC: 0.9218669275968748
AUC: 0.9217811507302607
~~~~~~~~~~~~~~~~~~~~ Epoch:  62
AUC: 0.921874816395738
AUC: 0.9217581221037199
AUC: 0.9217917748896072
AUC: 0.9219090508663679
AUC: 0.9218296324709884
AUC: 0.9218857562352338
AUC: 0.9218775773952181
AUC: 0.9219442019311372
AUC: 0.9218160950806176
AUC: 0.921803789964536
~~~~~~~~~~~~~~~~~~~~ Epoch:  63
AUC: 0.9218744318107843
AUC: 0.9217416570603889
AUC: 0.9217936224998222
AUC: 0.9219

AUC: 0.9218782664432602
AUC: 0.9217465220600534
AUC: 0.9218316627590566
AUC: 0.9219214777676846
AUC: 0.9218412277073429
AUC: 0.9219060046330468
AUC: 0.9218838156836549
AUC: 0.9219274981246475
AUC: 0.92185066125577
AUC: 0.9218361928493238
~~~~~~~~~~~~~~~~~~~~ Epoch:  91
AUC: 0.9218770566030933
AUC: 0.9217993928765653
AUC: 0.921827371431948
AUC: 0.9219208335878871
AUC: 0.9218432932490317
AUC: 0.9219024504270996
AUC: 0.9219015466524584
AUC: 0.9219449550766715
AUC: 0.9218620513801491
AUC: 0.9218524127197468
~~~~~~~~~~~~~~~~~~~~ Epoch:  92
AUC: 0.921892670752214
AUC: 0.9218068490173555
AUC: 0.9218025176293143
AUC: 0.9219056985675211
AUC: 0.9218338388689196
AUC: 0.9219027308536283
AUC: 0.9218947186670927
AUC: 0.9219243477329017
AUC: 0.9218364941075377
AUC: 0.9218436666169243
~~~~~~~~~~~~~~~~~~~~ Epoch:  93
AUC: 0.921884844448406
AUC: 0.921805307472666
AUC: 0.9218245815885963
AUC: 0.9219102142358526
AUC: 0.9218223878519227
AUC: 0.9218966576162343
AUC: 0.9218813158814556
AUC: 0.92192955565415


In [17]:
# Chunk size 20, capped at 80
fig = go.Figure()
fig.add_trace(go.Scatter(y=aucs))

In [16]:
# Chunk size 10, capped at 40
fig = go.Figure()
fig.add_trace(go.Scatter(y=aucs))
# fig.add_trace(go.Histogram/(x=pred_means[labels==False]))

In [18]:
TEMP_DATA + 'model_' + str(chunk_sz) + '.p'

'../temp_data/reddit/multidoc_20_limitted_data_40_capped/model_20.p'

In [19]:
with open(TEMP_DATA + 'model_' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((clf, transformer, scaler, secondary_scaler, aucs ), f)