In [29]:
import numpy as np
from scipy.stats import rankdata
from scipy.special import softmax
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import pickle
from scipy.stats import rankdata

def mrrs(out, labels):
#     print(out,labels)
    outputs = np.argmax(out,axis=1)
    mrr = 0.0 
    last = [0,0,0,0,0,0,0]
    for label,ranks in zip(labels,out):
        ranks = rankdata(ranks*-1)
        rank = ranks[label]
#         print(rank,ranks)
        last[int(rank)]+=1
        mrr+=1/rank
    print("Total @ Last:",last)
    return mrr/len(labels)

def precision_at(out,labels,prank=1):
    count = 0
    for label,ranks in zip(labels,out):
        ranks = rankdata(ranks*-1)
        rank = ranks[label]
#         print(rank,ranks)
        if rank <= prank:
            count+=1
    return count/len(labels)

def mrrwrapper(qid2c,qid2indexmap,preds_prob):
    labels = []
    out = []
    for qid in qid2c.keys():
        scores = []
        for ix in qid2indexmap[qid]:
            if len(scores) < 6:
                scores.append(preds_prob[ix][1])
        if len(scores) < 6:
            continue
        out.append(scores)
        labels.append(int(qid2c[qid]))
    return mrrs(np.array(out),labels),precision_at(np.array(out),labels,1),precision_at(np.array(out),labels,3)

def mrrVisualization(n_trainSamples,mrrScore):
    ''' Plots the data points with the centroids
    '''
    fig = plt.figure()
    ax = plt.axes()
    ax.plot(n_trainSamples,mrrScore,label='MRR vs TrainData',marker='x')
    plt.title("Scores vs TrainData")
    plt.legend()
    plt.xlabel("TrainData")
    plt.ylabel("Scores")
    fig.savefig("NN_Tokens_Linear_Mrr.png")
    plt.show()
    

def precisionVisualization(n_trainSamples,precisionAt1,precisionAt3):
    ''' Plots the data points with the centroids
    '''
    fig = plt.figure()
    ax = plt.axes()
    ax.plot(n_trainSamples,precisionAt1,label='Precision@1 vs TrainData',marker='x')
    ax.plot(n_trainSamples,precisionAt3,label='Precision@3 vs TrainData',marker='o')
    plt.title("Precisions vs TrainData")
    plt.legend()
    plt.xlabel("TrainData")
    plt.ylabel("Precisions")
    fig.savefig("NN_Tokens_Linear_Precision.png")
    plt.show()

def load_ranking(fname):
    return pickle.load(open("/scratch/pbanerj6/sml-dataset/ranking_"+fname+".p","rb"))

def accuracyVisualization(n_trainSamples,valAccuracy,testAccuracy):
    ''' Plots the data points with the centroids
    '''
    fig = plt.figure()
    ax = plt.axes()
    ax.plot(n_trainSamples,testAccuracy,label='TestAccuracy vs TrainData',marker='o')
    ax.plot(n_trainSamples,valAccuracy,label='ValAccuracy vs TrainData',marker='.')
    plt.title("Scores vs TrainData")
    plt.legend()
    plt.xlabel("TrainData")
    plt.ylabel("Scores")
    fig.savefig("NN_Tokens_Linear_Accuracy.png")
    #plt.show()

qid2c,qid2indexmap = load_ranking("test")

In [31]:
filename = "/scratch/pbanerj6/sml-class-bert-simple-base-1e5-r1-small-128/test.tsv-score.tsv"

valnames = ["/scratch/pbanerj6/sml-class-bert-simple-base-1e5-r1-small-128/val.tsv-score.tsv",
           "/scratch/pbanerj6/sml-class-bert-large-1e5-r1-small-128/val.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-simple-1e5-r1-small-128/val.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-simple-large-1e5-r1-small-128/val.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-cnn-5e6-r1-small-128/val.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-large-cnn-1e5-r1-small-128/val.tsv-score.tsv",
           ]

testnames = ["/scratch/pbanerj6/sml-class-bert-simple-base-1e5-r1-small-128/test.tsv-score.tsv",
           "/scratch/pbanerj6/sml-class-bert-large-1e5-r1-small-128/test.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-simple-1e5-r1-small-128/test.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-simple-large-1e5-r1-small-128/test.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-cnn-5e6-r1-small-128/test.tsv-score.tsv",
            "/scratch/pbanerj6/sml-class-bert-large-cnn-1e5-r1-small-128/test.tsv-score.tsv",
           ]

def get_mrr(filename,qid2c,qid2indexmap):
    with open(filename,"r") as fd:
        lines = fd.readlines()
        preds = []
        for line in lines:
            line = line.strip().split("\t")
            pred = [0,float(line[-1])]
            preds.append(pred)
        return mrrwrapper(qid2c=qid2c,qid2indexmap=qid2indexmap,preds_prob=preds)


for filename1,filename2 in zip(valnames,testnames):
    print("Model:",filename1.split("/")[3])
    qid2c,qid2indexmap = load_ranking("val")
    print("Val\t",get_mrr(filename1,qid2c,qid2indexmap),)
    qid2c,qid2indexmap = load_ranking("test")
    print("Test\t",get_mrr(filename2,qid2c,qid2indexmap))

Model: sml-class-bert-simple-base-1e5-r1-small-128
Total @ Last: [0, 766, 320, 195, 104, 72, 43]
Val	 (0.692088888888891, 0.51, 0.854)
Total @ Last: [0, 719, 348, 192, 142, 63, 36]
Test	 (0.6740349206349232, 0.47933333333333333, 0.8386666666666667)
Model: sml-class-bert-large-1e5-r1-small-128
Total @ Last: [0, 748, 362, 175, 103, 67, 45]
Val	 (0.6890682539682561, 0.498, 0.856)
Total @ Last: [0, 730, 357, 204, 113, 61, 35]
Test	 (0.681633333333336, 0.486, 0.8606666666666667)
Model: sml-class-bert-simple-1e5-r1-small-128
Total @ Last: [0, 756, 319, 167, 108, 80, 70]
Val	 (0.6836481481481497, 0.5033333333333333, 0.828)
Total @ Last: [0, 724, 310, 190, 120, 88, 68]
Test	 (0.6674925925925951, 0.4826666666666667, 0.816)
Model: sml-class-bert-simple-large-1e5-r1-small-128
Total @ Last: [0, 806, 306, 173, 104, 68, 43]
Val	 (0.7085111111111126, 0.536, 0.8566666666666667)
Total @ Last: [0, 767, 313, 179, 112, 75, 54]
Test	 (0.689888888888891, 0.5106666666666667, 0.8393333333333334)
Model: sml-cl

In [34]:
med_val = ["/scratch/pbanerj6/sml-class-bert-2e6-r1-med/val.tsv-score.tsv","/scratch/pbanerj6/sml-class-bert-large-v2-5e6-full/val.tsv-score.tsv"]
med_test = ["/scratch/pbanerj6/sml-class-bert-2e6-r1-med/test.tsv-score.tsv","/scratch/pbanerj6/sml-class-bert-large-v2-5e6-full/test.tsv-score.tsv"]
for filename1,filename2 in zip(med_val,med_test):
    print("Model:",filename1.split("/")[3])
    qid2c,qid2indexmap = pickle.load(open("../xgboost/ranking_val_med.p","rb"))
    print("Val\t",get_mrr(filename1,qid2c,qid2indexmap))
    qid2c,qid2indexmap = pickle.load(open("../xgboost/ranking_test_med.p","rb"))
    print("Test\t",get_mrr(filename2,qid2c,qid2indexmap))

Model: sml-class-bert-2e6-r1-med
Total @ Last: [0, 28342, 11046, 5482, 2894, 1553, 811]
Val	 (0.7352362713358449, 0.5650734120651133, 0.8950486753909991)
Total @ Last: [0, 28794, 10824, 5277, 2878, 1565, 790]
Test	 (0.7405171016166736, 0.5739506862432173, 0.8955872965209065)
Model: sml-class-bert-large-v2-5e6-full
Total @ Last: [0, 29267, 10890, 5176, 2630, 1428, 737]
Val	 (0.7480311355311282, 0.5835062240663901, 0.904304979253112)
Total @ Last: [0, 29597, 10742, 5008, 2670, 1379, 732]
Test	 (0.7519528261618383, 0.5899696776252793, 0.9046042132141717)


In [35]:

# sml-class-bert-1e5-r1-small-128/eval_results.txt:eval_accuracy_2_Test = 0.7072666666666667  -- CNN - Done
# sml-class-bert-1e5-r1-small-128/eval_results.txt:eval_accuracy_2_Val = 0.7148666666666667 -- CNN - Done

# sml-class-bert-large-1e5-r1-small-128/eval_results.txt:eval_accuracy_1_Test = 0.7211333333333333 -- SEQCLASS - Done
# sml-class-bert-large-1e5-r1-small-128/eval_results.txt:eval_accuracy_1_Val = 0.7240666666666666 -- SEQCLASS


# sml-class-bert-2e6-r1-med/eval_results.txt:eval_accuracy_4_Test = 0.7545613048144557 - SEQCLASS  - Done
# sml-class-bert-2e6-r1-med/eval_results.txt:eval_accuracy_4_Val = 0.7510662660937356 -  SEQCLASS  - Done


# sml-class-bert-cnn-5e6-r1-small-128/eval_results.txt:eval_accuracy_4_Test = 0.7078 -- CNN 
# sml-class-bert-cnn-5e6-r1-small-128/eval_results.txt:eval_accuracy_4_Val = 0.7088 -- CNN 

# sml-class-bert-large-cnn-1e5-r1-small-128/eval_results.txt:eval_accuracy_2_Test = 0.7110666666666666
# sml-class-bert-large-cnn-1e5-r1-small-128/eval_results.txt:eval_accuracy_2_Val = 0.7258



# /scratch/pbanerj6/sml-class-bert-simple-large-1e5-r1-small-128     eval_accuracy_1_Test = 0.7131333333333333
# /scratch/pbanerj6/sml-class-bert-simple-large-1e5-r1-small-128     eval_accuracy_1_Val = 0.7310666666666666

# /scratch/pbanerj6/sml-class-bert-simple-1e5-r1-small-128     eval_accuracy_2_Test = 0.7076
# /scratch/pbanerj6/sml-class-bert-simple-1e5-r1-small-128     eval_accuracy_2_Val = 0.7110666666666666




# Type			Base 	Large
# Models

# Simple 			Done 	Done 
# SeqClass        		Done 
# CNN 			Done	Done 

# python run_scorer.py --data_dir ../datasets/ranking/small --fname val.tsv --bert_model bert-base-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-cnn-5e6-r1-small-128 --modeltype cnn2
# python run_scorer.py --data_dir ../datasets/ranking/small --fname test.tsv --bert_model bert-base-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-cnn-5e6-r1-small-128 --modeltype cnn2


# python run_scorer.py --data_dir ../datasets/ranking/small --fname val.tsv --bert_model bert-large-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-large-cnn-1e5-r1-small-128 --modeltype cnn
# python run_scorer.py --data_dir ../datasets/ranking/small --fname test.tsv --bert_model bert-large-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-large-cnn-1e5-r1-small-128 --modeltype cnn

# python run_scorer.py --data_dir ../datasets/ranking/small --fname val.tsv --bert_model bert-base-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-simple-1e5-r1-small-128 --modeltype simple
# python run_scorer.py --data_dir ../datasets/ranking/small --fname test.tsv --bert_model bert-base-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-simple-1e5-r1-small-128 --modeltype simple

# python run_scorer.py --data_dir ../datasets/ranking/small --fname val.tsv --bert_model bert-large-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-simple-large-1e5-r1-small-128  --modeltype simple
# python run_scorer.py --data_dir ../datasets/ranking/small --fname test.tsv --bert_model bert-large-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-simple-large-1e5-r1-small-128  --modeltype simple


# python run_scorer.py --data_dir ../datasets/ranking-med/ --fname test.tsv --bert_model bert-large-cased --task_name obqa --output_dir /scratch/pbanerj6/sml-class-bert-large-v2-5e6-full --modeltype seqclass



In [36]:
import torch

In [37]:
a = torch.ones([4,5])

In [38]:
a

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [39]:
b = torch.ones(4,3)

In [40]:
b

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [41]:
c =  torch.zeros([4,2])

In [51]:
a*torch.cat([b,c],1)

tensor([[1., 1., 1., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 0., 0.]])