In [1]:
require("optim")
require('hdf5')

<h3>Multinomial Logistic Regression - LBFGS Minibatch - L2 Norm</h3>

In [6]:
function ml(W, X, Y)
    local W = W:reshape(Y:size(2), X:size(2)+1)
    local b = W:sub(1, W:size(1), W:size(2),W:size(2)):t()
    W = W:sub(1, W:size(1),1,W:size(2)-1)
    local p = X*W:t()
    p:add(b:expand(b,p:size(1),b:size(2)))
    local arr = p:clone()
    arr = arr:t()
    local vmax = arr:max(1)
    local evmax = torch.expand(vmax,arr:size(1),vmax:size(2))
    arr:csub(evmax)
    arr:exp()
    arr = arr:sum(1)
    arr:log()
    arr:add(vmax)
    arr = arr:t()
    arr:expand(arr, arr:size(1), p:size(2))
    p:csub(arr)
    local norm = W:reshape(W:size(1)*W:size(2), 1)
    local loss = (torch.sum(torch.cmul(Y,p))*-1) + 1.0 *0.5 * torch.dot(norm, norm) -- torch.norm(W)
    p:exp()
    return loss, p, W
end

function mlg(W, X, Y)
    local bsize = 10000
    local idx = torch.randperm(X:size(1)):sub(1,bsize)
    local X_batch = torch.Tensor(bsize, X:size(2))
    local Y_batch = torch.Tensor(bsize, Y:size(2))
    
    for i=1,bsize do
        X_batch[i] = X[idx[i]]
        Y_batch[i] = Y[idx[i]]
    end

    local grad = torch.zeros(Y_batch:size(2), X_batch:size(1)+1)
    local loss, p, W = ml(W, X_batch, Y_batch)
    local diff = torch.csub(p,Y_batch)
 
    local grad = diff:t()*X_batch

    grad:add(W)
    grad = grad:cat(torch.zeros(grad:size(1),1), 2)
    grad:sub(1, grad:size(1), grad:size(2), grad:size(2)):add(diff:sum(1))
    print(loss)
    return loss, grad:reshape(grad:size(1)*grad:size(2), 1), p
end


function fit(X, Y, rate, iter, lX)

    local W = torch.zeros(Y:size(2) * (X:size(2)+1), 1)    
    local func = function(W)
        loss, grad, p = mlg(W, X, Y)
        return loss, grad
    end
    local state = {learningRate = rate, maxIter=iter, tolX=lX}
    W, f_hist, currentFuncEval = optim.lbfgs(func, W, state)
    W = W:reshape(Y:size(2), X:size(2)+1)
    b = W:sub(1, W:size(1), W:size(2), W:size(2))
    W = W:sub(1, W:size(1), 1, W:size(2)-1)
    return W, b
end

function predict(X, W, b)
    local b = b:t()
    return (X*W:t()):add(b:expand(b, X:size(1), b:size(2)))
end

function predict_score(ypred, ytrue)
    local c = 0
    for i=1,ypred:size(1) do
        if ypred[i][1] == ytrue[i][1] then
            c = c + 1       
        end
    end
    return c/ypred:size(1)
end

<h3>Create Document Word Matrix and One Hot Encoding</h3>

In [3]:
--feature weight: counts
function createDocWordMatrix(vocab, max_sent_len, sparseMatrix)
    docword = torch.zeros(sparseMatrix:size(1), vocab)
    for i=1,sparseMatrix:size(1) do
        for j=1, max_sent_len do
            local idx = (sparseMatrix[i][j])
            if idx ~= 0 then
                docword[i][idx] = 1 + docword[i][idx]
            end
        end
    end
    return docword
end
 
function onehotencode(classes, target)
    onehot = torch.zeros(target:size(1), classes)
    for i=1,target:size(1) do
        onehot[i][target[i]] = 1
    end
    return onehot
end

function write2file(fname, pred) 
    f = io.open(fname, "w")
    f:write("ID,Category\n")
    for i=1,pred:size(1) do
        f:write(tostring(i) .. "," .. tostring(pred[i][1]) .. "\n")
    end
    f:close()
end
 


In [10]:
f = hdf5.open("SST1.hdf5", "r")

--X_train = f:read("train_input"):all()
--Y_train = f:read("train_output"):all()
--X_valid = f:read("valid_input"):all()
--Y_valid = f:read("valid_output"):all()
X_test = f:read("test_input"):all()
nclasses = f:read('nclasses'):all():long()[1]
nfeatures = f:read('nfeatures'):all():long()[1]

f:close()

In [11]:
--X_train =createDocWordMatrix(nfeatures, 53, X_train)
--Y_train = onehotencode(nclasses, Y_train)
X_test = createDocWordMatrix(nfeatures, 53, X_test)
--Y_test = onehotencode(nclasses, Y_valid)

In [8]:
--start_time = os.time()
W, b = fit(X_train, Y_train, 0.095, 1000)
--end_time = os.time()
--print(end_time - start_time)

16094.379124338	


16067.60022604	


1910848.2529167	


1833612.5642696	


1725260.8038949	


1552498.6165235	


852705.36061441	


409522.39938089	


391933.86094358	


776029.72864099	


716728.56969043	


613431.90267036	


538983.5361749	


1334786.7093546	


1252885.519009	


494231.61454557	


693605.47570207	


635654.15422549	


569545.71326754	


417746.26965817	


367524.74434727	


352742.92770786	


303562.21646064	


185876.52786197	


115270.98968925	


94695.413107974	


87063.70559094	


84886.176365813	


78281.653352938	


73114.056765294	


67587.908684947	


62704.650146377	


60363.037640082	


54221.425292961	


52913.453030302	


47402.910804419	


44792.496452725	


39166.141674874	


36675.748369301	


30851.76240868	


29395.098450646	


28359.665930465	


26704.173657157	


25687.050068687	


23745.172389734	


22422.523063395	


22087.801119268	


20987.751389283	


20946.680915495	


19673.127219413	


18859.895623923	


18738.052380145	


18084.857517999	


17401.497417122	


17161.259922317	


16788.465257901	


16644.644860327	


16143.60894516	


16165.91154124	


15768.395845135	


15468.127668287	


15390.563484435	


14858.412668968	


14891.152358196	


14784.111588549	


14754.120786206	


14245.086906765	


14118.031214681	


14086.015955437	


14057.638342324	


13716.77331581	


13732.251430056	


13529.06740241	


13431.035630557	


13395.541825372	


13274.975414942	


12880.69252401	


13035.173887364	


13615.500345411	


13214.253095979	


13359.308840588	


12667.733089843	


12492.789270848	


12240.149261127	


12466.35877488	


11868.926600356	


11870.641793581	


12030.206272747	


11967.80844601	


11906.690851623	


11871.345419854	


11780.66663759	


11862.226633878	


11666.440786815	


66657.271120028	


60454.046697019	


57555.073000835	


50630.344367049	


41241.370158122	


32090.41986232	


23128.634383263	


17473.084285866	


14822.819535445	


13563.603701078	


12535.440804775	


12395.127314735	


11886.657861434	


11645.616533578	


11608.764471872	


11517.36227757	


11126.283671745	


11220.200778645	


11149.18974644	


11124.979428142	


11286.669658489	


11109.417074473	


11192.823744935	


11230.983683461	


11113.593719554	


11146.08167621	


11457.469545661	


11285.968201304	


11259.884777535	


11165.163348185	


11223.554777273	


11010.587395687	


11076.397402669	


11098.973271353	


11025.990746339	


10884.072675552	


11080.671350532	


10982.455150223	


10991.221389294	


10954.441970272	


11082.339383755	


10817.446054269	


10995.152903613	


11009.60700381	


10948.091377713	


10771.838610403	


10841.677345884	


10853.887596732	


10814.814893904	


10738.11050028	


10919.285095651	


10875.662409699	


10629.327958983	


10832.328729531	


10768.574609808	


10774.319182196	


10841.99183319	


10830.06010922	


10710.439857626	


10697.348005055	


10733.960777927	


10819.441716399	


10725.45619637	


10711.176893217	


10689.92103368	


10715.644804645	


10678.902843573	


10748.707878741	


10698.418182517	


10787.007351607	


10761.640881771	


10701.665146623	


10652.993900476	


10722.553294391	


10622.391660766	


10620.856516884	


10675.58271168	


10634.547014923	


10679.276515646	


10642.453838845	


10587.548864688	


10797.294054935	


10800.651056248	


10637.201523221	


10612.265225849	


10662.132563379	


10648.121256827	


10539.511174657	


10634.84767893	


10576.956069431	


10343.157937631	


10586.438277185	


10631.98228765	


10475.423812437	


10607.583083611	


10798.47806698	


61214.270219461	


55247.841403245	


15543.356801662	


14432.529436497	


13334.848427249	


12554.177296292	


12183.213348812	


11624.5893925	


11197.240070554	


11017.896638579	


10890.804273579	


10596.05623331	


10712.387002098	


10594.365354586	


10676.205397523	


10494.503441885	


10402.199490039	


10485.916581822	


10499.9393305	


10501.998412887	


10411.765553639	


10478.327869447	


10469.513667565	


10590.748262519	


10544.523394716	


10501.183909051	


10447.82232285	


10580.064296818	


10439.26542253	


10446.888274584	


10464.645445279	


10525.503730174	


10332.894842535	


10436.721206138	


10485.052965965	


10491.945115617	


10428.33613366	


10479.779522033	


10343.276000886	


10339.264843321	


10414.193280897	


10487.609528226	


10306.273534664	


10512.544568763	


10510.557444434	


10348.505391963	


10378.255802131	


10514.809268948	


10520.819336888	


10281.195907192	


10342.315281332	


10387.978421415	


10315.436744159	


10386.097201994	


10514.625863629	


10282.980197484	


10444.488499102	


10347.83313748	


10444.591601948	


10382.688426379	


10355.114643763	


10327.491357292	


10451.56491098	


10361.205797619	


10420.365369353	


10255.30210421	


10393.677076561	


10729.799431814	


10691.680798268	


10572.314883093	


10532.945217436	


10364.425363395	


10416.849982077	


10355.328362526	


10371.476734434	


17593.959857313	


16356.213336417	


14741.431251858	


12871.140486002	


11930.056088298	


11423.425321878	


11037.960474697	


10939.612988761	


10729.508462082	


10641.96172224	


10592.251790738	


10688.686265091	


10450.192625858	


10529.471315822	


10344.202496261	


10365.203729809	


10285.437496992	


10374.124283863	


10402.209809146	


10293.508849686	


10298.440113484	


10263.235216396	


10242.222697347	


10154.484697917	


10298.962075324	


10309.549620223	


10221.31055634	


10364.544046409	


10241.988295777	


10252.483664184	


10196.681918224	


10288.511440467	


10317.605301129	


10335.94325701	


10367.839183087	


10246.367977348	


10115.670794949	


10095.677559579	


10142.724815317	


10328.585983222	


10363.587944998	


10254.782964664	


10335.894030914	


10130.227140198	


10283.51577329	


10240.607801374	


10221.979895343	


10314.193720893	


10253.329535891	


10315.551659207	


10321.624320882	


10313.333303005	


10284.65442141	


10161.431346108	


10252.21965832	


10192.604709838	


10249.777871608	


10131.132706738	


10219.304686362	


10223.916348604	


10279.210789729	


10309.026396736	


10245.958865749	


10155.774475891	


10289.573516222	


10143.065731001	


10237.170151736	


10224.403553372	


10307.516876193	


10171.651623922	


10146.901655427	


10183.933150714	


10131.690605565	


10138.709489456	


10479.415298323	


10400.612033471	


10445.483556621	


10227.96961659	


10178.905571791	


10253.634994033	


10219.788062924	


10183.521966987	


10117.488216485	


10231.759268536	


10179.514791279	


10130.722354166	


10100.374155732	


10281.359395213	


10197.515667761	


10109.397994605	


10192.861598336	


10095.792326459	


10273.987022802	


10344.277026812	


10230.594632023	


10158.945454258	


10150.90780987	


10218.888983434	


10103.566959943	


10302.906366217	


10131.52998327	


10150.221523421	


10104.630544121	


10112.052134764	


10345.612144049	


10207.044466146	


10174.467665828	


10210.882693271	


10087.054246601	


10095.602702232	


10294.555679405	


10175.896486286	


10208.825894713	


10240.163154294	


10299.701237483	


10202.009815676	


10211.636164834	


10130.17198059	


10078.166327064	


10194.737464569	


10115.150214817	


10147.775199065	


10114.343469798	


10108.532209635	


10156.13173148	


10181.364191947	


10131.519439752	


10139.876373092	


10069.482216238	


10362.667721194	


10106.31452109	


10151.942721972	


10138.854908563	


10218.136265717	


10066.984911644	


2060025.7114149	


1863810.22454	


4174709.8251734	


3687007.0067764	


1854462.8819901	


1551300.5337961	


1199613.5398993	


823538.70702177	


611745.18961457	


533052.7333307	


453667.38560963	


399508.01476511	


349426.30492064	


322780.68038065	


273312.76957672	


270620.50021248	


234497.05222036	


224231.68762235	


204545.86580475	


192390.45362946	


176059.5528942	


171543.13469634	


153464.39445782	


139654.488142	


134101.071732	


123087.58991343	


108485.80656468	


106882.20549801	


100255.43232528	


91920.638869182	


90544.170413498	


81169.689644021	


81068.578246131	


75095.806675142	


71109.400261875	


65755.006266531	


63097.246093922	


60864.541356554	


54456.838206467	


53549.873093085	


49810.803599464	


47930.677478009	


44626.344261464	


42219.929988383	


40551.166070845	


39130.819531752	


36610.516694411	


35329.698534221	


32965.460428002	


32509.36139499	


30554.487802983	


29430.885293638	


28312.946368427	


26806.014681716	


25674.434795604	


24912.029396934	


24553.224557979	


22267.772586888	


22929.107993813	


22034.176051973	


21399.241596869	


20946.47659956	


20413.936579622	


20128.316332021	


19424.276088318	


18944.295444789	


18753.693797559	


17964.678431186	


17372.391396645	


17853.186277371	


16760.017228854	


17086.35522451	


16712.035974714	


16429.74611325	


16430.221887228	


15721.670305884	


15954.789127085	


15881.483722747	


15083.326624663	


15144.196366721	


14866.727375293	


14667.952461393	


15066.824693606	


14719.989119306	


14425.265350087	


14297.212264441	


14439.947346689	


14472.884936726	


14126.620822807	


14287.486538787	


13891.949920256	


13841.082855059	


13746.997982903	


13760.101600736	


13650.434738031	


13741.613629594	


13266.477682379	


13129.817734704	


12956.189908119	


12920.76130299	


12784.985790262	


12530.634393884	


12683.883524592	


12578.701938298	




12609.686522295	


12657.515371417	


12517.908210792	


12382.260485744	


12236.671333499	


12369.69883641	


12293.910864749	


12166.934457631	


12048.392511893	


11964.459037912	


12183.981070269	


11992.857529618	


11821.616110715	


12074.650977822	


11999.454635207	


11902.027102726	


11827.597669392	


11764.313897364	


11725.469344001	


11886.386572919	


11856.061216637	


11654.951530378	


11810.646085722	


11535.645281614	


11533.923050494	


11445.7155264	


11527.689481239	


11458.533881192	


11567.099252253	


11444.779184007	


11449.896175667	


11368.787299864	


11316.609923063	


11330.040069216	


11423.884734612	


11396.404707632	


13381.294027621	


12739.375396719	


12486.338499753	


12127.400326557	


11691.285201741	


11721.783238124	


11542.169388754	


11290.70225434	


11099.002075141	


11246.693249603	


10969.828693266	


10956.314786144	


11056.670736859	


10902.364967447	


11001.004528051	


10964.00208745	


10838.935660146	


10949.726696886	


11108.461244113	


10980.46859396	


11009.973885563	


11087.311171485	


10942.035566934	


10810.370486658	


10870.450353889	


10970.413295021	


11026.778897842	


10896.467021232	


10968.151195967	


787761.97455039	


700291.52722506	


179482.45553936	


65862.581626614	


50716.28459284	


44414.029024133	


31140.885885391	


27312.65449178	


21245.126677436	


17955.045430881	


16503.142692832	


14829.645076271	


14636.141799275	


13928.177018718	


13299.10286072	


12960.890403469	


12620.786196546	


12361.638688879	


12320.867952392	


11982.684306649	


11961.449351543	


12057.983236145	


11829.045309798	


11613.778562223	


11708.59452878	


11371.492554951	


11599.755348117	


11406.330404195	


11199.624692378	


11154.452493708	


11286.003910064	


11115.732091667	


11086.105705111	


11055.39648732	


11143.630939013	


11094.566245296	


10946.421397028	


11013.210319794	


10903.404018579	


10914.453485494	


10838.502914751	


10768.91774144	


10771.067385356	


10752.551642112	


10733.874138009	


10710.554368815	


10656.050822029	


10752.81729156	


10733.742490298	


10702.904774227	


10823.64745712	


10862.379315448	


10808.567150826	


10648.205871476	


10718.953613067	


10580.000351115	


10744.967205698	


10752.213106954	


10714.805520716	


10536.65477254	


10700.2250713	


10456.984384505	


10537.992521573	


10754.362067551	


10473.684516427	


10662.685091124	


10597.147062461	


10431.079765762	


10536.774024479	


10503.919533483	


10391.513394654	


10391.29040258	


10501.357069539	


10420.768274831	


10497.020951412	


10513.108164064	


10616.364292751	


10380.280303718	




10394.817276024	


10536.083686758	


10559.368109055	


10437.93634188	


10477.171570944	


10499.202789609	


10542.267053527	


10614.250678406	


10516.461425119	


10611.375298346	


10443.699256049	


10399.859153382	


10372.782298091	


10423.560386723	


10424.444856158	


10341.65572607	


10407.769054136	


10376.535077732	


10368.550405093	


10411.433023953	


10394.082025743	


10415.0862126	


10384.933516437	


10488.860038428	


10429.720163806	


10385.593224935	


10232.440688485	


10343.774397146	


19702.108316577	


18570.333225643	


12092.251852255	


11319.768893652	


11028.604995664	


10727.28071216	


10841.498900726	


10558.253772369	


10537.0218475	


10531.203595631	


10333.754046775	


10399.123326073	


10510.329641482	


10393.035354081	


10438.704066182	


10507.581913271	


10419.025821824	


10413.549760423	


10291.844955757	


10340.31669404	


10261.09259233	


10462.38383352	


10442.41146209	


10354.03293329	


10372.947193696	


10236.742061765	


10237.698394787	


10220.434998015	


10295.166838162	


10276.313055026	


10409.635121383	


10344.116148823	


10329.018952466	


10278.711165589	


10362.002338338	


10249.898171799	


10252.793350939	


10309.283503954	


10197.246522178	


10450.710929464	


10490.600499617	


10330.926463538	


10370.170694832	


10156.398081718	


10401.264366726	


10316.043044033	


10429.546241867	


10381.350811146	


10280.886751767	


10330.47660203	


10188.64433786	


10200.004263155	


10404.212696484	


10210.535477997	


10287.507197662	


10369.871998551	


10338.299365267	


10288.435668303	


10222.155067374	


10321.759908213	


10351.923608662	


10165.723325192	


10353.267559426	


10250.064516147	


10409.941556297	


10297.411632753	


10300.266024754	


10278.267915836	


10235.748192585	


10395.890623594	


10334.410684843	


10390.78655754	


10298.29148273	


10335.806087752	


10393.134717727	


10279.946459018	


10257.8431174	


10234.261339589	


10394.480358623	


10221.041612613	


10379.961594374	


10253.500971879	


10254.699101568	


10224.977424056	


10317.378186791	


10362.894001765	


10331.612008209	


10175.052380859	


10194.299100046	


10531.210675161	


10196.025749591	


10299.332085629	


10229.548177881	


10220.108327363	


10326.711349133	


10214.450486174	


10264.513006534	


10138.020504937	


10137.907461682	


10203.861946958	


10346.436376216	


10437.117730039	


10209.186047356	


10268.525353437	


10067.053303625	


10298.69765519	


10277.8240986	


10320.82198774	


10218.344267259	


10221.470872455	


9953.5792859502	


10235.87542395	


10123.281814703	


10284.642380312	


10213.132405336	


10251.466615436	


10167.128426944	


10172.928357206	


10296.476498616	


10177.853091327	


10116.789975571	


10190.813233873	


10259.723387708	


10187.628127066	


10233.373777418	


10176.880227632	


10302.783975489	


10121.836684898	


10166.260433716	


10263.015316694	


10103.881665983	


10227.699510086	


10051.636635216	


10265.299077297	


10196.826048039	


10131.175375294	


10182.838154011	


10154.244358784	


10265.552203603	


10169.044658366	


10155.212636466	


10236.040279181	


10156.051847426	


10153.718035674	


10087.785395582	


10265.43959296	


10295.588620043	


10237.738488753	


10314.286803674	


10159.979723898	


10211.524383671	


10244.916778866	


10190.578185477	


10255.453756146	


10181.51073754	


10087.770608103	


10196.626410456	


10165.975335738	


10145.869916461	


10225.127769939	


10099.122365543	


10303.129269544	


10162.176905631	


10189.969427786	


10139.46383086	


10165.608000263	


10118.892429089	


10083.780184196	


10259.611752666	


11543.190540967	


11320.236848981	


10673.16703639	


10604.957810515	


10442.319148911	


10344.886192261	


10304.219961629	


10213.675172005	


10237.148196212	


10255.134165613	


10211.06800324	


10088.356404781	


10186.260042861	


10095.079558978	


10192.192309826	


10110.485658951	


10210.505982307	


10013.999618797	


9995.7506841217	


10072.196126212	


10140.638252095	


10142.423797014	


10216.242317303	


10135.177289475	


10137.60508045	


10144.400603083	


10167.215037658	


10097.354955394	


10026.740132421	


10032.839515635	


10144.783775205	


9976.2463962621	


10148.620965883	


10144.026275176	


10092.684154721	


10086.225525165	


10194.70018158	


10204.405092914	


10004.749804867	


10091.326789836	


10160.282300878	


10151.417834454	


10186.804998641	


10173.840061724	


10123.705226252	


10060.131985546	


10033.29290275	


10148.640498711	


10143.394315296	


10136.95682409	


10036.036229624	


10111.414342861	


10086.190449537	


10123.644685657	


10161.136713472	


10230.699412325	


10103.239480892	


10052.905772639	


10148.199477433	


10334.820058748	


10071.374847968	


10099.431115021	


10166.141465578	


10110.611849015	


10105.261948991	


10133.520020717	


10067.410221341	


10177.580846843	


10120.59952866	


10116.072138116	


10080.529621091	


10067.419999047	


10151.97841815	


10042.661998734	


10060.588048823	


10157.333452457	


10144.908244934	


10075.259174407	


10136.285679594	


10049.48906436	


10066.416671815	


10025.822502798	


10011.713835991	


9957.119385411	


10038.375392832	


10122.607661493	


10000.239414364	


10092.721520781	


10052.491273685	


10074.35483625	


10039.423736535	


10083.547654368	


10011.636609504	


10164.677036907	


10174.718066068	


10143.967869876	


10074.651200736	


10031.311271062	


10082.180718861	


10139.869738043	


10149.201958941	


10178.614023218	


10066.184019778	


10092.434507203	


10078.419478539	


10137.523491936	


10114.232228283	


9978.7892370639	


10034.329485135	


10023.189390948	


10030.417584241	


10175.577365223	


10190.855751527	


10106.616883749	


10097.470273696	


10093.28034395	


10156.768252823	


9958.2943201416	


9999.8615275733	


10012.106377673	


10106.429613762	


10169.562960644	


10155.178900596	


10139.310168574	


10056.984451941	


10112.943224607	


10166.32892721	


9990.5836905357	


10034.835751614	


10051.661448452	


10008.181742788	


9982.5882169475	


9952.9592770406	


10093.990809849	


10082.665538551	


10211.831959579	


10089.556496197	


10081.100851125	


9962.71588274	


10063.573233429	


10052.027604422	


10025.41393417	


10011.987928496	


10066.651909412	


10091.93472832	


10005.863082647	


10107.800600575	


10059.629438537	


10020.775839879	


9991.7209940949	


10025.846547815	


10082.57748864	


10113.613937659	


10133.236635624	


10130.221766209	


10032.91965221	


In [12]:
Y_pred = predict(X_test, W, b)
_, Y_pred = torch.max(Y_pred, 2)
--_,Y_true = torch.max(Y_test, 2)
--acc_score = predict_score(Y_pred, Y_true)
--print(acc_score)

In [13]:
write2file("MNB_6.csv", Y_pred)


