# Slicing CDR Relation Extraction 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append('/dfs/scratch0/vschen/metal')

import metal
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
np.set_printoptions(precision=4, suppress=True)

In [None]:
print('PyTorch: ', torch.__version__)
print('MeTaL:   ', metal.__version__)
print('Python:  ', sys.version)
print('Python:  ', sys.version_info)

## Initalize CDR Dataset
To uncompress the SQLite db: ```bzip2 -d cdr.db.bz2```

In [None]:
from metal.contrib.backends.wrapper import SnorkelDataset
import os

db_conn_str   = os.path.join(os.getcwd(),"cdr.db")
candidate_def = ['ChemicalDisease', ['chemical', 'disease']]

train, dev, test = SnorkelDataset.splits(db_conn_str, 
                                         candidate_def, 
                                         max_seq_len=125)

print(f'[TRAIN] {len(train)}')
print(f'[DEV]   {len(dev)}')
print(f'[TEST]  {len(test)}')

## Get Pretrained Embeddings

Download [GloVe embeddings](http://nlp.stanford.edu/data/glove.6B.zip):
`wget http://nlp.stanford.edu/data/glove.6B.zip \
&& mkdir -p glove.6B \
&& unzip glove.6B.zip -d glove.6B \
&& rm glove.6B.zip`

In [None]:
from embeddings import EmbeddingLoader, load_embeddings
emb_path  = "../glove.6B/glove.6B.50d.txt"
embs  = EmbeddingLoader(emb_path, fmt='text')

## Generate `L_*` to target slices

In [None]:
from labeling_functions import LFs
print ([lf.__name__ for lf in LFs])

In [None]:
%%time 
from snorkel import SnorkelSession
session = SnorkelSession()

from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)
L_train = labeler.apply(split=0)
L_dev = labeler.apply(split=1) # used for debugging
L_test = labeler.apply(split=2) # used for evaluation

from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
from snorkel.learning import GenerativeModel

# need to extract `accs` from gen_model
gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, deps=deps, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=0.0
)

accs = np.array(gen_model.learned_lf_stats()['Accuracy'])
accs[np.isnan(accs)] = 0
accs = np.minimum(accs, 0.999)

gen_marginals = gen_model.marginals(L_train)

In [None]:
L = L_train.copy()
L[L==-1] = 2 # convert to multiclass
Y_dev = np.array([ex[1] for ex in dev])

In [None]:
from metal.label_model import LabelModel
label_model = LabelModel(k=2, seed=123)
label_model.train_model(L, Y_dev=Y_dev)
label_model.score((L_dev, Y_dev))

### Weak Labels in Dataset

In [None]:
metal_marginals = label_model.predict_proba(L)
metal_marginals

In [None]:
snorkel_marginals = np.vstack((gen_marginals, 1-gen_marginals)).T
snorkel_marginals

In [None]:
from metal.contrib.slicing.sqlite_wrapper \
    import SnorkelDataset as SnorkelSliceDataset

train_metal = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    train_marginals=metal_marginals
)

train_snorkel = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    train_marginals=snorkel_marginals
)

### Custom Slicing Dataset

In [None]:
train_slice = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train.todense()
)

train_slice_metal = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train.todense(),
    train_marginals=metal_marginals
)

train_slice_snorkel = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train.todense(),
    train_marginals=snorkel_marginals
)

In [None]:
from metal.contrib.slicing.online_dp import SliceDPModel
from metal.end_model import EndModel
from metal.modules import LSTMModule
def init_model(use_end_model=False, r=None, reweight=None):
    wembs = load_embeddings(train.word_dict, embs)
    lstm = LSTMModule(embed_size=50, 
                      hidden_size=100, 
                      embeddings=wembs,
                      lstm_reduction='attention', 
                      dropout=0.0, 
                      num_layers=1, 
                      freeze=False)
    if use_end_model:
        model = EndModel([200, 2], input_module=lstm, seed=123, use_cuda=True)
    else:
        input_layer_config = {
            "input_relu": False,
            "input_batchnorm": False,
            "input_dropout": 0.0,
        }
        model = SliceDPModel(lstm, accs, r, reweight, seed=123, use_cuda=True, input_layer_config=input_layer_config)
        
    model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
    model.config['train_config']['validation_metric'] = 'f1'
    model.config['train_config']['batch_size'] = 32
    model.config['train_config']['n_epochs'] = 10
    return model

from metal.modules import LSTMModule
from metal.tuners import RandomSearchTuner


def search_slice_weights(train_loader, dev_loader, r, rw, max_search=1, search_space=None, log_dir='./run_logs'):
    wembs = load_embeddings(train.word_dict, embs)
    lstm = LSTMModule(embed_size=50, 
                      hidden_size=100, 
                      embeddings=wembs,
                      lstm_reduction='attention', 
                      dropout=0.0, 
                      num_layers=1, 
                      freeze=False)
    
    searcher = RandomSearchTuner(SliceDPModel, validation_metric='f1', log_dir=log_dir)

    if search_space is None:
        search_space = {
            "slice_weight": [0, 0.25, 0.5, 0.75, 1.0]
        }

    input_layer_config = {
        "input_relu": False,
        "input_batchnorm": False,
        "input_dropout": 0.0,
    }
    
    trained_model = searcher.search(
        search_space,
        dev_loader,
        train_args=[train_loader],
        init_args=[lstm, accs, r, reweight],
        init_kwargs={"use_cuda": True, "input_layer_config": input_layer_config},
        train_kwargs={
            "lr": 0.01,
            "batch_size": 32,
            "n_epochs": 10
        },
        max_search=max_search
    )
    return trained_model

## (a) `Oracle`: EndModel Trained on Full GT

In [None]:
oracle = init_model(use_end_model=True)
%time oracle.train_model(train, dev_data=dev)
oracle_score = oracle.score(test, metric=['precision', 'recall', 'f1'])

## (b) `BaseWeak`: EndModel trained on weak labels

In [None]:
from metal.end_model import EndModel
from metal.modules import LSTMModule

base_weak = init_model(use_end_model=True)
%time base_weak.train_model(train_snorkel, dev_data=dev)
base_weak_score = base_weak.score(test, metric=['precision', 'recall', 'f1'])

## (e) `SliceOursWeak`: Slice Model with $\tilde{Y}$ priors

In [None]:
# slice_ours_weak = init_model(r=200, reweight=True)
# %time slice_ours_weak.train_model(train_slice_snorkel, dev_data=dev)
search_space = {
    "slice_weight": {"range": [0,1.0] ,"scale": "linear"}
}
%time slice_ours_weak = search_slice_weights(train_slice_snorkel, dev, r=200, rw=True, \
                                             max_search=10, search_space=search_space)

slice_ours_weak_score = slice_ours_weak.score(test, metric=['precision', 'recall', 'f1'])

## (f) `SliceUWWeak`: Unweighted Slice model with $\tilde{Y}$ priors

In [None]:
# slice_uw_weak = init_model(r=200, rw=False)
# %time slice_uw_weak.train_model(train_slice_snorkel, dev_data=dev)
search_space = {
    "slice_weight": {"range": [0,1.0] ,"scale": "linear"}
}
%time slice_uw_weak = search_slice_weights(train_slice_snorkel, dev, r=200, rw=False, \
                                             max_search=10, search_space=search_space)

slice_uw_weak_score = slice_uw_weak.score(test, metric=['precision', 'recall', 'f1'])

## Overall Scores

In [None]:
print ("Oracle_score:", oracle_score)
print ("base_weak_score:", base_weak_score)
print ("slice_ours_weak_score:", slice_ours_weak_score)
print ("slice_uw_weak_score:", slice_uw_weak_score)

## Slice-specific scores

In [None]:
# TODO: don't call private fns
Yp_oracle, Y = oracle._get_predictions(test)
Yp_base_weak, Y = base_weak._get_predictions(test)
Yp_slice_ours_weak, Y = slice_ours_weak._get_predictions(test)
Yp_slice_uw_weak, Y = slice_uw_weak._get_predictions(test)

#### `slice_ours` (re-weighting, accuracy priors) vs. `base_weak` (end_model trained on weak labels)

In [None]:
L_test = L_test.todense()

In [None]:
from metal.contrib.slicing.experiment_utils import compare_LF_slices

#### `slice_ours_weak` (slice model with weak priors + reweighting) vs. `base_weak` (end_model trained on weak labels)

In [None]:
print ("slice_ours_weak vs base_weak")
compare_LF_slices(Yp_slice_ours_weak, Yp_base_weak,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0)

#### `slice_ours_weak` vs. `Yp_slice_uw_weak` (unweighted slice model)

In [None]:
print ("slice_ours_weak vs. slice_uw_weak")
compare_LF_slices(Yp_slice_ours_weak, Yp_slice_uw_weak,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0)

In [None]:
print ("slice_ours_weak vs. oracle")
compare_LF_slices(Yp_slice_ours_weak, Yp_oracle,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0)

In [None]:
print ("slice_uw_weak vs. base_weak")
compare_LF_slices(Yp_slice_uw_weak, Yp_base_weak,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0)