In [4]:
require('nn')
require('hdf5')
require('optim')

f = hdf5.open("data.hdf5", "r")
X_train = f:read("train_input"):all()
Y_train = f:read("train_output"):all()
X_valid = f:read("valid_input"):all()
q_valid = f:read("valid_q"):all()
X_test = f:read("test_input"):all()
q_test = f:read("test_q"):all()
kaggle_valid = f:read("valid_kaggle"):all()
nclasses = f:read('nclasses'):all():long()[1]
nfeatures = f:read('nfeatures'):all():long()[1]
f:close()

nSen = 41931
lenSen = 85

window_size = X_train:size(2)
n_actual = nfeatures

function idxSen(X, n)
    local sen_count = 0
    local new_sen_idx = torch.zeros(1, n)
    for i=1,X:size(1) do
        local row = X[i]
        if (row[1] ==row[2] and row[2]==row[3] and row[1]==3) then
            sen_count = sen_count + 1
            new_sen_idx[1][sen_count] = i 
        end
    end
    return new_sen_idx
end

function senMatrix(X, new_sen_idx, sen_len)
    local sen_counter = 2
    local i = 1
    local c = 0
    local sen_matrix = torch.zeros(nSen, sen_len)
    while i < X:size(1) do
        if i < new_sen_idx[1][sen_counter] then
            for j=1,5 do
                c = c + 1
                sen_matrix[sen_counter-1][c]= X[i][j]
            end
            i = i + 5
        else
            local last_win = new_sen_idx[1][sen_counter]-1
            local mid_sen = (2+5-(new_sen_idx[1][sen_counter]-(i-5)))
            for j=mid_sen,5 do
                c = c + 1
                sen_matrix[sen_counter-1][c] = X[last_win][j]
            end
            c = 0
            i = new_sen_idx[1][sen_counter]
            sen_counter = sen_counter + 1
            if sen_counter > new_sen_idx:size(2) then
                break
            end
        end
    end
    return sen_matrix
end

function bigramCount(sentences, K)
    local countMatrix = torch.zeros(K, K)
    for i=1,sentences:size(1) do
        for j=5,sentences:size(2) do
            if sentences[i][j] ~= 0 then
                local pre_word = sentences[i][j-1]
                local nex_word = sentences[i][j]
                countMatrix[pre_word][nex_word] = countMatrix[pre_word][nex_word]+1
            end
        end
    end
    return countMatrix
end

function bigramHash(sentences)
    local bigrams = {}
    for i=1,sentences:size(1) do
        for j=5, sentences:size(2) do
            if sentences[i][j] ~= 0 then
                local w_1 = sentences[i][j-1]
                local w = sentences[i][j]
                local key = tostring(w_1) .." "..tostring(w)
                if bigrams[key] then
                    bigrams[key] = bigrams[key] + 1
                else
                    bigrams[key]=1
                end
            end
        end
    end
    return bigrams
end

function trigramHash(sentences)
    trigrams = {}
    for i=1,sentences:size(1) do
        for j=6,sentences:size(2) do
            if sentences[i][j] ~= 0 then
                local w_2 = sentences[i][j-2]
                local w_1 = sentences[i][j-1]
                local w = sentences[i][j]
                local key = tostring(w_1) .. " " .. tostring(w_2)
                local subkey = tostring(w)
                if trigrams[key] then
                    if trigrams[key][subkey] then
                        trigrams[key][subkey] = trigrams[key][subkey] + 1
                    else
                        trigrams[key][subkey] = 1
                    end
                else
                    trigrams[key] = {}
                    trigrams[key][subkey] = 1
                end
                
            end
        end
    end
    return trigrams
end

function trigramNorm(bigramD, trigramD, nfeat, alpha)
    for k,v in pairs(bigramD) do
        if trigramD[k] then
            for k1,v1 in pairs(trigramD[k]) do
                trigramD[k][k1] = (trigramD[k][k1] + 1)/(v+nfeat)
            end
        end
    end
    return trigramD
end

function unigramCount(sentences, K)
    local countVector = torch.zeros(1, K)
    for i=1,sentences:size(1) do
        for j=4,sentences:size(2) do
            if sentences[i][j] ~= 0 then
                local idx = sentences[i][j]
                countVector[1][idx] = countVector[1][idx] + 1
            end
        end
    end
    return countVector
end

function additive(matrix, alpha)
    local add_matrix = matrix:add(alpha)
    return add_matrix
end

function normalize(cv, cm)
    local norm = torch.cdiv(cm,torch.expand(cv:t(), cm:size(1), cm:size(2)))
    return norm
end

function bigramPerplexity(cm, valid, correct_idx, word_choices)
    local perp = 0
    for i=1,valid:size(1) do
        local cur_corr_idx = correct_idx[i]
        local corr_word = word_choices[i][cur_corr_idx]
        local pre_word = valid[i][5]
        perp = torch.log(cm[pre_word][corr_word]) + perp
    end
    return torch.exp(-perp/valid:size(1))
end

function trigramPerp(trigramnorm, bigramCounts, valid, correct_idx, word_choices, alpha, nfeat)
    local perp = 0
    for i=1, valid:size(1) do
        local cur_corr_idx = correct_idx[i]
        local corr_word = word_choices[i][cur_corr_idx]
        local w_2 = valid[i][4]
        local w_1 = valid[i][5]
        local key = tostring(w_2) .. " " .. tostring(w_1)
        local subkey = tostring(corr_word)
        local prob = 0
        if trigramnorm[key] then
            if trigramnorm[key][subkey] then
                prob = trigramnorm[key][subkey]
            else
                if bigramCounts[key] then
                    prob = alpha/(bigramCounts[key] + nfeat)
                else
                    prob = alpha /nfeat
                end
            end
        else
            if bigramCounts[key] then 
                prob = alpha /(bigramCounts[key] + nfeat)
            else
                prob = alpha / nfeat
            end
        end
        perp = torch.log(prob) + perp
    end
    return torch.exp(-perp/valid:size(1))
end

function wbBigram_perp(bigramCountMatrix, unigrams, valid, correct_idx, word_choices)
    local perp = 0
    for i=1, valid:size(1) do
        local cur_corr_idx = correct_idx[i]
        local corr_word = word_choices[i][cur_corr_idx]
        local w_1 = valid[i][5]
        local counts = bigramCountMatrix[w_1][corr_word]
        if counts ~= 0 then
            perp = torch.log(counts/(unigrams[1][w_1]+ bigramCountMatrix:sum())) + perp
        else 
            local total = bigramCountMatrix:sub(w_1, w_1):sum()
            local total_all = bigramCountMatrix:sum()
            perp = torch.log( total / (total + total_all)) + perp
        end
    end
    return torch.exp(-perp/valid:size(1))
end
        
function wbtrigramPerp(trigramnorm, bigramCounts, valid, correct_idx, word_choices, alpha, nfeat)
    local perp = 0
    for i=1, valid:size(1) do
        local cur_corr_idx = correct_idx[i]
        local corr_word = word_choices[i][cur_corr_idx]
        local w_2 = valid[i][4]
        local w_1 = valid[i][5]
        local key = tostring(w_2) .. " " .. tostring(w_1)
        local subkey = tostring(corr_word)
        local prob = 0
        if trigramnorm[key] then
            if trigramnorm[key][subkey] then
                prob = trigramnorm[key][subkey]
            else
                if bigramCounts[key] then
                    prob = bigramCounts[key] /(bigramCounts[key] + bigramCountMatrix:sum())
                else
                    prob = alpha / nfeat
                end
            end
        else
            if bigramCounts[key] then 
                prob = alpha /(bigramCounts[key] + bigramCountMatrix:sum())
            else
                prob = alpha / nfeat
            end
        end
        perp = torch.log(prob) + perp
    end
    return torch.exp(-perp/valid:size(1))
end

function kaggle(CM, test, word_choices)
    local scores = torch.zeros(word_choices:size(1), 50)
    for i=1,test:size(1) do
        local pre_word = test[i][5]
        for j=1,50 do
            local word = word_choices[i][j]
            scores[i][j] = CM[pre_word][word]
        end
    end
    return scores:cdiv(torch.expand((torch.sum(scores, 2)), scores:size(1), scores:size(2)))
end

function write2file(scores, fname)
    f = io.open(fname, "w")
    f:write("ID,Class1,CLass2,Class3,Class4,Class5,Class6,Class7,Class8,Class9,Class10,Class11,Class12,CLass13,Class14,Class15,Class16,Class17,Class18,Class19,Class20,Class21,Class22,Class23,CLass24,Class25,Class26,Class27,Class28,Class29,Class30,Class31,Class32,Class33,Class34,CLass35,Class36,Class37,Class38,Class39,Class40,Class41,Class42,Class43,Class44,Class45,CLass46,Class47,Class48,Class49,Class50\n")
    for i=1,scores:size(1) do
        s = tostring(i)
        for j=1, scores:size(2) do
            s = s .. "," .. tostring(scores[i][j])
        end
        f:write(s .. "\n")
    end
    f:close()
end
sen_idx = idxSen(X_train, nSen)
sen_matrix = senMatrix(X_train, sen_idx, lenSen)

--initially started off with tensors but we realized
--it is much faster wih hash tables
CM = bigramCount(sen_matrix, nfeatures)
CV = unigramCount(sen_matrix, nfeatures)
CV = additive(CV, nfeatures)
CM = additive(CM, 1)
BCM = normalize(CV, CM)

bperp = bigramPerplexity(BCM, X_valid, kaggle_valid, q_valid)
scores = kaggle(BCM, X_test, q_test)
write2file(scores, "bigram.csv")

bigramFeat = 845430
bigramM = bigramHash(sen_matrix)
trigramM = trigramHash(sen_matrix)

trigramM_norm = trigramNorm(bigramM, trigramM, bigramFeat, 1)
trigramPerp(trigramM, bigramM, X_valid, kaggle_valid, q_valid, 1, bigramFeat)

wbBigram_perp(CM, CV, X_valid, kaggle_valid, q_valid)
wbtrigramPerp(trigramM, bigramM, X_valid, kaggle_valid, q_valid, 1, bigramFeat)