# Testing `TFNoiseAwareModel` in Jupyter Notebook

We'll start by testing the `textRNN` model on a categorical problem from `tutorials/crowdsourcing`.  In particular we'll test for (a) basic performance and (b) proper construction / re-construction of the TF computation graph both after (i) repeated notebook calls, and (ii) with `GridSearch` in particular.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.environ['SNORKELDB'] = 'sqlite:///{0}{1}crowdsourcing.db'.format(os.getcwd(), os.sep)

from snorkel import SnorkelSession
session = SnorkelSession()

### Load candidates and training marginals

In [None]:
from snorkel.models import candidate_subclass
from snorkel.contrib.models.text import RawText
Tweet = candidate_subclass('Tweet', ['tweet'], cardinality=5)
train_tweets = session.query(Tweet).filter(Tweet.split == 0).order_by(Tweet.id).all()
len(train_tweets)

In [None]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, train_tweets, split=0)
train_marginals.shape

### Train basic LSTM

In [None]:
from snorkel.contrib.rnn import TextRNN

train_kwargs = {
    'dim':        100,
    'lr':         0.01,
    'n_epochs':   50,
    'dropout':    0.2,
    'print_freq': 10
}
lstm = TextRNN(seed=1701, cardinality=Tweet.cardinality)
lstm.train(train_tweets, train_marginals, **train_kwargs)

In [None]:
import numpy as np
test_tweets = session.query(Tweet).filter(Tweet.split == 1).order_by(Tweet.id).all()
test_labels = np.load('crowdsourcing_test_labels.npy')
correct, incorrect = lstm.score(session, test_tweets, test_labels)
acc = len(correct) / float(len(correct) + len(incorrect))
assert acc > 0.60

### Run `GridSearch`

In [None]:
from snorkel.learning.utils import GridSearch
from snorkel.learning import RangeParameter

lstm = TextRNN(seed=1701, cardinality=Tweet.cardinality)

# Searching over learning rate
rate_param = RangeParameter('lr', 1e-4, 1e-2, step=1, log_base=10)
dim_param = RangeParameter('dim', 50, 100, step=25)
searcher = GridSearch(session, lstm, train_tweets, train_marginals, [rate_param, dim_param])

# Use test set here (just for testing)
train_kwargs = {
    'dim':        100,
    'n_epochs':   50,
    'dropout':    0.2,
    'print_freq': 10
}
searcher.fit(test_tweets, test_labels, **train_kwargs)

In [None]:
correct, incorrect = lstm.score(session, test_tweets, test_labels)
acc = len(correct) / float(len(correct) + len(incorrect))
assert acc > 0.60