In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../")
import warnings
warnings.filterwarnings("ignore")

In [2]:
import pickle
import glob
from tqdm.notebook import trange, tqdm
import json
import re
import pandas as pd
import numpy as np
from collections import defaultdict

from utills import chunker, cartesian_product, get_num_chunks
from train_utils import generate_unique_pairs, get_random_author_excluding, generate_doc_pairs, fit_transformers, vectorize


from sklearn.linear_model import SGDClassifier
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve, auc

from sklearn.utils.fixes import loguniform
from sklearn.model_selection import RandomizedSearchCV

In [3]:

from plotly.offline import init_notebook_mode
import plotly.offline as py
import plotly.graph_objs as go
init_notebook_mode(connected=True)

In [4]:
# BASE_PATH = '../data/reddit_2/'
COMPUTED_DATA_PATH = '../temp_data/reddit/preprocessed/'
chunk_sz=10
TEMP_DATA = '../temp_data/reddit/multidoc_' + str(chunk_sz) + '_limitted_data_40_capped/'
MAX_COMMENTS = 40

In [5]:
with open(COMPUTED_DATA_PATH + 'metadata.p', 'rb') as f:
    (train_files, test_files, min_count, author_mapping_all) = pickle.load(f)

Vectorize Data
===

In [6]:
exclude_users_root = np.random.choice(train_files, size=int(len(train_files)*0.5), replace=False)
exclude_users_root = [u.replace('../data/reddit_2/', '/scratch/jnw301/av/data/reddit_2/reddit_2/').replace('.jsonl', '') for u in exclude_users_root]

exclude_users = []
for u in exclude_users_root:
    exclude_users.extend(author_mapping_all[u])
    
    
exclude_users_test_root = np.random.choice(test_files, size=int(len(test_files)*0.5), replace=False)
exclude_users_test_root = [u.replace('../data/reddit_2/', '/scratch/jnw301/av/data/reddit_2/reddit_2/').replace('.jsonl', '') for u in exclude_users_test_root]

exclude_users_test = []
for u in exclude_users_test_root:
    exclude_users_test.extend(author_mapping_all[u])

In [8]:
transformer, scaler, secondary_scaler = fit_transformers(
    COMPUTED_DATA_PATH + 'train.jsonl', 
    author_mapping_all, 
    chunk_sz, 
    sample_fraction=0.1,
    exclude_users=exclude_users,
    max_comments=MAX_COMMENTS
)

Sampled: 7411
Reading preprocessed data...



Fitting transformer
Generating pairs











In [9]:
with open(TEMP_DATA + 'model_' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((transformer, scaler, secondary_scaler ), f)

In [10]:
XX_train, author_bounds, author_subreddit, x_shape = vectorize(
    preprocessed_path = COMPUTED_DATA_PATH + 'train.jsonl', 
    vectorized_x_path = TEMP_DATA + 'XX_train_' + str(chunk_sz) + '.npy', 
    transformer=transformer,
    scaler=scaler,
    chunk_sz=chunk_sz,
    max_comments=MAX_COMMENTS,
    exclude_users=exclude_users
)

Precomputing record size...








In [11]:
XX_test, author_bounds_test, author_subreddit_test, x_shape_test = vectorize(
    preprocessed_path = COMPUTED_DATA_PATH + 'test.jsonl', 
    vectorized_x_path = TEMP_DATA + 'XX_test_' + str(chunk_sz) + '.npy', 
    transformer=transformer,
    scaler=scaler,
    chunk_sz=chunk_sz,
    max_comments=MAX_COMMENTS,
    exclude_users=exclude_users_test
)

Precomputing record size...








In [12]:
with open(TEMP_DATA + 'experiment_data' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((
        author_bounds, 
        author_bounds_test, 
        author_subreddit, 
        author_subreddit_test, 
        x_shape, 
        x_shape_test,
        exclude_users,
        exclude_users_test
    ), f)

Train the classifier
===

In [13]:
author_mapping = defaultdict(set)
author_to_root = {}
for y in author_bounds.keys():
    u = re.search(r'(.*)_[A-Z]+$', y).group(1)
    author_mapping[u].add(y)
    author_to_root[y] = u

subreddit_to_author = defaultdict(list)
for k, v in author_subreddit.items():
    subreddit_to_author[v].append(k)
    
author_mapping_test = defaultdict(set)
author_to_root_test = {}
for y in author_bounds_test.keys():
    u = re.search(r'(.*)_[A-Z]+$', y).group(1)
    author_mapping_test[u].add(y)
    author_to_root_test[y] = u

subreddit_to_author_test = defaultdict(list)
for k, v in author_subreddit_test.items():
    subreddit_to_author_test[v].append(k)
    
X_idxs_train, Y_train = generate_doc_pairs(author_mapping, subreddit_to_author, author_to_root, author_bounds, author_subreddit)
X_idxs_test, Y_test = generate_doc_pairs(author_mapping_test, subreddit_to_author_test, author_to_root_test, author_bounds_test, author_subreddit_test)



















In [14]:
with open(TEMP_DATA + 'experiment_data' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((
        author_bounds, 
        author_bounds_test, 
        author_subreddit, 
        author_subreddit_test, 
        x_shape, 
        x_shape_test,
        exclude_users,
        exclude_users_test,
        X_idxs_train,
        Y_train,
        X_idxs_test,
        Y_test
    ), f)

In [15]:
batch_sz = 50000
clf = SGDClassifier(loss='log', alpha=0.01)
x_test_diff_sample = secondary_scaler.transform(np.abs(XX_test[X_idxs_test[:batch_sz, 0]] - XX_test[X_idxs_test[:batch_sz, 1]]))
y_test_sample = Y_test[:batch_sz]
aucs = []
for i in range(100):
    for idxs in chunker(np.arange(len(X_idxs_train)), batch_sz):
        x_diff = secondary_scaler.transform(np.abs(XX_train[X_idxs_train[idxs, 0]] - XX_train[X_idxs_train[idxs, 1]]))
        x_diff[np.isnan(x_diff)]=0
        y = Y_train[idxs]
        clf.partial_fit(x_diff, y, classes=[0, 1])

        probs = clf.predict_proba(x_test_diff_sample)[:, 1]

        fpr, tpr, thresh = roc_curve(y_test_sample, probs)
        roc_auc = auc(fpr, tpr)
        print('AUC:', roc_auc)
    print('~'*20, 'Epoch: ', i)
    aucs.append(roc_auc)

AUC: 0.7665618885937436
AUC: 0.7826923280976275
AUC: 0.7866352787606996
AUC: 0.8173270708760466
AUC: 0.8193893421166447
AUC: 0.8228230360764885
AUC: 0.8319298195746292
AUC: 0.833070243844846
AUC: 0.8366660503834911
AUC: 0.8404726303502159
AUC: 0.8404339698637991
AUC: 0.8431527748593364
AUC: 0.8424224591898081
AUC: 0.8465653394078376
AUC: 0.8480789829235569
AUC: 0.8466434261886571
AUC: 0.8478279981677834
AUC: 0.8496910680118542
AUC: 0.8482689449588823
AUC: 0.8525917067439379
AUC: 0.852708362690975
AUC: 0.8526700931190949
AUC: 0.8530177619341397
AUC: 0.8537143235753197
AUC: 0.8524229678433173
AUC: 0.8535062256688645
AUC: 0.8539247413352554
AUC: 0.8552490620570434
AUC: 0.8531894631338521
AUC: 0.8542337817060588
AUC: 0.8545501757560567
AUC: 0.8544678962566202
AUC: 0.8564333728924425
AUC: 0.8555464326676602
AUC: 0.855255830966542
AUC: 0.8575150382218852
AUC: 0.8569836411766885
AUC: 0.8570865109778711
AUC: 0.856828263863102
~~~~~~~~~~~~~~~~~~~~ Epoch:  0
AUC: 0.8574366198045529
AUC: 0.857946

AUC: 0.8634503737384087
AUC: 0.8631085013528471
AUC: 0.863531727218984
AUC: 0.8633242909834453
AUC: 0.863329468998948
AUC: 0.8633567977701262
AUC: 0.8632715255315313
AUC: 0.8632893778294112
AUC: 0.8631091069499571
AUC: 0.8630810652403823
AUC: 0.862849791228647
AUC: 0.8632119110646808
AUC: 0.8629204842775527
AUC: 0.863094282637628
AUC: 0.8632315128653235
AUC: 0.8633259780039664
AUC: 0.8632049867506326
AUC: 0.8631542015051424
~~~~~~~~~~~~~~~~~~~~ Epoch:  8
AUC: 0.8631195831391199
AUC: 0.8632776279637868
AUC: 0.8630959440244093
AUC: 0.8631879002607385
AUC: 0.8634027126049948
AUC: 0.8630532334070545
AUC: 0.8632594520399379
AUC: 0.8632186495341129
AUC: 0.8634747305978425
AUC: 0.8634535603327266
AUC: 0.8632896005225286
AUC: 0.8633866402499439
AUC: 0.863491504676529
AUC: 0.8632780925753264
AUC: 0.8635511047244007
AUC: 0.8632862921679451
AUC: 0.8631716949306608
AUC: 0.8631904812579463
AUC: 0.8632493667653122
AUC: 0.8632548924384154
AUC: 0.8632495269761878
AUC: 0.8635754792070307
AUC: 0.8632911

AUC: 0.8634216655515909
AUC: 0.8635076859749754
AUC: 0.863542733706145
AUC: 0.8633988947798268
AUC: 0.8635784639356451
AUC: 0.8634827683774767
AUC: 0.8636811623111906
AUC: 0.8636621500865704
AUC: 0.8634240174472463
AUC: 0.8636802795492655
AUC: 0.8636428286549603
AUC: 0.8636043476047244
AUC: 0.8636652741986467
AUC: 0.8635200734798851
AUC: 0.8634595842616533
AUC: 0.8634862465555884
AUC: 0.8634668049658212
AUC: 0.8634312637851546
AUC: 0.8634760859818511
AUC: 0.8635728149001678
AUC: 0.8634094526765357
AUC: 0.863622578000271
AUC: 0.8635422482671915
AUC: 0.8635233626091633
AUC: 0.8636258879569632
AUC: 0.8635945346885872
AUC: 0.863528692824998
AUC: 0.8635013496348413
AUC: 0.8634427348838553
AUC: 0.8632752648533701
AUC: 0.8634294325748453
AUC: 0.8633038961389676
AUC: 0.8633172481133491
AUC: 0.8634985523529515
AUC: 0.8634349037762505
AUC: 0.8635126525121224
AUC: 0.8634718211683396
~~~~~~~~~~~~~~~~~~~~ Epoch:  17
AUC: 0.8635904877618669
AUC: 0.863567603240381
AUC: 0.863386970284348
AUC: 0.863531

AUC: 0.8636751800370918
AUC: 0.8636221246034926
AUC: 0.8635422883199105
AUC: 0.8636378332798558
AUC: 0.8636306478220803
AUC: 0.8635971188900124
AUC: 0.8635471090651605
AUC: 0.8635689714412595
AUC: 0.8634374911818933
AUC: 0.8634831705067747
AUC: 0.8634095616199313
AUC: 0.8634616782177986
AUC: 0.8635541022698852
AUC: 0.8635516190013117
AUC: 0.8635700656815406
AUC: 0.8635420800457722
~~~~~~~~~~~~~~~~~~~~ Epoch:  25
AUC: 0.8636098075913683
AUC: 0.863594167805682
AUC: 0.8634879191571307
AUC: 0.8636192263887512
AUC: 0.8635782188130053
AUC: 0.8634573589325898
AUC: 0.8635798866082214
AUC: 0.8635496115590391
AUC: 0.8636784595537175
AUC: 0.8636798437756835
AUC: 0.863560137413573
AUC: 0.8636142294115379
AUC: 0.8636714487257966
AUC: 0.8636109146485196
AUC: 0.8636943861168714
AUC: 0.8636070663832851
AUC: 0.8635762129728415
AUC: 0.8635990494310647
AUC: 0.8635858496570151
AUC: 0.8635636412254248
AUC: 0.8635684795938712
AUC: 0.8636255563204505
AUC: 0.8634741266028411
AUC: 0.863643793124432
AUC: 0.8636

AUC: 0.8635923285848288
AUC: 0.8634888515844272
AUC: 0.863590500578737
AUC: 0.8635504734935504
AUC: 0.8636405728858304
AUC: 0.8636921223371976
AUC: 0.8635974457201989
AUC: 0.863638551024579
AUC: 0.8636388329957202
AUC: 0.8636146940230774
AUC: 0.8636766315476256
AUC: 0.8636463308647033
AUC: 0.8636040271829728
AUC: 0.8636236289836156
AUC: 0.8635998905381621
AUC: 0.8635548504546747
AUC: 0.8635870592491268
AUC: 0.8635905182019332
AUC: 0.8635128159272156
AUC: 0.8636117157028981
AUC: 0.8636195740463515
AUC: 0.8636083865209009
AUC: 0.8636627348562669
AUC: 0.8636416927598515
AUC: 0.8636358899219334
AUC: 0.8635828729389445
AUC: 0.8635678836094136
AUC: 0.8634844313663665
AUC: 0.8635066702380234
AUC: 0.8634019516033353
AUC: 0.8634729586655573
AUC: 0.8635650959401763
AUC: 0.8635654019429487
AUC: 0.8635915051009276
AUC: 0.8635733612192538
~~~~~~~~~~~~~~~~~~~~ Epoch:  34
AUC: 0.8636180584514672
AUC: 0.8636034696491254
AUC: 0.8635390584686563
AUC: 0.8636314584891114
AUC: 0.863586621873436
AUC: 0.8635

AUC: 0.8636180985041861
AUC: 0.8636503457492485
AUC: 0.8636379726633177
AUC: 0.8636370226128246
AUC: 0.8636069686546508
AUC: 0.8635885219744219
AUC: 0.8635114781664035
AUC: 0.8635196761569133
AUC: 0.8634449441918313
AUC: 0.8635029581520334
AUC: 0.8635722221199276
AUC: 0.863583166124847
AUC: 0.8635984886929996
AUC: 0.8635855548690039
~~~~~~~~~~~~~~~~~~~~ Epoch:  42
AUC: 0.8636216840235844
AUC: 0.8636251974480889
AUC: 0.8635694216338201
AUC: 0.8636097611302144
AUC: 0.8635892653528852
AUC: 0.8635345693599191
AUC: 0.863596692729083
AUC: 0.8635332203843455
AUC: 0.8636317612876665
AUC: 0.8636776104360763
AUC: 0.863574692571631
AUC: 0.8636109547012385
AUC: 0.8636378621178133
AUC: 0.8636188915480207
AUC: 0.8636654872791116
AUC: 0.863630221661151
AUC: 0.8636303338087639
AUC: 0.8636417119851566
AUC: 0.8636214725452287
AUC: 0.8635908177962709
AUC: 0.8635778599406437
AUC: 0.8636199521440182
AUC: 0.8635659803042101
AUC: 0.8636667449344857
AUC: 0.863634348693309
AUC: 0.8636231131045958
AUC: 0.863664

AUC: 0.86361609586824
AUC: 0.8635747870960476
AUC: 0.8636584219794927
AUC: 0.8636784611558261
AUC: 0.8635970660204235
AUC: 0.8636262804736086
AUC: 0.863641082356415
AUC: 0.8636195420041763
AUC: 0.8636592262380888
AUC: 0.8636314376616975
AUC: 0.8636384532959448
AUC: 0.8636460825378458
AUC: 0.8636074060303416
AUC: 0.8635759630438753
AUC: 0.8635891083462269
AUC: 0.8636144729320689
AUC: 0.8635453066928088
AUC: 0.8636241208310039
AUC: 0.863619642937028
AUC: 0.8636167559370478
AUC: 0.8636605223440732
AUC: 0.86365325678086
AUC: 0.8636508920683346
AUC: 0.863643041735425
AUC: 0.8636410358952613
AUC: 0.8635678916199573
AUC: 0.8635581299713004
AUC: 0.8634964535904797
AUC: 0.8635309293688229
AUC: 0.863598097778463
AUC: 0.863608167032001
AUC: 0.8636186816717738
AUC: 0.8636081478066959
~~~~~~~~~~~~~~~~~~~~ Epoch:  51
AUC: 0.8636302681223049
AUC: 0.8636352026172766
AUC: 0.8636000988123005
AUC: 0.86364132107062
AUC: 0.8636023994404758
AUC: 0.8635488217194218
AUC: 0.863618547094638
AUC: 0.8635757531676

AUC: 0.8636449081921269
AUC: 0.8636403710201268
AUC: 0.8636256364258883
AUC: 0.8636437018042329
AUC: 0.8635846913323838
AUC: 0.8635600348786125
AUC: 0.8635066846570022
AUC: 0.8635332556307381
AUC: 0.8635870928934107
AUC: 0.8635946244066776
AUC: 0.8636161150935449
AUC: 0.8636062701352326
~~~~~~~~~~~~~~~~~~~~ Epoch:  59
AUC: 0.8636291081955647
AUC: 0.8636342109119559
AUC: 0.8635774241670617
AUC: 0.8636410407015873
AUC: 0.8636083560808345
AUC: 0.8635670633297299
AUC: 0.8636192520224912
AUC: 0.8635861300260477
AUC: 0.8636521064667724
AUC: 0.8636775207179858
AUC: 0.8636060121957229
AUC: 0.8636221262056014
AUC: 0.8636450876283077
AUC: 0.8636249987866028
AUC: 0.86365637608661
AUC: 0.863644135975706
AUC: 0.8636550719700817
AUC: 0.8636633740976609
AUC: 0.8636152916096438
AUC: 0.863593922683042
AUC: 0.8636103138577357
AUC: 0.8636239782433246
AUC: 0.8635618356488556
AUC: 0.8636299957638163
AUC: 0.8636330461788897
AUC: 0.8636253256167894
AUC: 0.8636598590710478
AUC: 0.8636528242114956
AUC: 0.86365

AUC: 0.8635926361897102
AUC: 0.8636496552403742
AUC: 0.8636744478733898
AUC: 0.8635998360664644
AUC: 0.8636189796640026
AUC: 0.863633656582326
AUC: 0.8636237731734039
AUC: 0.8636518949884164
AUC: 0.8636410919690676
AUC: 0.8636581255893727
AUC: 0.863668946231918
AUC: 0.8636505716465832
AUC: 0.8636287236894632
AUC: 0.8636197518804235
AUC: 0.8636386215173641
AUC: 0.863587182611501
AUC: 0.8636550767764082
AUC: 0.8636256188026918
AUC: 0.8636195035535661
AUC: 0.8636514287747682
AUC: 0.8636500942181736
AUC: 0.8636823895264985
AUC: 0.8636670893878687
AUC: 0.8636280668248727
AUC: 0.8635634425639388
AUC: 0.8635698445905321
AUC: 0.8635151325764783
AUC: 0.8635329960891196
AUC: 0.86358552122472
AUC: 0.8635940636686126
AUC: 0.8636201812455705
AUC: 0.8636112671124461
~~~~~~~~~~~~~~~~~~~~ Epoch:  68
AUC: 0.8636297298137624
AUC: 0.8636395299130295
AUC: 0.8635941774183344
AUC: 0.8636566372303374
AUC: 0.8636205096778655
AUC: 0.8635827944356154
AUC: 0.8636261795407569
AUC: 0.8635920065609687
AUC: 0.863652

AUC: 0.8636518068724349
AUC: 0.8636413675317738
AUC: 0.8636330413725632
AUC: 0.8635815143507186
AUC: 0.8635778070710548
AUC: 0.8635286928249981
AUC: 0.8635447892116803
AUC: 0.8635864840920828
AUC: 0.8635952556375279
AUC: 0.86361610868511
AUC: 0.8636077504837243
~~~~~~~~~~~~~~~~~~~~ Epoch:  76
AUC: 0.8636311412715774
AUC: 0.8636345793969701
AUC: 0.8635959589632722
AUC: 0.8636552994695252
AUC: 0.8636154790563684
AUC: 0.863580786993343
AUC: 0.8636279050118882
AUC: 0.8635973463894558
AUC: 0.8636576850094646
AUC: 0.8636785172296328
AUC: 0.8636074925442145
AUC: 0.863618654435925
AUC: 0.8636322018675746
AUC: 0.8636197342572272
AUC: 0.8636476285727964
AUC: 0.8636370642676524
AUC: 0.8636562126715168
AUC: 0.8636661601647895
AUC: 0.8636185118482456
AUC: 0.8635968897884602
AUC: 0.8636068533028203
AUC: 0.8636175009176199
AUC: 0.8635757916182383
AUC: 0.86362437075997
AUC: 0.8636282847116635
AUC: 0.8636280427932413
AUC: 0.8636572236021427
AUC: 0.8636528258136045
AUC: 0.8636856081629914
AUC: 0.8636729

AUC: 0.8636732142496467
AUC: 0.8636026894221607
AUC: 0.863615370112973
AUC: 0.8636254185390972
AUC: 0.8636135036562708
AUC: 0.8636410807543062
AUC: 0.8636362247626636
AUC: 0.8636597773635013
AUC: 0.8636703016159264
AUC: 0.8636189780618938
AUC: 0.863600755676891
AUC: 0.8636058808228048
AUC: 0.8636151682472695
AUC: 0.863570974077206
AUC: 0.8636217449037172
AUC: 0.863624444456973
AUC: 0.8636260770057964
AUC: 0.863648277426843
AUC: 0.8636433573508502
AUC: 0.8636441087398572
AUC: 0.8636420227942554
AUC: 0.8636362439879688
AUC: 0.8635901064599827
AUC: 0.8635848771769996
AUC: 0.863534476437611
AUC: 0.8635474807543921
AUC: 0.8635896819011621
AUC: 0.8635956161119983
AUC: 0.8636146747977723
AUC: 0.8636084137567497
~~~~~~~~~~~~~~~~~~~~ Epoch:  85
AUC: 0.8636297089863487
AUC: 0.8636344239924207
AUC: 0.8635997447462652
AUC: 0.8636540898774137
AUC: 0.8636209935147101
AUC: 0.8635865321553456
AUC: 0.863624226570182
AUC: 0.8635982451724684
AUC: 0.8636552209661961
AUC: 0.8636775319327471
AUC: 0.86360961

AUC: 0.8636335348220605
AUC: 0.863590051988285
AUC: 0.8635881614999517
AUC: 0.8635468062666054
AUC: 0.8635559671244784
AUC: 0.8635938938450844
AUC: 0.8635997847989842
AUC: 0.8636164723637977
AUC: 0.8636112991546212
~~~~~~~~~~~~~~~~~~~~ Epoch:  93
AUC: 0.8636296625251948
AUC: 0.8636341932887597
AUC: 0.8636009158877667
AUC: 0.8636527633313629
AUC: 0.8636193096984065
AUC: 0.8635909011059262
AUC: 0.8636256860912594
AUC: 0.8636006835819968
AUC: 0.863657101841877
AUC: 0.8636728617857202
AUC: 0.8636108409515165
AUC: 0.8636226292677514
AUC: 0.8636339129197272
AUC: 0.8636223537050449
AUC: 0.8636460585062145
AUC: 0.8636382738597639
AUC: 0.8636645051864433
AUC: 0.8636726471031468
AUC: 0.8636444435805873
AUC: 0.8636244829075831
AUC: 0.863613973074137
AUC: 0.8636236530152469
AUC: 0.8635844686392666
AUC: 0.8636303898825705
AUC: 0.8636185182566805
AUC: 0.8636178533815463
AUC: 0.863643020908011
AUC: 0.8636406161427668
AUC: 0.8636774518273092
AUC: 0.8636725509766213
AUC: 0.8636364843042823
AUC: 0.86360

In [17]:
# Chunk size 10, capped at 40
fig = go.Figure()
fig.add_trace(go.Scatter(y=aucs))
# fig.add_trace(go.Histogram/(x=pred_means[labels==False]))

In [23]:
# Chunk size 20, capped at 80
fig = go.Figure()
fig.add_trace(go.Scatter(y=aucs))
# fig.add_trace(go.Histogram/(x=pred_means[labels==False]))

In [17]:
# Chunk size 20, capped at 40
fig = go.Figure()
fig.add_trace(go.Scatter(y=aucs))
# fig.add_trace(go.Histogram/(x=pred_means[labels==False]))

In [19]:
with open(TEMP_DATA + 'model_' + str(chunk_sz) + '.p', 'wb') as f:
    pickle.dump((clf, transformer, scaler, secondary_scaler, aucs ), f)