In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

train_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()

In [20]:
import bz2
from six.moves.cPickle import load

with bz2.BZ2File('data/ctd.pkl.bz2', 'rb') as ctd_f:
    ctd_unspecified, ctd_therapy, ctd_marker = load(ctd_f)
    
    
def cand_in_ctd_unspecified(c):
    return 1 if c.get_cids() in ctd_unspecified else 0

def cand_in_ctd_therapy(c):
    return 1 if c.get_cids() in ctd_therapy else 0

def cand_in_ctd_marker(c):
    return 1 if c.get_cids() in ctd_marker else 0

def LF_in_ctd_unspecified(c):
    if(cand_in_ctd_unspecified(c)==1):
        return (-1,1)
    else:
        return (0,0)

def LF_in_ctd_therapy(c):
    if(cand_in_ctd_therapy(c)==1):
        return (-1,1)
    else:
        return (0,0)

def LF_in_ctd_marker(c):
    if(cand_in_ctd_marker(c)==1):
        return (1,1)
    else:
        return (0,0)

In [21]:
def LF_closer_chem(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find chemical closer than @dist/2 in either direction
    sent = c.get_parent()
    closest_other_chem = float('inf')
    for i in range(dis_end, min(len(sent.words), dis_end + dist / 2)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return (-1,1)
    for i in range(max(0, dis_start - dist / 2), dis_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return (-1,1)
    return (0,0)

def LF_closer_dis(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find chemical disease than @dist/8 in either direction
    sent = c.get_parent()
    for i in range(chem_end, min(len(sent.words), chem_end + dist / 8)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Disease' and cid != sent.entity_cids[dis_start]:
            return (-1,1)
    for i in range(max(0, chem_start - dist / 8), chem_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Disease' and cid != sent.entity_cids[dis_start]:
            return (-1,1)
    return (0,0)

In [3]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=1, annotator='gold')

from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_dev

AnnotatorLabels created: 888


<888x1 sparse matrix of type '<type 'numpy.int64'>'
	with 888 stored elements in Compressed Sparse Row format>

In [5]:
#gold_labels_dev = [x[0,0] for x in L_gold_dev.todense()]
#for i,L in enumerate(gold_labels_dev):
#    print(i,gold_labels_dev[i])

gold_labels_dev = []
for i,L in enumerate(L_gold_dev):
    gold_labels_dev.append(L[0,0])
    
    
print(len(gold_labels_dev))
print(gold_labels_dev.count(1),gold_labels_dev.count(-1))

888
(296, 592)


In [8]:
from gensim.parsing.preprocessing import STOPWORDS
import gensim.matutils as gm

from gensim.models.keyedvectors import KeyedVectors

# Load pretrained model (since intermediate data is not included, the model cannot be refined with additional data)
model = KeyedVectors.load_word2vec_format('../glove_w2v.txt', binary=False)  # C binary format


wordvec_unavailable= set()
def write_to_file(wordvec_unavailable):
    with open("wordvec_unavailable.txt","w") as f:
        for word in wordvec_unavailable:
            f.write(word+"\n")

def preprocess(tokens):
    btw_words = [word for word in tokens if word not in STOPWORDS]
    btw_words = [word for word in btw_words if word.isalpha()]
    return btw_words

def get_word_vectors(btw_words): # returns vector of embeddings of words
    word_vectors= []
    for word in btw_words:
        try:
            word_v = np.array(model[word])
            word_v = word_v.reshape(len(word_v),1)
            #print(word_v.shape)
            word_vectors.append(model[word])
        except:
            wordvec_unavailable.add(word)
    return word_vectors

def get_similarity(word_vectors,target_word): # sent(list of word vecs) to word similarity
    similarity = 0
    target_word_vector = 0
    try:
        target_word_vector = model[target_word]
    except:
        wordvec_unavailable.add(target_word+" t")
        return similarity
    target_word_sparse = gm.any2sparse(target_word_vector,eps=1e-09)
    for wv in word_vectors:
        wv_sparse = gm.any2sparse(wv, eps=1e-09)
        similarity = max(similarity,gm.cossim(wv_sparse,target_word_sparse))
    return similarity


In [42]:
##### Continuous ################

softmax_Threshold = 0.3
LF_Threshold = 0.3

import re
from snorkel.lf_helpers import (
    get_left_tokens, get_right_tokens, get_between_tokens,
    get_text_between, get_tagged_text,
)

import re
from snorkel.lf_helpers import (
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,
)

def ltp(x):
    return '(' + '|'.join(x) + ')'

causal = ['induced', 'caused', 'due','associated with']

def LF_induce(c):
    return (1,1) if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else (0,0)

def LF_causal(c):
    sc = 0
    word_vectors = get_word_vectors(get_between_tokens(c))
    for w in causal:
        sc=max(sc,get_similarity(word_vectors,w))
    if(re.search('{{A}}.{0,50}(not|no|none).{0,20}' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (1,sc)
    
def LF_induce_name(c):
    return (1,1) if 'induc' in c.chemical.get_span().lower() else (0,0)   

    
def LF_c_induced_d(c):
    return (1,1) if (
        ('{{A}} {{B}}' in get_tagged_text(c)) and 
        (('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
        ) else (0,0)

    
treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']

def LF_treat(c):
    global LF_Threshold
    sc = 0
    word_vectors = get_word_vectors(get_between_tokens(c))
    for w in treat:
        sc=max(sc,get_similarity(word_vectors,w))
    if(re.search('{{A}}.{0,50}(not|no|none).{0,20}' + ltp(treat) + '.{0,50}{{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (-1,sc)
    
def LF_treat_d(c):
    sc = 0
    word_vectors = get_word_vectors(get_left_tokens(c[1]))
    for w in treat:
        sc=max(sc,get_similarity(word_vectors,w))
    if(re.search('(not|no|none) .{0,50} {{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (-1,sc)
    
def LF_c_d(c):
    return (1,1) if ('{{A}} {{B}}' in get_tagged_text(c)) else (0,0)

    
pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
    return (-1,1) if re.search(ltp(pat_terms) + '{{B}}', get_tagged_text(c), flags=re.I) else (0,0)

uncertain = ['combin', 'possible', 'unlikely']

def LF_uncertain(c):
    sc = 0
    word_vectors = get_word_vectors(get_left_tokens(c[1]))
    for w in uncertain:
        sc=max(sc,get_similarity(word_vectors,w))
    if(re.search('(not|no|none) .{0,50} {{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (-1,sc)
    
def LF_far_c_d(c):
    if(rule_regex_search_btw_AB(c, '.{100,5000}', -1)==-1):
        return (-1,1)
    else:
        return (0,0)

def LF_far_d_c(c):
    if(rule_regex_search_btw_BA(c, '.{100,5000}', -1)==-1):
        return (-1,1)
    else:
        return (0,0)
    
def LF_develop_d_following_c(c):
    sc1 = 0
    sc2 = 0
    word_vectors = get_word_vectors(get_left_tokens(c[1]))
    sc1=max(sc1,get_similarity(word_vectors,'develop'))
    
    word_vectors = get_word_vectors(get_between_tokens(c))
    sc2=max(sc2,get_similarity(word_vectors,'following'))
    
    sc = (sc1+sc2)/2
    if(re.search('(not|no|none) .{0,50} {{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (1,sc)
    

def LF_risk_d(c):
    sc = 0
    word_vectors = get_word_vectors(get_left_tokens(c[1]))
    sc=max(sc,get_similarity(word_vectors,'risk'))
    if(re.search(' (not|no|none) .{0,50}{{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (1,sc)
    
def LF_improve_before_disease(c):
    sc = 0
    word_vectors = get_word_vectors(get_left_tokens(c[1]))
    sc=max(sc,get_similarity(word_vectors,'improve'))
    if(re.search(' (not|no|none) .{0,50}{{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (1,sc)
    
procedure, following = ['inject', 'administrate'], ['following']
def LF_d_following_c(c):
    sc1 = 0
    sc2 = 0
    word_vectors = get_word_vectors(get_between_tokens(c))
    for w in following:
        sc1=max(sc1,get_similarity(word_vectors,w))
    
    word_vectors = get_word_vectors(get_right_tokens(c[1]))
    for w in procedure:
        sc2=max(sc2,get_similarity(word_vectors,w))
    
    sc = (sc1+sc2)/2
    return (1,sc)

def LF_measure(c):
    sc = 0
    word_vectors = get_word_vectors(get_left_tokens(c[0]))
    sc=max(sc,get_similarity(word_vectors,'measure'))
    return (-1,sc)
    
def LF_level(c):
    sc = 0
    word_vectors = get_word_vectors(get_right_tokens(c[0]))
    sc=max(sc,get_similarity(word_vectors,'level'))
    return (-1,sc)

def LF_neg_d(c):
    return (-1,1) if re.search('(none|not|no) .{0,25}{{B}}', get_tagged_text(c), flags=re.I) else (0,0)

    
WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'implicated',
               'aim', 'investigate','assess','study']


def LF_weak_assertions(c):
    sc = 0
    word_vectors = get_word_vectors(get_left_tokens(c[1]))
    for w in WEAK_PHRASES:
        sc=max(sc,get_similarity(word_vectors,w))
    if(re.search(' (not|no|none) .{0,50}{{B}}', get_tagged_text(c), re.I)):
        return (0,0)
    else:
        return (1,sc)


In [43]:
def LF_ctd_marker_c_d(c):
    l,s = LF_c_d(c)
    return (l*cand_in_ctd_marker(c),s)

def LF_ctd_marker_induce(c):
    l,s = LF_c_induced_d(c)
    return (l*cand_in_ctd_marker(c),s)

def LF_ctd_therapy_treat(c):
    l,s = LF_treat(c)
    return (l* cand_in_ctd_therapy(c),s)

def LF_ctd_unspecified_treat(c):
    l,s = LF_treat(c)
    return (l * cand_in_ctd_unspecified(c),s)

def LF_ctd_unspecified_induce(c):
    l,s = LF_c_induced_d(c)
    return (l*cand_in_ctd_unspecified(c),s)


In [44]:
import numpy as np
import math

LFs = [LF_in_ctd_unspecified,LF_in_ctd_marker,LF_in_ctd_therapy,LF_closer_chem, 
       LF_closer_dis,LF_causal,LF_c_induced_d,LF_c_d,LF_in_patient_with,LF_uncertain,
       LF_far_c_d,LF_far_d_c,LF_develop_d_following_c,LF_d_following_c,LF_measure,
      LF_level,LF_neg_d,LF_weak_assertions,LF_ctd_marker_c_d,LF_ctd_therapy_treat,
      LF_ctd_unspecified_treat,LF_ctd_unspecified_induce,LF_improve_before_disease,
       LF_risk_d,LF_treat,LF_treat_d,LF_induce,LF_induce_name]


In [45]:
''' output:

    [[[L_x1],[S_x1]],
     [[L_x2],[S_x2]],
     ......
     ......
    ]

'''
def get_L_S_Tensor(cands): 
    
    L_S = []
    for ci in cands:
        L_S_ci=[]
        L=[]
        S=[]
        P_ik = []
        for LF in LFs:
            #print LF.__name__
            l,s = LF(ci)
            L.append(l)
            S.append((s+1)/2)  #to scale scores in [0,1] 
        L_S_ci.append(L)
        L_S_ci.append(S)
        L_S.append(L_S_ci) 
    return L_S

def get_L_S(cands):  # sign gives label abs value gives score
    
    L_S = []
    for ci in cands:
        l_s=[]
        for LF in LFs:
            #print LF.__name__
            l,s = LF(ci)
            s= (s+1)/2  #to scale scores in [0,1] 
            l_s.append(l*s)
        L_S.append(l_s)
    return L_S

def get_Initial_P_cap_L_S(L_S):
    P_cap = []
    for L,S in L_S:
        P_ik = []
        denominator=float(L.count(1)+L.count(-1))
        if(denominator==0):
            denominator=1
        P_ik.append(L.count(1)/denominator)
        P_ik.append(L.count(-1)/denominator)
        P_cap.append(P_ik)
    return P_cap



In [48]:
# import matplotlib.pyplot as plt
   
    
dev_L_S = get_L_S_Tensor(dev_cands)
train_L_S = get_L_S_Tensor(train_cands)
# test_L_S = get_L_S_Tensor(test_cands)


# train_P_cap= get_Initial_P_cap_L_S(train_L_S) 

# dev_P_cap = get_Initial_P_cap_L_S(dev_L_S)

# test_P_cap = get_Initial_P_cap_L_S(test_L_S)

import cPickle as pkl

pkl.dump(dev_L_S,open("dev_L_S.p","wb"))
pkl.dump(train_L_S,open("train_L_S.p","wb"))
# pkl.dump(test_L_S,open("test_L_S.p","wb"))

# pkl.dump(train_P_cap,open("train_P_cap.p","wb"))
# pkl.dump(dev_P_cap,open("dev_P_cap.p","wb"))
# pkl.dump(test_P_cap,open("test_P_cap.p","wb"))

In [89]:
#prepare batch data
# train_L_S_batch,dev_L_S_batch = get_L_S_batch()
# train_P_cap_batch,dev_P_cap_batch = get_P_cap_batch()

In [50]:
from sklearn.metrics import precision_recall_fscore_support

import cPickle as pkl


# pkl.dump(dev_L_S,open("dev_L_S.p","wb"))
# pkl.dump(train_L_S,open("train_L_S.p","wb"))
#pkl.dump(test_L_S,open("test_L_S.p","wb"))

#pkl.dump(train_P_cap,open("train_P_cap.p","wb"))
#pkl.dump(dev_P_cap,open("dev_P_cap.p","wb"))
#pkl.dump(test_P_cap,open("test_P_cap.p","wb"))

dev_L_S = pkl.load( open( "dev_L_S.p", "rb" ) )
train_L_S = pkl.load( open( "train_L_S.p", "rb" ) )
# test_L_S = pkl.load( open( "test_L_S.p", "rb" ) )

# train_P_cap = pkl.load( open( "train_P_cap.p", "rb" ) )
# dev_P_cap = pkl.load( open( "dev_P_cap.p", "rb" ) )
# test_P_cap = pkl.load( open( "test_P_cap.p", "rb" ) )

def get_L_S_batch():
    dev_L_batch = []
    dev_S_batch = []
    dev_L_S_batch = []
    train_L_batch = []
    train_S_batch = []
    train_L_S_batch = []
    for l,s in train_L_S:
        train_L_batch.append(l)
        train_S_batch.append(s)
    train_L_S_batch = [train_L_batch, train_S_batch]
    for l,s in dev_L_S:
        dev_L_batch.append(l)
        dev_S_batch.append(s)
    dev_L_S_batch = [dev_L_batch, dev_S_batch]
    return train_L_S_batch,dev_L_S_batch


def get_P_cap_batch():
    kp1_train= []
    kn1_train = []
    kp1_dev= []
    kn1_dev = []
    for pci in train_P_cap:
        kp1_train.append(pci[0])
        kn1_train.append(pci[1])
    for pci in dev_P_cap:
        kp1_dev.append(pci[0])
        kn1_dev.append(pci[1])
    return [kp1_train,kn1_train],[kp1_dev,kn1_dev]
        
def get_mini_batches(X,P_cap,bsize): #X : (train/dev/)_L_S_batch
    for i in range(0, len(X[0]) - bsize + 1, bsize):
        indices = slice(i, i + bsize)
        #print(indices)
        yield [X[0][indices],X[1][indices]],P_cap[indices]

# train_L_S_batch,dev_L_S_batch = get_L_S_batch()

#for x in get_mini_batches(train_L_S_batch,200):
#    print(len(x),len(x[0]),len(x[0][0]))
    


In [63]:
#stochastic + weighted cross entropy logits func + remove min(theta,0) in loss 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector

def train_NN():
    print()
    result_dir = "./"
    config = projector.ProjectorConfig()
    tf.logging.set_verbosity(tf.logging.INFO)
    summary_writer = tf.summary.FileWriter(result_dir)

    tf.reset_default_graph()

    dim = 2 #(labels,scores)

    _x = tf.placeholder(tf.float64,shape=(dim,len(LFs)))

    alphas = tf.get_variable('alpha', _x.get_shape()[-1],initializer=tf.constant_initializer(0.2),
                            dtype=tf.float64)

    thetas = tf.get_variable('theta', _x.get_shape()[-1],initializer=tf.constant_initializer(0.2),
                            dtype=tf.float64)

    l,s = tf.unstack(_x)

    prelu_out_s = tf.maximum(tf.subtract(s,alphas), tf.zeros(shape=(len(LFs)),dtype=tf.float64))        

    mul_L_S = tf.multiply(l,prelu_out_s)

    phi_p1 = tf.reduce_sum(tf.multiply(mul_L_S,thetas))

    phi_n1 = tf.reduce_sum(tf.multiply(tf.negative(mul_L_S),thetas))

    phi_out = tf.stack([phi_n1,phi_p1])
    
    predict = tf.argmax(tf.nn.softmax(phi_out))

    loss = tf.negative(tf.reduce_logsumexp(phi_out))

    train_step = tf.train.GradientDescentOptimizer(0.0001).minimize(loss) 


    check_op = tf.add_check_numerics_ops()

    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    for i in range(1):
        c = 0
        te_prev=1
        total_te = 0
        for L_S_i in train_L_S:

            a,t,te_curr,_ = sess.run([alphas,thetas,loss,train_step],feed_dict={_x:L_S_i})
            total_te+=te_curr

            if(abs(te_curr-te_prev)<1e-200):
                break

            if(c%500==0):
                pl = []
                t_de=0
                for L_S_i in dev_L_S:
                    a,t,de_curr,p = sess.run([alphas,thetas,loss,predict],feed_dict={_x:L_S_i})
                    pl.append(p)
                    t_de+=de_curr
                predicted_labels = [-1 if x==0 else x for x in pl]
                print("dev err:",t_de/888)
                print(total_te/500)
                total_te=0
                print(a)
                print(t)
                print()
                print(predicted_labels.count(-1),predicted_labels.count(1))
                print(c," d ",precision_recall_fscore_support(np.array(gold_labels_dev),np.array(predicted_labels),average='macro'))
            c+=1
            te_prev = te_curr
        pl = []
        for L_S_i in dev_L_S:
            p = sess.run(predict,feed_dict={_x:L_S_i})
            pl.append(p)
        predicted_labels = [-1 if x==0 else x for x in pl]
        print(i,total_te)
        print(predicted_labels.count(-1),predicted_labels.count(1))
        print(precision_recall_fscore_support(np.array(gold_labels_dev),np.array(predicted_labels),average='macro'))
    
train_NN()


dev err: -0.808128308193
-0.00151581608805
[ 0.2         0.20000697  0.19999303  0.2         0.2         0.20000697
  0.2         0.2         0.2         0.19999303  0.2         0.2
  0.20000697  0.20000697  0.19999303  0.19999303  0.19999303  0.19999303
  0.2         0.19999303  0.2         0.2         0.20000697  0.20000697
  0.19999303  0.19999303]
[ 0.2         0.19997212  0.20002788  0.2         0.2         0.19998157
  0.2         0.2         0.2         0.20001714  0.2         0.2
  0.19998261  0.19998356  0.20001656  0.20001677  0.20002029  0.20002135
  0.2         0.20001834  0.2         0.2         0.19998443  0.19998413
  0.20001834  0.20001834]

805 83
0  d  (0.67499064581306589, 0.56672297297297292, 0.54548665345831537, None)
dev err: -0.835173447317
-0.80440284948
[ 0.19755467  0.20137045  0.19856294  0.19922373  0.19993977  0.20328807
  0.1997506   0.19974915  0.2         0.19683756  0.1995668   0.19952999
  0.20311313  0.20320042  0.19686939  0.19668855  0.19675667  0.

dev err: -2.02011446481
-1.87574590604
[ 0.12247436  0.22423201  0.16960538  0.18014492  0.19806343  0.23604349
  0.20110288  0.20206738  0.19981979  0.10127702  0.19000269  0.19275152
  0.23981858  0.24360987  0.10379307  0.09467813  0.09333744  0.09861487
  0.20167722  0.17357607  0.13759376  0.20097895  0.2416233   0.24130609
  0.10143693  0.10134621]
[ 0.41238783  0.04272337  0.29926804  0.26863834  0.20761203  0.01718905
  0.19554287  0.19156475  0.20071974  0.38860911  0.23685005  0.22727022
  0.03247346  0.04519465  0.37196367  0.39195348  0.41278226  0.4042859
  0.19318325  0.26402082  0.33275174  0.19604839  0.04987857  0.04733578
  0.39449853  0.3871097 ]

884 4
5500  d  (0.45814479638009048, 0.49915540540540543, 0.40238482384823854, None)
dev err: -2.25717250296
-2.14537073727
[ 0.10854275  0.22517874  0.16509213  0.1775073   0.19796135  0.2362943
  0.20151645  0.20260914  0.19981979  0.08384136  0.18882292  0.1921948
  0.24080837  0.24526978  0.08690064  0.07580321  0.07423

In [85]:
#stochastic + weighted cross entropy logits func + remove min(theta,0) in loss 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector

def train_NN():
    print()
    result_dir = "./"
    config = projector.ProjectorConfig()
    tf.logging.set_verbosity(tf.logging.INFO)
    summary_writer = tf.summary.FileWriter(result_dir)

    tf.reset_default_graph()

    dim = 2 #(labels,scores)

    _x = tf.placeholder(tf.float64,shape=(dim,len(LFs)))

    alphas = tf.get_variable('alpha', _x.get_shape()[-1],initializer=tf.constant_initializer(0.2),
                            dtype=tf.float64)

    thetas = tf.get_variable('theta', _x.get_shape()[-1],initializer=tf.constant_initializer(1),
                            dtype=tf.float64)

    l,s = tf.unstack(_x)

    prelu_out_s = tf.maximum(tf.subtract(s,alphas), tf.zeros(shape=(len(LFs)),dtype=tf.float64))        

    mul_L_S = tf.multiply(l,prelu_out_s)

    phi_p1 = tf.reduce_sum(tf.multiply(mul_L_S,thetas))

    phi_n1 = tf.reduce_sum(tf.multiply(tf.negative(mul_L_S),thetas))

    phi_out = tf.stack([phi_n1,phi_p1])
    
    predict = tf.argmax(tf.nn.softmax(phi_out))

    loss = tf.negative(tf.reduce_logsumexp(phi_out))

    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss) 


    check_op = tf.add_check_numerics_ops()

    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    for i in range(1):
        c = 0
        te_prev=1
        total_te = 0
        for L_S_i in train_L_S:

            a,t,te_curr,_ = sess.run([alphas,thetas,loss,train_step],feed_dict={_x:L_S_i})
            total_te+=te_curr

            if(abs(te_curr-te_prev)<1e-200):
                break

            if(c%4000==0):
                pl = []
                for L_S_i in dev_L_S:
                    a,t,de_curr,p = sess.run([alphas,thetas,loss,predict],feed_dict={_x:L_S_i})
                    pl.append(p)
                predicted_labels = [-1 if x==0 else x for x in pl]
                print()
                print(total_te/4000)
                total_te=0
#                 print(a)
#                 print(t)
#                 print()
                print(predicted_labels.count(-1),predicted_labels.count(1))
                print(c," d ",precision_recall_fscore_support(np.array(gold_labels_dev),np.array(predicted_labels),average='macro'))
            c+=1
            te_prev = te_curr
        pl = []
        for L_S_i in dev_L_S:
            p = sess.run(predict,feed_dict={_x:L_S_i})
            pl.append(p)
        predicted_labels = [-1 if x==0 else x for x in pl]
        print(i,total_te)
        print(predicted_labels.count(-1),predicted_labels.count(1))
        print(precision_recall_fscore_support(np.array(gold_labels_dev),np.array(predicted_labels),average='macro'))
    
train_NN()



-0.000226377142841
2251 545
0  d  (0.58865213829531426, 0.71341836734693875, 0.60408179957052133, None)

-1.94518369934e+28
2232 564
4000  d  (0.57825249752154351, 0.6933045525902668, 0.58885935866155448, None)

-5.04415736866e+58
2232 564
8000  d  (0.57825249752154351, 0.6933045525902668, 0.58885935866155448, None)

-1.33431810995e+89
2232 564
12000  d  (0.57825249752154351, 0.6933045525902668, 0.58885935866155448, None)

-3.67295517678e+119
2232 564
16000  d  (0.57825249752154351, 0.6933045525902668, 0.58885935866155448, None)

-9.52453175821e+149
2232 564
20000  d  (0.57825249752154351, 0.6933045525902668, 0.58885935866155448, None)
0 -1.71979948062e+170
2232 564
(0.57825249752154351, 0.6933045525902668, 0.58885935866155448, None)
