In [160]:
require 'hdf5'
require 'math'

f = hdf5.open("data.hdf5", "r")

X_train = f:read("windows_train"):all()
Y_train = f:read("train_Y"):all()

X_valid = f:read("windows_valid"):all()
X_valid_nospaces = f:read("valid_kaggle_without_spaces"):all()

Y_valid = f:read("valid_reduced_Y"):all()
Y_valid_spaces = f:read("valid_answers"):all()

nclasses = f:read('nclasses'):all():long()[1]
nfeatures = f:read('nfeatures'):all():long()[1]

f:close()

In [174]:
--Returns converted LongTensor to Str split by space row-wise
function createHash(longTensor) 
    local s = ""
    for i=1,longTensor:size(1) do
        s = s .. tostring(longTensor[i]) .. " "
    end
    return s
end

--Returns converted LongTensor to Str split by space col-wise
function createHash2(longTensor)
    local s = ""
    for i=1,longTensor:size(2) do
        s = s .. tostring(longTensor[1][i]) .. " "
    end
    return s
end

--Return added padding when necessary at </s>
function add_padding(r, n, win) 
    if r < n then
        for i=1,n-r do
            win[i] = 30
        end
        r = r + 1
    end
    if win[n] == 30 then
        r = 1
    end
    return r, win
end

--Returns table of space and total counts
function count_train(windows_train, space_train, n_gram)
    local count_table = {}
    local space_table = {}
    local restart = 1
    for i=1,windows_train:size(1) do
        --Checks if a </s> is necessary (only for CBM!)
        restart, padded = add_padding(restart, n_gram, windows_train[i])
        local key = createHash(padded)
        if count_table[key] then
            count_table[key] = count_table[key] + 1
        else
            count_table[key] = 1
            space_table[key] = 0
        end
        if space_train[i] == 2 then
            space_table[key] = space_table[key] + 1
        end
    end
    return space_table, count_table
end  

--Returns normalized probability for two scores
function normalize(num, den) 
    return num/(num+den)
end

--Returns smoothed counts
function laplace_smooth(count, total, alpha, vocab_size)
    return (count+alpha)/(total+alpha*vocab_size)
end

--Returns perplexity for CBM
function count_perp(space_table, total_table, windows_valid, space_valid, n_gram, alpha, vocab_size)
    local restart = 1
    local perp = 0
    for i=1,windows_valid:size(1) do
        --Checks if a </s> is necessary (only for CBM!)
        restart, padded = add_padding(restart, n_gram, windows_valid[i])
        local key = createHash(padded)
        if total_table[key] then
            local p_space = laplace_smooth(space_table[key], total_table[key], alpha, vocab_size)
            local p_nospace = laplace_smooth(total_table[key] - space_table[key], total_table[key], alpha, vocab_size)
            
            --Next char is NOT space
            if space_valid[i] == 1 then
                perp = perp + math.log(normalize(p_nospace, p_space))
            --Next char IS space
            else
                perp = perp + math.log(normalize(p_space, p_nospace))
            end
        else
            --Probability for unseen counts
            perp = perp + math.log(alpha/vocab_size)
        end
    end
    return math.exp(-perp/windows_valid:size(1))
end

--Returns array of perplexity scores for different alphas
function count_based_CV(space_table, total_table, x_valid, y_valid, w, a, n)
    local list=""
    for i=1,100 do
        local a = i/10
        local perplexity = count_perp(space_table, total_table, x_valid, y_valid, w, a, n)
        list=list..tostring(perplexity)..","
    end
    return list
end

--Returns padded X_valid_nospaces for space count
function pre_pad(x_valid_nospaces, w)
    local padding = torch.Tensor(x_valid_nospaces:size(1), w-1):fill(30)
    padding = padding:type('torch.LongTensor')
    return torch.cat(padding, x_valid_nospaces, 2)
end

--Returns MSE of spaces for GREEDY CBM given following parameters:
    --space_table : hash_table of ngram : n_spaces
    --total_table : hash_table of ngram : counts
    --x_valid : sentences with NO spaces
    --y_valid : number of spaces in each sentence
    --w : window size
    --n : nfeatures
    --a : additive alpha
    --p : threshold paramters
function count_greedy_predict(space_table, total_table, x_valid, y_valid, w, n, a, p)
    local mean_error_sq = 0
    for i=1,1000 do
        local spaces = 0
        local sentence = x_valid:sub(i,i,1,x_valid:size(2)):clone()
        for j=1,x_valid:size(2)-w+1 do
            if sentence[1][j+w] == 30 then
                break
            end
            local context = sentence:sub(1,1,j,j+w-1)
            local key = createHash2(context)
            local p_space = 0
            local p_nospace = 0
            if total_table[key] then
                p_space = laplace_smooth(space_table[key], total_table[key], a, n)
                p_nospace = laplace_smooth(total_table[key] - space_table[key], total_table[key], a, n)
                p_space = normalize(p_space, p_nospace)
                ----print(p_space)
            else
                p_space = 1/n
            end
            if p_space > p then
                spaces = spaces + 1
                --Loop through backwards from the second to last entry up to position of next space
                for k=1,sentence:size(2)-j-w do
                    local idx = sentence:size(2) - k
                    sentence[1][idx+1] = sentence[1][idx]
                end
                sentence[1][j+w] = 30
            end
        end
        --print(spaces, y_valid[i])
        mean_error_sq = mean_error_sq + (spaces - y_valid[i])^2
    end
    return mean_error_sq/1000
end

--Returns MSE of spaces for DYNAMIC PROGRAMMING CBM given following parameters:
    --space_table : hash_table of ngram : n_spaces
    --total_table : hash_table of ngram : counts
    --x_valid : sentences with NO spaces
    --y_valid : number of spaces in each sentence
    --w : window size
    --n : nfeatures
    --a : additive alpha
    --p : threshold paramters
function count_dp_predict(space_table, total_table, x_valid, y_valid, w, n, a, p)
    local mean_error_sq = 0
    for i=1,1000 do
        --This is our tabulated (or memoized I keep forgetting which is which) array to keep track
        --of the space probabilities
        local array = torch.Tensor(x_valid:size(2)+1):fill(0)
        
        --We want to set the very first index as 2 because we are adding EXP probabilities and the rest
        --are 0 because they will be filled in as we go along
        array[1] = 2
        
        --An array to keep track of the number of spaces of each iteration and how many spaces thus far
        local array_space = torch.Tensor(x_valid:size(2)+1):fill(0)
        
        --Outer loop to cycle through all characters
        for j=1,x_valid:size(2)-w do
            
            --Where the previous space was inserted, make it the first entry at beginning
            local prev_space = 1
            
            --Copy of a sentence index so that spaces can be inserted
            local sentence = x_valid:sub(i,i,1,x_valid:size(2)):clone()
            
            --Inner loop to cycle through characters from j onwards, this allows use to possibly
            --'discover' new space segmentations that may have been skipped over in Greedy
            for k=j,x_valid:size(2)-w+1 do
                
                --Early termination condition when padding is reached
                if sentence[1][k+w] == 30 then
                    break
                end
                
                local context = sentence:sub(1,1,k,k+w-1)
                local key = createHash2(context)
                
                local p_space = 0
                local p_nospace = 0
                
                if total_table[key] then
                    p_space = laplace_smooth(space_table[key], total_table[key], a, n)
                    p_nospace = laplace_smooth(total_table[key] - space_table[key], total_table[key], a, n)
                    p_space = normalize(p_space, p_nospace)
                end
                
                --threshold is exceeded
                if p_space > p then
                    --if this space will be better probability wise
                    if math.exp(p_space)+array[prev_space] > array[k+w] then
                        array[k+w] = math.exp(p_space)+array[prev_space]
                        array_space[k+w] = 1 + array_space[prev_space]
                    end
                    prev_space = k+w
                    --spaces = spaces + 1
                    --Loop through backwards from the second to last entry up to position of next space
                    for m=1,sentence:size(2)-k-w do
                        local idx = sentence:size(2) - m
                        sentence[1][idx+1] = sentence[1][idx]
                    end
                    sentence[1][k+w] = 30
                end
            end
            --mean_error_sq = mean_error_sq + (spaces - y_valid[i])^2
            
        end
        _, idx3 = torch.max(array,1)
        mean_error_sq = mean_error_sq + (array_space[idx3[1]]-y_valid[i])^2
    end
    return mean_error_sq/1000
end  

In [162]:
win_size = X_train:size(2)
ngram_space, ngram_total = count_train(X_train, Y_train, win_size)

--Adds some extra padding for MSE calculations
X_valid_nospaces = pre_pad(X_valid_nospaces, win_size)

In [163]:
additive = 1
perplexity = count_perp(ngram_space, ngram_total, X_valid, Y_valid, win_size, additive, nfeatures)

print("COUNT BASED MODEL")
print("================================")
print("Window size: " .. tostring(win_size))
print("Laplace smoothing parameter: " .. tostring(additive))
print("Perplexity: " .. tostring(perplexity))

COUNT BASED MODEL	
Window size: 3	
Laplace smoothing parameter: 1	
Perplexity: 1.1981899697832	


In [175]:
threshold = 0.34
additive = 1
print(count_greedy_predict(ngram_space, ngram_total, X_valid_nospaces, Y_valid_spaces, win_size, nfeatures, additive, threshold))

10.727	


In [173]:
threshold = 0.34
additive = 1
print(count_dp_predict(ngram_space, ngram_total, X_valid_nospaces, Y_valid_spaces, win_size, nfeatures, additive, threshold))

1	


2	


3	


4	


5	


6	


7	


8	


9	


10	


11	


12	


13	


14	


15	


16	


17	


18	


19	


20	


21	


22	


23	


24	


25	


26	


27	


28	


29	


30	


31	


32	


33	


34	


35	


36	


37	


38	


39	


40	


41	


42	


43	


44	


45	




46	


47	


48	


49	


50	


51	


52	


53	


54	


55	


56	


57	


58	


59	


60	


61	


62	


63	


64	


65	


66	


67	


68	


69	


70	


71	


72	


73	


74	


75	


76	


77	


78	


79	


80	


81	


82	


83	


84	


85	


86	


87	


88	


89	


90	


91	


92	


93	


94	


95	


96	


97	


98	


99	


100	


101	


102	


103	


104	


105	


106	


107	


108	


109	


110	


111	


112	


113	


114	


115	


116	


117	


118	


119	


120	


121	


122	


123	


124	


125	


126	


127	


128	


129	


130	


131	


132	


133	


134	


135	


136	


137	


138	


139	


140	


141	


142	


143	


144	


145	


146	


147	


148	


149	


150	


151	


152	


153	


154	


155	


156	


157	


158	


159	


160	


161	


162	


163	


164	


165	


166	


167	


168	


169	


170	


171	


172	


173	


174	


175	


176	


177	


178	


179	


180	


181	


182	


183	


184	


185	


186	


187	


188	


189	


190	


191	


192	


193	


194	


195	


196	


197	


198	


199	


200	


201	


202	


203	


204	


205	


206	


207	


208	


209	


210	


211	


212	


213	


214	


215	


216	


217	


218	


219	


220	


221	


222	


223	


224	


225	


226	


227	


228	


229	


230	


231	


232	


233	


234	


235	


236	


237	


238	


239	


240	


241	


242	


243	


244	


245	


246	


247	


248	


249	


250	


251	


252	


253	


254	


255	


256	




257	


258	


259	


260	


261	


262	


263	


264	


265	


266	


267	


268	


269	


270	


271	


272	


273	


274	


275	


276	


277	


278	


279	


280	


281	


282	


283	


284	


285	


286	


287	


288	


289	


290	


291	


292	


293	


294	


295	


296	


297	


298	


299	


300	


301	


302	


303	


304	


305	


306	


307	


308	


309	


310	


311	


312	


313	


314	


315	


316	


317	


318	


319	


320	


321	


322	


323	


324	


325	


326	


327	


328	


329	


330	


331	


332	


333	


334	


335	


336	


337	


338	


339	


340	


341	


342	


343	


344	


345	


346	


347	


348	


349	


350	


351	


352	


353	


354	


355	


356	


357	


358	


359	


360	


361	


362	


363	


364	


365	


366	


367	


368	


369	


370	


371	


372	


373	


374	


375	


376	


377	


378	


379	


380	


381	


382	


383	


384	


385	


386	


387	


388	


389	


390	


391	


392	


393	


394	


395	


396	


397	


398	


399	


400	


401	


402	


403	


404	


405	


406	


407	


408	


409	


410	


411	


412	


413	


414	


415	


416	


417	


418	


419	


420	


421	


422	


423	


424	


425	


426	


427	


428	


429	


430	


431	


432	


433	


434	


435	


436	


437	


438	


439	


440	


441	


442	


443	


444	


445	


446	


447	


448	


449	


450	


451	


452	


453	


454	


455	


456	


457	


458	


459	


460	


461	


462	


463	


464	


465	


466	


467	


468	


469	


470	


471	


472	


473	


474	


475	


476	


477	


478	


479	


480	


481	


482	


483	


484	


485	


486	


487	


488	


489	


490	


491	


492	


493	


494	


495	


496	


497	


498	


499	


500	


501	


502	


503	


504	


505	


506	


507	


508	


509	


510	


511	


512	


513	


514	


515	


516	


517	


518	


519	


520	


521	


522	


523	


524	


525	


526	


527	


528	


529	


530	


531	


532	


533	


534	


535	


536	


537	


538	


539	


540	


541	


542	


543	


544	


545	


546	


547	


548	


549	


550	


551	


552	


553	


554	


555	


556	


557	


558	


559	


560	


561	


562	


563	


564	


565	


566	


567	


568	


569	


570	


571	


572	


573	


574	


575	


576	


577	


578	


579	


580	


581	


582	


583	


584	


585	


586	


587	


588	


589	


590	


591	


592	


593	


594	


595	


596	


597	


598	


599	


600	


601	


602	


603	


604	


605	


606	


607	


608	


609	


610	


611	


612	


613	


614	


615	


616	


617	


618	


619	


620	


621	


622	


623	


624	


625	


626	


627	


628	


629	


630	


631	


632	


633	


634	


635	


636	


637	


638	


639	


640	


641	


642	


643	


644	


645	


646	


647	


648	


649	


650	


651	


652	


653	


654	


655	


656	


657	


658	


659	


660	


661	


662	


663	


664	


665	


666	


667	


668	


669	


670	


671	


672	


673	


674	


675	


676	


677	


678	


679	


680	


681	


682	


683	


684	


685	


686	


687	


688	


689	


690	


691	


692	


693	


694	


695	


696	


697	


698	


699	


700	


701	


702	


703	


704	


705	


706	


707	


708	


709	


710	


711	


712	


713	


714	


715	


716	


717	


718	


719	


720	


721	


722	


723	


724	


725	


726	


727	


728	


729	


730	


731	


732	


733	


734	


735	


736	


737	


738	


739	


740	


741	


742	


743	


744	


745	


746	


747	


748	


749	


750	


751	


752	


753	


754	


755	


756	


757	


758	


759	


760	


761	


762	


763	


764	


765	


766	


767	


768	


769	


770	


771	


772	


773	


774	


775	


776	


777	


778	


779	


780	


781	


782	


783	


784	


785	


786	


787	


788	


789	


790	


791	


792	


793	


794	


795	


796	


797	


798	


799	


800	


801	


802	


803	


804	


805	


806	


807	


808	


809	


810	


811	


812	


813	


814	


815	


816	


817	


818	


819	


820	


821	


822	


823	


824	


825	


826	


827	


828	


829	


830	


831	


832	


833	


834	


835	


836	


837	


838	


839	


840	


841	


842	


843	


844	


845	


846	


847	


848	


849	


850	


851	


852	


853	


854	


855	


856	


857	


858	


859	


860	


861	


862	


863	


864	


865	


866	


867	


868	


869	


870	


871	


872	


873	


874	


875	


876	


877	


878	


879	


880	


881	


882	


883	


884	


885	


886	


887	


888	


889	


890	


891	


892	


893	


894	


895	


896	


897	


898	


899	


900	


901	


902	


903	


904	


905	


906	


907	


908	


909	


910	


911	


912	


913	


914	


915	


916	


917	


918	


919	


920	


921	


922	


923	


924	


925	


926	


927	


928	


929	


930	


931	


932	


933	


934	


935	


936	


937	


938	


939	


940	


941	


942	


943	


944	


945	


946	


947	


948	


949	


950	


951	


952	


953	


954	


955	


956	


957	


958	


959	


960	


961	


962	


963	


964	


965	


966	


967	


968	


969	


970	


971	


972	


973	


974	


975	


976	


977	


978	


979	


980	


981	


982	


983	


984	


985	


986	


987	


988	


989	


990	


991	


992	


993	


994	


995	


996	


997	


998	


999	


1000	
10.495	


In [128]:
Y_valid_spaces[1]

13	
