# NIHDAM model 09/04/2018

In [1]:
from training_tools import *
from custom_hdam import CustomHDAMVAE, pad_batch
from custom_cnn_model import CustomCNN
import pickle

In [2]:
import gensim
nm=gensim.models.KeyedVectors.load_word2vec_format('./data/network_undirected.embd')
D=pickle.load(open('./data/docids_linkids_map_clean.pkl'))

In [3]:
HDAMAP=pickle.load(open('/mnt/sshd/vvkulkarni/docid_token_dict_for_hdam.pkl'))

In [4]:
def get_network_features(docid):
    if docid in D:
        linkids=D[docid]
        linkids_str=[unicode(l) for l in linkids] #80 is twitter
        if len(linkids_str):
            return nm[linkids_str].mean(axis=0)
        else:
            return np.zeros(shape=128)
    else:
        return np.zeros(128)

def get_network_features_batch(batch):
    docids=batch.docid.data.cpu().numpy()
    nfeatures=np.array([get_network_features(d) for d in docids])
    if use_cuda:
        return Variable(torch.from_numpy(nfeatures)).float().cuda()
    else:
        return Variable(torch.from_numpy(nfeatures)).float()

In [5]:
def get_hdam_features(docid):
    if docid in HDAMAP:
        return HDAMAP[docid]
    else:
        return [list(s) for s in np.zeros(shape=(10,10)).astype(int)]
    
def get_hdam_features_batch(batch):
    docids=batch.docid.data.cpu().numpy()
    nfeatures=pad_batch([get_hdam_features(d) for d in docids])
    if use_cuda:
        return nfeatures.cuda()
    else:
        return nfeatures

In [6]:
def get_examples(b):
    fnn_view = get_network_features_batch(b)
    hdam_view=get_hdam_features_batch(b)
    cnn_view=b.example.transpose(0,1)
    return (fnn_view,hdam_view, cnn_view)

def get_targets(b):
    return b.target

In [7]:
if use_cuda:
    feature_tensor_type=torch.cuda.LongTensor
    target_tensor_type=torch.cuda.LongTensor
else:
    feature_tensor_type=torch.LongTensor
    target_tensor_type=torch.LongTensor
INDEX=torchtext.data.Field(sequential=False, use_vocab=False)
TEXT=torchtext.data.ReversibleField(tensor_type=feature_tensor_type, fix_length=40)
TARGET=torchtext.data.Field(preprocessing=int, sequential=False, use_vocab=False, tensor_type=target_tensor_type)
SOURCE=torchtext.data.Field(preprocessing=int, sequential=False, use_vocab=False, tensor_type=target_tensor_type)
DOCID=torchtext.data.Field(preprocessing=int, sequential=False, use_vocab=False, tensor_type=target_tensor_type)

In [8]:
train_iterator, val_iterator, test_iterator=load_dataset_fake('./data/title_selected_sources_ids_with_targets_train.input', './data/title_selected_sources_ids_with_targets_val.input','./data/title_selected_sources_ids_with_targets_test_final.input', INDEX, TEXT, TARGET, SOURCE, DOCID, min_freq=1, batch_size=64)

In [9]:
from gensim.models.keyedvectors import KeyedVectors 
word_vectors = KeyedVectors.load_word2vec_format("GoogleNews-vectors-negative300.bin", binary=True) 

In [10]:
wv_matrix = []                                                                                                                                                           
for i in range(len(TEXT.vocab)):                                                                                                                                      
    word = TEXT.vocab.itos[i]                                                                                                                                        
    if word in word_vectors.vocab:                                                                                                                                       
        wv_matrix.append(word_vectors.word_vec(word))                                                                                                                    
    else:  
        wv_matrix.append(np.random.uniform(-0.01, 0.01, 300).astype("float32"))                                                                                          

# one for UNK and one for zero padding                                                                                                                                   
wv_matrix.append(np.random.uniform(-0.01, 0.01, 300).astype("float32"))                                                                                                  
wv_matrix.append(np.zeros(300).astype("float32"))                                                                                                                        
wv_matrix = np.array(wv_matrix)  

In [11]:
params = {                                                                                                                                                                   
        "MODEL": 'non-static',                                                                                                                                                  
        "MAX_SENT_LEN": 40,                                                                           
        "BATCH_SIZE": 100,                                                                                                                                                       
        "WORD_DIM": 300,                                                                                                                                                         
        "VOCAB_SIZE": len(TEXT.vocab),                                                                                                                                        
        "HIDDEN_SIZE": 128,                                                                                                                                      
        "FILTERS": [3,4,5],                                                                                                                                                      
        "FILTER_NUM": [100,100,100],                                                                                                                                             
        "DROPOUT_PROB": 0.5,                                                                                                                                                     
        "NORM_LIMIT": 3,                                                                                                                                                         
        "GPU": -1,                                                                                                                                                      
        "model_prefix":'cnn',
        "WV_MATRIX":wv_matrix
    }                                                                                                                                          

In [12]:
def to_var(x):                                                                                                                                                                   
    if torch.cuda.is_available():                                                                                                                                                
        x = x.cuda()                                                                                                                                                             
    return Variable(x)                                                                                                                                                           


In [13]:
class FusionEncoder(nn.Module):
    def __init__(self, input_fnn_size, fnn_hidden_size, hdam_hidden_size, output_hidden_size, cnn_params):
        super(FusionEncoder, self).__init__()
        self.network_view1=CustomNet(input_fnn_size, fnn_hidden_size)
        self.network_view2=CustomHDAMVAE(n_classes=hdam_hidden_size, return_softmax=False)
        self.network_view3=CustomCNN(**cnn_params)
        self.fc1 = nn.Linear(fnn_hidden_size+hdam_hidden_size+cnn_params['HIDDEN_SIZE'], output_hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(output_hidden_size, output_hidden_size)
        self.b11=nn.BatchNorm1d(input_fnn_size)
    
    def forward(self, x):
        xb=self.b11(x[0])
        out_view1=self.network_view1(xb)
        self.network_view2.training=self.training
        out_view2, mu, log_var =self.network_view2(x[1])
        self.network_view3.training=self.training
        out_view3=self.network_view3(x[2])
        concat_views=torch.cat((out_view1, out_view2, out_view3),1)
        out = self.fc1(concat_views)
        out = self.relu(out)
        out = self.fc2(out)
        return out, mu, log_var

In [14]:
class FusionDecoder(nn.Module):
    def __init__(self, latent_hidden_size, num_classes):
        super(FusionDecoder, self).__init__()
        self.fc1 = nn.Linear(latent_hidden_size, latent_hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(latent_hidden_size, num_classes) 
        self.b11=nn.BatchNorm1d(latent_hidden_size)
    
    def forward(self, x):
        out = self.b11(x)
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [15]:
class FusionAutoEncoder(nn.Module):
    def __init__(self, input_fnn_size, fnn_hidden_size, hdam_hidden_size, output_hidden_size, cnn_params, num_classes):
        super(FusionAutoEncoder, self).__init__()
        self.encoder = FusionEncoder(input_fnn_size, fnn_hidden_size, hdam_hidden_size, 2*output_hidden_size, cnn_params)
        self.decoder = FusionDecoder(output_hidden_size, num_classes)
        
    def reparameterize(self, mu, log_var):                                                                                                                                       
        """"z = mean + eps * sigma where eps is sampled from N(0, 1)."""                                                                                                         
        eps = to_var(torch.randn(mu.size(0), mu.size(1)))                                                                                                                        
        z = mu + eps * torch.exp(log_var/2)    # 2 for convert var to std                                                                                                        
        return z                                                                                                                                                                 
                                                                                                                                                                                 
    def forward(self, x):                                                                                                                                                        
        h, hdam_mu, hdam_log_var = self.encoder(x)  
        mu, log_var = torch.chunk(h, 2, dim=1)  # mean and log variance.                                                                                                         
        z = self.reparameterize(mu, log_var)                                                                                                                                     
        out = self.decoder(z)                                                                                                                                                    
        return F.log_softmax(out), mu, log_var, hdam_mu, hdam_log_var                                                                                                                                                                   

In [16]:
my_model=FusionAutoEncoder(128, 128, 128, 128, params, 3)

In [17]:
if use_cuda:
    my_model=my_model.cuda()

In [18]:
my_model

FusionAutoEncoder(
  (encoder): FusionEncoder(
    (network_view1): CustomNet(
      (fc1): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
      (fc2): Linear(in_features=128, out_features=128, bias=True)
    )
    (network_view2): CustomHDAMVAE(
      (word_attn_model): AttentionWordRNN(
        (lookup): Embedding(100000, 300)
        (word_gru): GRU(300, 100, bidirectional=True)
        (softmax_word): Softmax()
      )
      (sent_attn_model): AttentionSentRNNVAE(
        (sent_gru): GRU(200, 200, bidirectional=True)
        (final_linear): Linear(in_features=200, out_features=128, bias=True)
        (softmax_sent): Softmax()
        (final_softmax): Softmax()
      )
    )
    (network_view3): CustomCNN(
      (embedding): Embedding(55072, 300, padding_idx=55071)
      (conv_0): Conv1d(1, 100, kernel_size=(900,), stride=(300,))
      (conv_1): Conv1d(1, 100, kernel_size=(1200,), stride=(300,))
      (conv_2): Conv1d(1, 100, kernel_size=(1500,), stride=(3

In [20]:
for b in train_iterator:
    break

In [27]:
my_model.encoder.network_view2(get_examples(b)[1])

  word_attn_norm = self.softmax_word(word_attn.transpose(1,0))


RuntimeError: input must have 3 dimensions, got 2

In [19]:
def get_kl(mu, log_var):
    klloss= torch.mean(torch.sum(0.5 * (mu**2 + torch.exp(log_var) - log_var -1), 1))
    return klloss

In [20]:
def train_vae(model, optimizer, train_dataset, model_dir, model_prefix, num_epochs, get_examples, get_targets, lr=0.01, max_norm=None, compute_metric=None, eval_data=None, plot_every=50, strict_batch=False):                                                                                                                                                       
    # Always feed examples batch first                                                                                                                                           
    epoch_losses=[]                                                                                                                                                              
    metrics=[]                                                                                                                                                                   
    training_metrics=[]                                                                                                                                                          
    mean_losses=[]                                                                                                                                                               
    train_dataset.iterations=0                                                                                                                                                   
    save_every=1                                                                                                                                                                 
    parameters = filter(lambda p: p.requires_grad, model.parameters())                                                                                                           
    loss_function=nn.NLLLoss()                                                                                                                                                   
    for epoch in np.arange(0,num_epochs):                                                                                                                                        
        batch_losses=[]                                                                                                                                                          
        mean_losses=[]                                                                                                                                                           
        for i, b in enumerate(train_dataset):                                                                                                                                    
            if strict_batch and (b.batch_size != train_dataset.batch_size):                                                                                                      
                continue                                                                                                                                                         
            model.train()                                                                                                                                                        
            model.zero_grad()                                                                                                                                                    
            output, mu, log_var, hdam_mu, hdam_log_var=model(get_examples(b))                                                                                                                                        
            targets=get_targets(b)                                                                                                                                               
            likelihood_loss = loss_function(output, targets)
            kl_loss=get_kl(mu, log_var)
            hdam_loss=get_kl(hdam_mu, hdam_log_var)
            total_loss=likelihood_loss + 0.01*kl_loss+0.001*hdam_loss
            print total_loss.data[0], likelihood_loss.data[0], kl_loss.data[0], hdam_loss.data[0]
            batch_losses.append(total_loss.data[0])                                                                                                                                    
            if (i%plot_every==0):                                                                                                                                                
                mean_losses.append(np.mean(batch_losses))                                                                                                                        
                yvals=np.array(mean_losses)                                                                                                                                      
                xvals=np.arange(0, len(yvals))                                                                                                                                   
                vis.line(Y=yvals, X=xvals, win='batch_loss', opts={'title':'batch_loss'})                                                                                        
            total_loss.backward()                                                                                                                                                      
            if max_norm is not None:                                                                                                                                             
                nn.utils.clip_grad_norm(parameters, max_norm=max_norm)                                                                                                           
            optimizer.step()                                                                                                                                                     
                                                                                                                                                                                 
        epoch_losses.append(np.mean(batch_losses))                                                                                                                               
        if compute_metric is not None:                                                                                                                                           
            tmetric=compute_metric(train_dataset, model, get_examples, get_targets)                                                                                              
            training_metrics.append(tmetric)                                                                                                                                     
            vis.line(np.array(training_metrics), win='training_metric', opts={'title':'training_metric'})                                                                        
            metric=compute_metric(eval_data, model, get_examples, get_targets)                                                                                                   
            metrics.append(metric)                                                                                                                                               
            vis.line(Y=np.array(metrics), X=np.arange(0, len(np.array(metrics))), win='metric', opts={'title':'metric'})                                                         
            vis.line(Y=np.array(epoch_losses), X=np.arange(0, len(np.array(metrics))), win='loss', opts={'title':'loss'})                                                        
        torch.save(model.state_dict(), "{}/{}_{}.dict".format(model_dir, model_prefix, int(epoch)))                                                                              
    return epoch_losses[-1]            

In [None]:
optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, my_model.parameters()), lr=1.0)

In [None]:
train_vae(my_model, optimizer, train_iterator, '/mnt/sshd/vvkulkarni/fakenews_nn_models/vae_std_gauss_attn/', 'vae_std_gauss_attn', 5,  get_examples, get_targets, lr=0.1, max_norm=None, compute_metric=None, eval_data=val_iterator)

1.09893202782 1.09482586384 0.26179921627 1.48818874359
1.10495090485 1.10033726692 0.250787734985 2.10569643974
1.12733221054 1.12311983109 0.237090736628 1.84140563011
1.06069874763 1.05687820911 0.217010840774 1.6504881382
1.15713071823 1.15310835838 0.222701206803 1.79526793957
1.17215895653 1.16831088066 0.224200949073 1.60615873337
1.09699273109 1.09332716465 0.208028867841 1.58522737026
1.01284265518 1.00933766365 0.196828603745 1.53673815727
1.15013206005 1.14653003216 0.200257301331 1.59937942028
1.16921675205 1.16583991051 0.197928875685 1.39766013622
1.1489123106 1.14548969269 0.188274383545 1.53985059261
1.1373090744 1.13395631313 0.188969969749 1.46302556992
1.05769586563 1.05453372002 0.182414993644 1.33795166016
1.11177825928 1.10858845711 0.177874043584 1.41113185883
1.14496815205 1.1419481039 0.179036155343 1.22964465618
1.0587669611 1.05566740036 0.173264592886 1.36694228649
1.09003043175 1.08674228191 0.160062432289 1.68748903275
1.14517235756 1.14189481735 0.1643349

1.13900852203 1.11085140705 2.7353246212 0.803859472275
1.00607919693 0.986951649189 1.82587969303 0.86878734827
0.960458397865 0.93265080452 2.70097994804 0.797819912434
0.899978876114 0.873034834862 2.61759018898 0.768135249615
1.00107324123 0.967038750648 3.32665419579 0.767936944962
0.966722548008 0.954855501652 1.10793614388 0.787649571896
1.12634551525 1.10567140579 1.99134767056 0.760630071163
1.03930008411 1.02096045017 1.7549841404 0.789819419384
0.888615787029 0.855561435223 3.19942688942 1.06004142761
1.10254621506 1.05563151836 4.62110853195 0.703560352325
1.02597975731 0.999144017696 2.58963894844 0.939380824566
0.997835040092 0.981365382671 1.54214894772 1.04814422131
1.05204701424 1.03907847404 1.22575700283 0.711004257202
0.958790779114 0.933824360371 2.42278146744 0.738635420799
0.966952443123 0.919654905796 4.62210988998 1.0764850378
1.07735157013 1.04056739807 3.60401582718 0.743996500969
0.980076789856 0.955128490925 2.4057598114 0.890734314919
0.873185575008 0.8372

1.12119698524 1.08661806583 3.38646125793 0.714343607426
1.00154399872 0.977203428745 2.36464977264 0.694068193436
0.888871371746 0.850272774696 3.78450775146 0.753520667553
0.749295055866 0.720191001892 2.82263803482 0.877670764923
0.890851616859 0.855575084686 3.45899915695 0.68650662899
1.1506267786 1.12251913548 2.74563622475 0.651187062263
1.11715829372 1.08613908291 3.03416466713 0.677558004856
1.0032389164 0.982759416103 1.97797310352 0.699812948704
0.878911793232 0.849653482437 2.86383032799 0.620030879974
0.90973418951 0.883759140968 2.53665709496 0.608512103558
0.92393964529 0.892884671688 3.04721450806 0.582820832729
0.895126461983 0.855896949768 3.85786414146 0.650870084763
0.935267031193 0.890290975571 4.43150806427 0.660966813564
0.878454566002 0.825323879719 5.24006795883 0.730059742928
0.813331902027 0.763585686684 4.90656614304 0.680571198463
0.990056753159 0.944510996342 4.47846221924 0.761141955853
0.954639077187 0.931610763073 2.22108507156 0.817470252514
0.83924740

1.06181001663 1.00923192501 5.1836681366 0.741379737854
0.97790491581 0.931608378887 4.57318735123 0.564682126045
0.879630923271 0.829002737999 4.97987699509 0.829380452633
0.804524838924 0.760558366776 4.33773469925 0.589150846004
1.02691757679 0.987146377563 3.93031406403 0.467980563641
1.15855693817 1.13204264641 2.60379433632 0.476313471794
1.02232563496 0.98519796133 3.66192173958 0.508429884911
0.92256718874 0.892883181572 2.91688156128 0.515160918236
0.956841468811 0.926962196827 2.94139790535 0.465301781893
1.00052261353 0.970448791981 2.96122860909 0.461596220732
0.938953220844 0.911761820316 2.67157673836 0.475671589375
0.933733642101 0.893753647804 3.95272397995 0.45272949338
0.965111672878 0.916332483292 4.83253669739 0.453836679459
0.90256267786 0.86463034153 3.74836587906 0.448688387871
0.779251456261 0.739532887936 3.92449402809 0.473637402058
0.859581947327 0.820973455906 3.8167090416 0.44144192338
0.914421200752 0.863282620907 5.06978654861 0.440718740225
1.17712879181

0.913157045841 0.874133348465 3.85191202164 0.504548788071
0.853657782078 0.811231315136 4.18960189819 0.530443191528
0.864161550999 0.825863838196 3.76549768448 0.642728328705
0.973963618279 0.950643241405 2.26829648018 0.637404441833
0.900965332985 0.864628493786 3.57859039307 0.550953090191
0.864214777946 0.824463307858 3.91984391212 0.552989542484
0.898636162281 0.863831341267 3.43397641182 0.46500903368
0.878988802433 0.800138056278 7.82738208771 0.576923489571
0.791431963444 0.744664132595 4.61507987976 0.617019712925
0.986067056656 0.940088689327 4.54784965515 0.499819904566
0.877097606659 0.842975258827 3.35967540741 0.525571167469
0.922336041927 0.88876324892 3.30359792709 0.53681409359
1.09759426117 1.072083354 2.49549198151 0.555963635445
0.966974079609 0.939761281013 2.66330289841 0.579777777195
1.06295335293 1.03329098225 2.89728879929 0.689461529255
0.886915266514 0.854778409004 3.15657901764 0.571067035198
0.984757602215 0.966462850571 1.76374363899 0.657329380512
0.9592

0.958398580551 0.908850312233 4.88389825821 0.709323465824
0.924733638763 0.876683294773 4.72847700119 0.765577375889
0.781846523285 0.740998744965 4.01395702362 0.708214700222
0.962658166885 0.933066308498 2.89639234543 0.627930343151
0.900661468506 0.859264016151 4.06396770477 0.757773220539
0.886743307114 0.845508933067 4.05858421326 0.64853990078
0.9150583148 0.870594382286 4.3832116127 0.631812632084
0.845640718937 0.792116463184 5.29005670547 0.623732328415
0.996336340904 0.957931160927 3.77636098862 0.641606152058
0.787127435207 0.738002061844 4.84092664719 0.716100633144
0.832794308662 0.77760642767 5.44419670105 0.745864927769
0.854708492756 0.79083865881 6.31366586685 0.733188867569
0.803064227104 0.739817380905 6.25870943069 0.659766972065
1.05413591862 1.00001704693 5.34737873077 0.645064413548
0.862313628197 0.819982349873 4.16503858566 0.680934965611
0.891200661659 0.834196031094 5.63806200027 0.624007701874
0.936648786068 0.890644073486 4.54043769836 0.600314557552
0.890

0.853667259216 0.808739602566 4.43190431595 0.608600199223
0.899272620678 0.854527711868 4.41373968124 0.607518255711
0.909719228745 0.876777172089 3.23789095879 0.563119292259
0.810975909233 0.772493302822 3.79451823235 0.537393629551
0.946593284607 0.901420772076 4.46443748474 0.528167366982
0.883904457092 0.842633664608 4.06792831421 0.591489136219
1.03258287907 0.98137485981 5.06882762909 0.519765198231
0.988447248936 0.950064718723 3.78076171875 0.574870467186
0.941818773746 0.910250127316 3.09032320976 0.665425240993
0.769413053989 0.723358273506 4.55007743835 0.554003000259
0.924219191074 0.883541405201 4.01669168472 0.510892271996
0.935991644859 0.896102130413 3.92752242088 0.614291608334
1.0145316124 0.977863848209 3.60389852524 0.62888532877
1.04767286777 1.0201612711 2.6921479702 0.59007525444
0.810064196587 0.775282084942 3.41914606094 0.59067851305
0.918434202671 0.887575268745 3.02703619003 0.588595032692
1.0792388916 1.0381911993 4.04713869095 0.576203584671
0.8683629632

0.802417159081 0.757010042667 4.48065662384 0.600566506386
0.994846999645 0.948558926582 4.57000398636 0.58803576231
0.85596716404 0.808651268482 4.67074632645 0.60845965147
1.03037345409 0.989272892475 4.05558824539 0.54468524456
0.956381857395 0.930068671703 2.5726146698 0.587073087692
1.05275523663 1.02240252495 2.9867875576 0.48485276103
0.9744836092 0.931500971317 4.23571586609 0.625516951084
0.981001615524 0.949822247028 3.06355500221 0.54382032156
0.936401367188 0.909477233887 2.63719058037 0.552242100239
0.861958682537 0.826267719269 3.51457023621 0.545246839523
0.801005005836 0.75379049778 4.66278600693 0.586655020714
0.850458085537 0.806923806667 4.29005050659 0.633767843246
0.70016092062 0.647242486477 5.22892332077 0.629204750061
0.961685717106 0.906443178654 5.46410942078 0.601455152035
0.863227963448 0.822461366653 4.01405763626 0.626046955585
0.973519742489 0.938851773739 3.40538787842 0.614095807076
0.809643626213 0.772257626057 3.67887067795 0.597300052643
0.9572257399

0.898965299129 0.850761651993 4.76657867432 0.537889122963
0.830085098743 0.785696089268 4.37782812119 0.610684037209
1.00212967396 0.955924451351 4.55033540726 0.701943337917
0.823587119579 0.784595608711 3.83601593971 0.631325662136
0.949287116528 0.910617232323 3.81272625923 0.542647361755
0.946340739727 0.92258888483 2.31892299652 0.562621355057
0.891389787197 0.864390194416 2.64689230919 0.530686974525
1.07090008259 1.04531395435 2.50067281723 0.579298198223
0.962768375874 0.925531804562 3.6666533947 0.570035099983
0.946913361549 0.91536450386 3.08949875832 0.653879225254
0.909896314144 0.876005411148 3.33141851425 0.576720833778
0.829213619232 0.789488732815 3.91474747658 0.577367424965
0.984426021576 0.945536851883 3.82818865776 0.607316255569
0.831294476986 0.791833877563 3.89225578308 0.538030803204
0.900917351246 0.863455832005 3.69215011597 0.540007412434
0.904322504997 0.881645143032 2.21283435822 0.549020469189
0.918959200382 0.889290511608 2.91517996788 0.516891598701
0.9

0.863878190517 0.829143881798 3.40385103226 0.695799231529
0.824519693851 0.769057154655 5.48188924789 0.643699645996
1.06518101692 1.02807247639 3.6495885849 0.612604379654
0.947274029255 0.914783895016 3.18467783928 0.643389582634
0.920918405056 0.877217173576 4.3143196106 0.558008313179
0.913293480873 0.879744887352 3.30194187164 0.529140889645
0.76268774271 0.717050433159 4.51324033737 0.504882097244
0.923473954201 0.878996908665 4.39738941193 0.503196537495
0.873936712742 0.833665132523 3.97762417793 0.495323061943
0.910898625851 0.861916422844 4.84825992584 0.499608844519
0.855580985546 0.802460610867 5.23669624329 0.753413975239
1.07117068768 1.03730404377 3.33426523209 0.52404743433
0.919782519341 0.887721180916 3.14963579178 0.564975678921
0.900782763958 0.87424826622 2.60236978531 0.510832548141
0.882448852062 0.838189840317 4.37295150757 0.529474496841
0.96568775177 0.935078680515 3.00198841095 0.58921533823
0.860311031342 0.818045437336 4.16158676147 0.649730801582
0.814734

0.90486985445 0.866167426109 3.78900504112 0.812352955341
0.913167357445 0.875921070576 3.64537096024 0.792538702488
0.761777579784 0.711647391319 4.93764734268 0.753720760345
1.08085799217 1.02046787739 5.96028423309 0.787225544453
0.853408336639 0.81650364399 3.60839676857 0.820757031441
0.971772968769 0.937743067741 3.32065725327 0.823326170444
0.780881404877 0.737806081772 4.23158311844 0.759458065033
0.859953045845 0.808540701866 5.06041383743 0.80819773674
1.02917981148 0.989899396896 3.85138320923 0.766615152359
1.14313828945 1.09980285168 4.25242471695 0.811209142208
0.87395465374 0.824488818645 4.85197591782 0.946124434471
1.0413287878 0.981565237045 5.88613080978 0.902275323868
0.841018259525 0.790961384773 4.90195083618 1.03737139702
0.886658549309 0.846655905247 3.90388774872 0.963760435581
0.934154689312 0.898168385029 3.5148358345 0.83789986372
0.882598221302 0.840221881866 4.15441036224 0.83226108551
0.793901145458 0.757721662521 3.54237270355 0.75574195385
0.81511795520

0.916011750698 0.885411322117 2.91219687462 1.47844922543
0.804050922394 0.772284269333 2.97584104538 2.00828886032
0.94415551424 0.921900451183 2.20404410362 0.214624732733
0.974956154823 0.951736807823 2.17790484428 1.4402731657
0.980996191502 0.93889850378 4.0777888298 1.31982421875
0.824372887611 0.784175872803 3.88075256348 1.38953089714
0.725774466991 0.665691077709 5.85759830475 1.50740170479
0.864107131958 0.822546601295 4.00555944443 1.50494372845
0.984676420689 0.937381744385 4.60026836395 1.29197847843
0.837011516094 0.795453965664 4.00071191788 1.55045008659
1.01687407494 0.966433048248 4.90182256699 1.42279648781
0.844540297985 0.810864686966 3.24369883537 1.23862373829
0.770743131638 0.729504644871 3.96941900253 1.54430282116
0.886521935463 0.836950004101 4.81022310257 1.4697419405
0.831285953522 0.796596348286 3.33839440346 1.30563855171
0.845263957977 0.795680880547 4.83417797089 1.24131393433
0.825160503387 0.768494784832 5.52269744873 1.43874907494
0.778685688972 0.72

0.81631654501 0.783648192883 3.13761591911 1.29219174385
0.83848553896 0.790410101414 4.67258405685 1.34960961342
0.704360127449 0.649843096733 5.32088088989 1.30822706223
1.03847336769 0.995924174786 4.13188076019 1.23039710522
0.960363030434 0.920406341553 3.88611984253 1.0954887867
0.635043382645 0.594581305981 3.92417860031 1.22029781342
1.5136692524 1.44109535217 7.11573457718 1.41651523113
0.910040259361 0.867138504982 4.14605045319 1.4412664175
1.05121445656 1.03462862968 1.51051986217 1.48055505753
0.873040735722 0.848331511021 2.32135295868 1.49571490288
0.935242116451 0.896426737309 3.73829007149 1.43245899677
0.865468502045 0.836411476135 2.75075483322 1.54950153828
1.09060168266 1.06771099567 2.15169000626 1.37381052971
1.05210042 1.02493834496 2.59745073318 1.18751418591
1.00668621063 0.966026067734 3.90174221992 1.64265024662
0.878578186035 0.843539774418 3.36608552933 1.37760722637
0.827582359314 0.78879737854 3.75345802307 1.25036013126
0.877614855766 0.826903283596 4.9

0.809630095959 0.758524775505 4.9328083992 1.77725088596
0.814041018486 0.788710236549 2.34669780731 1.86377549171
0.746360480785 0.674930155277 6.99137353897 1.51660716534
1.07516586781 1.03139173985 4.23991298676 1.37490737438
0.729595661163 0.672597646713 5.52551698685 1.74284732342
0.897982180119 0.853045105934 4.32731628418 1.66393303871
0.979375302792 0.940001547337 3.79026651382 1.47111403942
0.89087253809 0.845607995987 4.36068058014 1.65770447254
0.881111383438 0.839995384216 3.96487426758 1.4672100544
0.859771668911 0.8136895895 4.46262311935 1.45585024357
0.965008974075 0.925194144249 3.83792901039 1.43550229073
1.00510966778 0.965043365955 3.8497197628 1.56910133362
0.788025915623 0.75033390522 3.60457897186 1.64621460438
0.732545316219 0.66798967123 6.2925567627 1.63007414341
0.803920865059 0.756474256516 4.586581707 1.58079898357
0.820041537285 0.76320797205 5.51454782486 1.68808376789
0.83802741766 0.78578132391 5.07908153534 1.45532631874
0.903696238995 0.848612308502 5

0.993293941021 0.947972476482 4.34450531006 1.87639772892
1.08051812649 1.03595733643 4.27066421509 1.85420465469
0.993496716022 0.951804101467 3.98121285439 1.88049256802
0.891301751137 0.857321381569 3.20685553551 1.91182267666
0.870862364769 0.833895027637 3.49130940437 2.05421376228
0.894003152847 0.833111941814 5.88593912125 2.03183245659
1.03853881359 1.00054430962 3.59510707855 2.04353308678
0.927341103554 0.884829163551 4.05522918701 1.95964920521
0.741846859455 0.70446652174 3.5699634552 1.6807076931
1.1130926609 1.07373452187 3.75063443184 1.851801157
0.791291415691 0.739107251167 5.02362585068 1.94790184498
0.677483916283 0.626615941525 4.90900564194 1.77793598175
1.02289295197 0.961937189102 5.89714288712 1.98438155651
1.10306477547 1.06347155571 3.78247261047 1.768486619
0.894144356251 0.860577225685 3.18430733681 1.72408127785
0.804880619049 0.760959148407 4.1833987236 2.08747625351
0.793816387653 0.742785274982 4.89764595032 2.05463123322
0.890889525414 0.833265244961 5.

0.757869064808 0.703216314316 5.27195358276 1.93324482441
0.839165508747 0.803547024727 3.37705039978 1.84800577164
0.740119636059 0.670632958412 6.70600891113 2.4265639782
0.857538461685 0.806160628796 4.93765115738 2.00136137009
0.835180938244 0.789146244526 4.40374135971 1.9973089695
0.789056122303 0.739237546921 4.79168796539 1.90169286728
0.591439902782 0.514205753803 7.55046272278 1.72952461243
1.06191527843 1.00908720493 5.11004638672 1.72757077217
0.997850298882 0.922911703587 7.20751237869 2.86346292496
0.820239543915 0.754999220371 6.28222799301 2.41805171967
0.865070819855 0.813866078854 4.87975597382 2.40721011162
0.744832754135 0.682704687119 5.97384691238 2.38961434364
0.924106180668 0.874095857143 4.78768539429 2.13351416588
0.950753629208 0.910349369049 3.80777168274 2.32657027245
0.906441807747 0.854857683182 4.95221376419 2.06194329262
0.788385570049 0.74502146244 4.12383985519 2.1257045269
0.972978234291 0.940615177155 3.0382065773 1.98102676868
0.802212536335 0.7418

0.710077166557 0.664511859417 4.37902259827 1.77509891987
0.740535140038 0.693840444088 4.48841953278 1.81047451496
1.0163654089 0.965148329735 4.95068597794 1.71020996571
0.938808739185 0.901164412498 3.57352209091 1.90906715393
0.806834995747 0.759860932827 4.51492834091 1.82478404045
0.783589184284 0.72833687067 5.31395959854 2.11272001266
0.799177646637 0.750543534756 4.66489982605 1.98516118526
0.831488847733 0.782627701759 4.67167377472 2.1443798542
0.912070989609 0.87330365181 3.70826745033 1.68465864658
0.805395364761 0.752384960651 5.11346769333 1.8757686615
0.940414011478 0.87716114521 6.14810323715 1.77180373669
0.98207962513 0.940309524536 3.93023347855 2.46775364876
0.840241968632 0.787321805954 5.05220842361 2.39805150032
0.835847437382 0.785401642323 4.82988786697 2.14691376686
0.595330774784 0.516877651215 7.6200633049 2.25251746178
0.597176790237 0.521459877491 7.3435049057 2.2818171978
0.87883490324 0.799106180668 7.75302457809 2.19846701622
0.701317846775 0.651291966

0.745803892612 0.69867759943 4.49699020386 2.15640521049
0.798092484474 0.743839263916 5.18787908554 2.37442731857
0.702661335468 0.644155502319 5.60737848282 2.43207216263
0.92015504837 0.870687067509 4.71714782715 2.29652523994
0.941851198673 0.902410686016 3.71143746376 2.32610821724
0.890125453472 0.835449635983 5.1969537735 2.70627188683
0.815396726131 0.758305907249 5.47423601151 2.34846758842
0.920646011829 0.864961504936 5.32082843781 2.47623515129
0.787928521633 0.730449616909 5.5251750946 2.22714686394
0.992007672787 0.956096827984 3.39590668678 1.95174753666
0.843550980091 0.796115815639 4.54691410065 1.96599507332
0.891201198101 0.845170676708 4.42591762543 1.77132105827
0.802341520786 0.768663406372 3.16878509521 1.99026930332
0.699125468731 0.626335442066 7.07269573212 2.06310224533
0.698375582695 0.643666863441 5.27114534378 1.99728024006
0.675151526928 0.599314630032 7.3766951561 2.06997084618
0.71960234642 0.659584999084 5.81014537811 1.91590058804
0.966858267784 0.902

0.949448108673 0.905540883541 4.19170761108 1.9901291132
0.921051025391 0.878175735474 4.10740184784 1.80126798153
0.695546805859 0.647222876549 4.65339374542 1.78997182846
0.767598807812 0.714473605156 5.12023448944 1.92286086082
0.845749676228 0.778449177742 6.52663469315 2.03410410881
0.766498982906 0.710189163685 5.42594480515 2.05038237572
0.787827312946 0.731593549252 5.44445943832 1.78915274143
0.973270356655 0.924827635288 4.66119003296 1.83084475994
0.868007421494 0.82041823864 4.52530574799 2.33615756035
0.670866727829 0.6116553545 5.69576644897 2.253698349
0.819514811039 0.760897874832 5.62641906738 2.35279202461
0.806373596191 0.749474465847 5.50002670288 1.89889442921
0.896640121937 0.848844349384 4.56618356705 2.1339828968
0.865875482559 0.822592198849 4.12535762787 2.029733181
0.877120256424 0.841439843178 3.33861446381 2.29425001144
0.825710296631 0.781332492828 4.2350153923 2.0276491642
0.704227805138 0.650557219982 5.11426401138 2.52796244621
0.823375403881 0.76588886

0.889964222908 0.850294291973 3.77512288094 1.91865813732
0.660926282406 0.602518796921 5.63540697098 2.0534620285
0.600234866142 0.52678591013 7.14971494675 1.95181131363
0.66168063879 0.576496899128 8.3259677887 1.92403519154
1.47601532936 1.40649688244 6.71700143814 2.34841346741
0.803266942501 0.748343765736 5.16307783127 3.29239559174
0.804112970829 0.751152992249 4.93641424179 3.59579944611
0.78094792366 0.730366706848 4.72440910339 3.33711600304
0.693079292774 0.637681603432 5.24582147598 2.93948221207
0.814778804779 0.760480761528 5.12058162689 3.09225463867
0.836766183376 0.788326144218 4.55072259903 2.93279743195
1.01759421825 0.961390674114 5.33696699142 2.83388853073
0.771867990494 0.720752179623 4.8011803627 3.10403800011
0.822053909302 0.767799794674 5.08806180954 3.37347555161
0.810449123383 0.764376938343 4.29820394516 3.09013009071
1.04317426682 1.00009918213 3.97525835037 3.32248306274
0.828924059868 0.784666121006 4.12547874451 3.00320672989
0.96212553978 0.899127840

0.758362352848 0.707948088646 4.81356191635 2.27860641479
0.701100826263 0.639465332031 5.94120073318 2.22348308563
0.936555087566 0.88950407505 4.46780443192 2.37300395966
0.788288593292 0.735224545002 5.06675291061 2.39652824402
0.721977710724 0.666752755642 5.28572463989 2.36771082878
0.8340716362 0.771931529045 5.98342609406 2.3058450222
0.91697961092 0.867525875568 4.70896816254 2.36404585838
0.860138475895 0.80364716053 5.36733150482 2.81800913811
0.780359625816 0.704921364784 7.28279352188 2.61034226418
0.962452530861 0.879964709282 7.95875740051 2.90026879311
0.731960713863 0.669694900513 5.9921503067 2.34431743622
0.820047318935 0.748505532742 6.92598867416 2.28191328049
0.685878574848 0.632636427879 5.10427570343 2.19942903519
0.650399982929 0.599801361561 4.84043312073 2.19429755211
0.863146424294 0.813096165657 4.80738306046 1.97645866871
0.836334526539 0.776978611946 5.71598100662 2.1961479187
0.938611745834 0.891737878323 4.47274875641 2.14634585381
0.868833720684 0.81622

0.977624356747 0.936498045921 3.90431022644 2.08319425583
0.74323785305 0.702009260654 3.87419128418 2.48672819138
0.769603013992 0.714620649815 5.29660701752 2.01633167267
0.572735249996 0.507201015949 6.33530759811 2.18120145798
0.874608337879 0.816145956516 5.60576915741 2.4047088623
0.594106674194 0.538430154324 5.33918714523 2.28464841843
0.740122020245 0.683616757393 5.40066814423 2.49859142303
0.914335846901 0.854275465012 5.73624134064 2.6979663372
0.783663511276 0.734720826149 4.65435028076 2.39920663834
0.918962597847 0.86282479763 5.38925170898 2.24529314041
0.706147372723 0.640586435795 6.27712392807 2.78970336914
0.838553845882 0.796907186508 3.88715982437 2.77508711815
0.837980091572 0.780393362045 5.48504686356 2.7362651825
0.765192985535 0.700026392937 6.25326347351 2.63392949104
1.06950724125 1.01799070835 4.90505647659 2.46591639519
0.811432778835 0.765918552876 4.29102134705 2.60402798653
0.705860853195 0.652149677277 5.08876848221 2.82348608971
0.798729121685 0.7392

0.405037969351 0.357707589865 4.30927324295 4.23766279221
0.402843594551 0.351639330387 4.69777917862 4.2264881134
0.362352013588 0.314788788557 4.33206939697 4.24252653122
0.361777186394 0.308168292046 4.92531108856 4.35579967499
0.394357770681 0.34128588438 4.92721319199 3.79978060722
0.328837305307 0.285639435053 3.92658829689 3.93200683594
0.39345857501 0.345024913549 4.41199302673 4.31372976303
0.324905455112 0.274238735437 4.69863176346 3.68041372299
0.413138329983 0.362916499376 4.56737661362 4.54808282852
0.317710042 0.258634030819 5.53179168701 3.75811314583
0.471915125847 0.414668291807 5.31234836578 4.12337303162
0.484307825565 0.429465055466 5.09921360016 3.85063505173
0.616747140884 0.564950525761 4.80619621277 3.73465871811
0.520653605461 0.470631688833 4.58174800873 4.20447015762
0.247285470366 0.196068525314 4.71749687195 4.04198026657
0.401205331087 0.3519333601 4.510014534 4.17183256149
0.299828618765 0.240746930242 5.492664814 4.1550450325
0.500456213951 0.4454123973

0.393077462912 0.34393632412 4.52261924744 3.91493868828
0.54288995266 0.493110120296 4.57592391968 4.02058124542
0.413438826799 0.365291953087 4.41092586517 4.0376162529
0.483424514532 0.439135432243 4.05085849762 3.78051900864
0.543045341969 0.497629076242 4.16166448593 3.7996404171
0.494769245386 0.45048931241 4.03176212311 3.96230173111
0.444117248058 0.400126546621 4.00208711624 3.96982264519
0.58066624403 0.540210962296 3.6744761467 3.7105345726
0.473528891802 0.426652491093 4.32079219818 3.66849851608
0.439662992954 0.394401907921 4.13265371323 3.93455886841
0.480779081583 0.435956448317 4.10059165955 3.81672763824
0.328085839748 0.279468983412 4.45703840256 4.0464720726
0.422431230545 0.374903082848 4.35070133209 4.02114677429
0.625330448151 0.581533670425 3.99962615967 3.80052661896
0.385659396648 0.336418658495 4.54012393951 3.8395011425
0.537667155266 0.494078755379 3.98627448082 3.725659132
0.548371732235 0.503989696503 4.07210969925 3.66088747978
0.352787852287 0.308226108

0.587940216064 0.540017724037 4.41451120377 3.77736592293
0.466914862394 0.411495387554 5.1283826828 4.13563632965
0.401768833399 0.351940065622 4.590883255 3.9199385643
0.285799056292 0.231536284089 4.98920965195 4.37067747116
0.441008418798 0.392057955265 4.52633142471 3.68714666367
0.437012076378 0.385798543692 4.69218111038 4.29171657562
0.317658394575 0.27027246356 4.33130121231 4.07294845581
0.315978825092 0.270380884409 4.13024759293 4.29546260834
0.58449357748 0.539684593678 4.06152868271 4.19372606277
0.346249580383 0.30103635788 4.12291288376 3.98408746719
0.43522977829 0.386402159929 4.52870321274 3.54056429863
0.472343444824 0.430487990379 3.8096575737 3.75888299942
0.503589808941 0.459293723106 4.0808134079 3.48793387413
0.478026747704 0.429293632507 4.47198152542 4.01329183578
0.552696287632 0.504417479038 4.470515728 3.57365584373
0.412604689598 0.372627586126 3.60173726082 3.95971679688
0.467469155788 0.422061920166 4.18132209778 3.59401655197
0.506361186504 0.456971734

0.615765869617 0.564486980438 4.75475072861 3.73135399818
0.36442604661 0.313162565231 4.76972532272 3.56624507904
0.487354010344 0.430127024651 5.36200332642 3.60693836212
0.575262844563 0.530929327011 4.04794692993 3.85402750969
0.586781919003 0.545179724693 3.77218985558 3.88034629822
0.478375971317 0.436295032501 3.85492062569 3.53173589706
0.461021780968 0.415804356337 4.1029253006 4.18818855286
0.427151173353 0.381451666355 4.20223665237 3.67711877823
0.487748801708 0.442446112633 4.14045000076 3.89819145203
0.425345420837 0.381800323725 3.92892241478 4.25586795807
0.383195310831 0.339668750763 3.99971914291 3.52936148643
0.50234234333 0.453623592854 4.49384498596 3.78029251099
0.425073981285 0.381049484015 4.05293798447 3.49511933327
0.429005980492 0.387148708105 3.80506181717 3.80663776398
0.392579376698 0.345452994108 4.30978345871 4.02855205536
0.317441791296 0.27234005928 4.12448740005 3.85685515404
0.35552534461 0.307590246201 4.4491648674 3.44345402718
0.396373510361 0.345

0.4533604375455199

In [32]:
def get_f1_on(dataset_iterator, model, get_examples, get_targets):                                                                                                               
    all_preds=[]                                                                                                                                                                 
    all_targets=[]                                                                                                                                                               
    model.training=False                                                                                                                                                         
    for e in dataset_iterator:                                                                                                                                                   
        model.zero_grad()                                                                                                                                                        
        output,_,_,_,_=model(get_examples(e))                                                                                                                                            
        classix=list(np.argmax(output.cpu().data.numpy(), axis=1))                                                                                                               
        targets=get_targets(e).cpu().data.numpy()                                                                                                                                
        all_preds.extend(classix)                                                                                                                                                
        all_targets.extend(targets)                                                                                                                                              
    model.training=True                                                                                                                                                          
    return f1_score(all_targets, all_preds, average='weighted')                                                                                                                  

In [33]:
get_f1_on(test_iterator, my_model, get_examples, get_targets)

0.8029463835580358

In [None]:
get_f1_on(train_iterator, my_model, get_examples, get_targets)

In [None]:
def get_preds_on(dataset_iterator, model, get_examples):                                                                                                                         
    results=[]                                                                                                                                                                   
    model.training=False                                                                                                                                                         
    for e in dataset_iterator:                                                                                                                                                   
        model.zero_grad()                                                                                                                                                        
        output,_,_,_,_=model(get_examples(e))                                                                                                                                            
        for pred, ix in zip(output, e.index):                                                                                                                                    
            results.append((ix.cpu().data.numpy()[0], pred.cpu().data.numpy()))                                                                                                  
                                                                                                                                                                                 
    model.training=True                                                                                                                                                          
    return results                                                                                                                                                               


In [None]:
train_preds=get_preds_on(train_iterator, my_model, get_examples)
test_preds=get_preds_on(test_iterator, my_model, get_examples)
val_preds=get_preds_on(val_iterator, my_model, get_examples)
import pickle
pickle.dump((train_preds, val_preds, test_preds),open('/mnt/sshd/vvkulkarni/fakenews_nn_models/preds_std_vae_attn.pkl','wb'))

In [52]:
my_model.load_state_dict(torch.load('/mnt/sshd/vvkulkarni/fakenews_nn_models/vae_std_gauss_attn/vae_std_gauss_attn_4.dict'))

In [None]:
my_model=my_model.cuda()
get_f1_on(test_iterator, my_model, get_examples, get_targets)

In [None]:
def get_f1_on(dataset_iterator, model, get_examples, get_targets):                                                                                                               
    all_preds=[]                                                                                                                                                                 
    all_targets=[]                                                                                                                                                               
    model.training=False                                                                                                                                                         
    for e in dataset_iterator:                                                                                                                                                   
        model.zero_grad()                                                                                                                                                        
        output,_,_,_,_=model(get_examples(e)) 
        print output
        classix=list(np.argmax(output.cpu().data.numpy(), axis=1))                                                                                                               
        targets=get_targets(e).cpu().data.numpy()                                                                                                                                
        all_preds.extend(classix)                                                                                                                                                
        all_targets.extend(targets)                                                                                                                                              
    model.training=True                                                                                                                                                          
    return f1_score(all_targets, all_preds, average='weighted')                                                                                                                  

In [55]:
def get_preds_report(dataset_iterator, model, get_examples, get_targets):                                                                                                               
    all_preds=[]                                                                                                                                                                 
    all_targets=[] 
    all_indices=[]
    model.training=False                                                                                                                                                         
    for e in dataset_iterator:                                                                                                                                                   
        model.zero_grad()                                                                                                                                                        
        output,_,_,_,_=model(get_examples(e)) 
        classix=list(np.argmax(output.cpu().data.numpy(), axis=1))                                                                                                               
        targets=get_targets(e).cpu().data.numpy()
        indices=e.index.cpu().data.numpy()
        all_preds.extend(classix)                                                                                                                                                
        all_targets.extend(targets)
        all_indices.extend(indices)
    model.training=True                                                                                                                                                          
    return all_targets, all_preds, all_indices                                                                                                                

In [56]:
train_preds=get_preds_report(train_iterator, my_model, get_examples, get_targets)
test_preds=get_preds_report(test_iterator, my_model, get_examples, get_targets)
val_preds=get_preds_report(val_iterator, my_model, get_examples, get_targets)
import pickle
pickle.dump((train_preds, val_preds, test_preds),open('/mnt/sshd/vvkulkarni/fakenews_nn_models/preds_our_model_attn_vae_clf_report.pkl','wb'))
from sklearn.metrics import classification_report
print classification_report(test_preds[0], test_preds[1])

             precision    recall  f1-score   support

          0       0.79      0.80      0.80      3362
          1       0.87      0.78      0.82      2855
          2       0.72      0.80      0.76      2110

avg / total       0.80      0.80      0.80      8327



In [57]:
print classification_report(test_preds[0], test_preds[1], digits=4)

             precision    recall  f1-score   support

          0     0.7904    0.8040    0.7971      3362
          1     0.8750    0.7793    0.8244      2855
          2     0.7179    0.8043    0.7586      2110

avg / total     0.8010    0.7956    0.7967      8327

