In [1]:
require("optim")
require('hdf5')

<h3>Multinomial Logistic Regression - LBFGS Minibatch - L2 Norm</h3>

In [39]:
neval = 0

function ml(W, X, Y)
    local W = W:reshape(Y:size(2), X:size(2)+1)
    
    --intercept
    local b = W:sub(1, W:size(1), W:size(2),W:size(2)):t()
    
    --coefficient
    W = W:sub(1, W:size(1),1,W:size(2)-1)
    
    --XW^T
    local p = X*W:t()
    
    --XW^T + b
    p:add(b:expand(b,p:size(1),b:size(2)))
    
    local arr = p:clone()
    arr = arr:t()
    
    --predicted Z scores for y_hat
    local vmax = arr:max(1)
    local evmax = torch.expand(vmax,arr:size(1),vmax:size(2))
    arr:csub(evmax)
    arr:exp()
    arr = arr:sum(1)
    arr:log()
    arr:add(vmax)
    arr = arr:t()
    arr:expand(arr, arr:size(1), p:size(2))
    p:csub(arr)
    
    --L2 regularization
    local norm = W:reshape(W:size(1)*W:size(2), 1)
    
    local loss = (torch.sum(torch.cmul(Y,p))*-1) + 1.0 *0.5 * torch.dot(norm, norm)
    
    p:exp()
    
    return loss, p, W
end

function mlg(W, X, Y, bsize)

    local bsize = 1000
    
    --random ordering of ints [1,nexamples] and take first bsize
    local idx = torch.randperm(X:size(1)):sub(1,bsize)
    
    --training minibatches
    local X_batch = torch.Tensor(bsize, X:size(2))
    local Y_batch = torch.Tensor(bsize, Y:size(2))
    
    for i=1,bsize do
        X_batch[i] = X[idx[i]]
        Y_batch[i] = Y[idx[i]]
    end

    --initialize gradient
    local grad = torch.zeros(Y_batch:size(2), X_batch:size(1)+1)
    
    --calculate loss, updated weight matrix
    local loss, p, W = ml(W, X_batch, Y_batch)
    local diff = torch.csub(p,Y_batch)
 
    local grad = diff:t()*X_batch

    grad:add(W)
    grad = grad:cat(torch.zeros(grad:size(1),1), 2)
    grad:sub(1, grad:size(1), grad:size(2), grad:size(2)):add(diff:sum(1))
    neval = neval + 1
    print(neval, loss)
    return loss, grad:reshape(grad:size(1)*grad:size(2), 1)
end


function fit(X, Y, rate, iter, lX, batch)
    --Weight matrix must be passed in as vector
    local W = torch.zeros(Y:size(2) * (X:size(2)+1), 1)   
    
    --define local function for optimization
    local func = function(W)
        loss, grad = mlg(W, X, Y, batch)
        return loss, grad
    end
    
    --optimization parameters
    local state = {learningRate = rate, maxIter=iter, tolX=lX}
    
    --LBFGS with no line search, therefore specify learning rate
    W, f_hist, currentFuncEval = optim.lbfgs(func, W, state)
    
    W = W:reshape(Y:size(2), X:size(2)+1)
    
    --intercept
    b = W:sub(1, W:size(1), W:size(2), W:size(2))
    
    --coefficients
    W = W:sub(1, W:size(1), 1, W:size(2)-1)
    
    return W, b
end

function predict(X, W, b)
    local b = b:t()
    return (X*W:t()):add(b:expand(b, X:size(1), b:size(2)))
end

function predict_score(ypred, ytrue)
    local c = 0
    for i=1,ypred:size(1) do
        if ypred[i][1] == ytrue[i][1] then
            c = c + 1       
        end
    end
    return c/ypred:size(1)
end

<h3>Create Document Word Matrix and One Hot Encoding</h3>

In [3]:
--feature weight: counts
function createDocWordMatrix(vocab, max_sent_len, sparseMatrix)
    docword = torch.zeros(sparseMatrix:size(1), vocab)
    for i=1,sparseMatrix:size(1) do
        for j=1, max_sent_len do
            local idx = (sparseMatrix[i][j])
            if idx ~= 0 then
                docword[i][idx] = 1 + docword[i][idx]
            end
        end
    end
    return docword
end
 
function onehotencode(classes, target)
    onehot = torch.zeros(target:size(1), classes)
    for i=1,target:size(1) do
        onehot[i][target[i]] = 1
    end
    return onehot
end

function write2file(fname, pred) 
    f = io.open(fname, "w")
    f:write("ID,Category\n")
    for i=1,pred:size(1) do
        f:write(tostring(i) .. "," .. tostring(pred[i][1]) .. "\n")
    end
    f:close()
end

In [23]:
f = hdf5.open("SST1.hdf5", "r")

X_train = f:read("train_input"):all()
Y_train = f:read("train_output"):all()
X_valid = f:read("valid_input"):all()
Y_valid = f:read("valid_output"):all()
X_test = f:read("test_input"):all()
nclasses = f:read('nclasses'):all():long()[1]
nfeatures = f:read('nfeatures'):all():long()[1]

f:close()

In [24]:
X_train =createDocWordMatrix(nfeatures, 53, X_train)
Y_train = onehotencode(nclasses, Y_train)
X_test = createDocWordMatrix(nfeatures, 53, X_valid)
Y_test = onehotencode(nclasses, Y_valid)

In [None]:
start_time = os.time()
W, b = fit(X_train, Y_train, 0.1, 10000)
end_time = os.time()
print(end_time - start_time)

1	1609.4379124341	


2	1607.3586826078	


3	18251.072658293	


4	17077.28110091	


5	50078.283463156	


6	45655.192875743	


7	40690.586748682	


8	15955.793744925	


9	22535.717145094	


10	20457.856570249	


11	18811.384282247	


12	16428.025272242	


13	3938.6189796809	


14	2399.1747518115	


15	2362.041152347	


16	2141.7695846278	


17	1875.2775135764	


18	1888.8720719079	


19	1594.2998504454	


20	1623.1160187209	


21	1602.5167124189	


22	1539.1112138393	


23	1516.1137135731	


24	1376.1617924826	


25	1437.5744039202	


26	1478.4896024227	


27	1465.5517501183	


28	1450.7652504355	


29	1467.2628944844	


30	1292.885404691	


31	1340.1996716273	


32	1325.1492772716	


33	1401.8547695907	


34	1344.334944316	


35	1284.3557639284	


36	1321.1709847512	


37	1294.9883284059	


38	1282.4550420848	


39	1285.8798751846	


40	1330.3688489913	


41	1230.3209867439	


42	1215.1129075184	


43	1203.3883801113	


44	1227.2995988727	


45	17159.526012995	


46	15047.170647128	


47	2000.4346729277	


48	1333.8692626627	


49	1298.6291138694	


50	1349.0978946307	


51	1326.4745725582	


52	1255.3581884336	


53	1207.3040133467	


54	1198.2386104713	


55	1227.5862188418	


56	1199.8200441249	


57	1233.7177939593	


58	1239.6282774099	


59	1226.6982948477	


60	1175.2516270314	


61	1245.7975004183	


62	1219.707405385	


63	1191.2522871732	


64	1168.5120057353	


65	1258.1900018342	


66	1222.0308565226	


67	1238.3975591612	


68	1225.6659923171	


69	1181.5932848435	


70	1200.9087183798	


71	1185.3737389738	


72	1187.5676838665	


73	1249.8316497253	


74	1235.2899631342	


75	1251.2922685846	


76	1220.5827779924	


77	1209.5150543579	


78	1241.5520334768	


79	1216.5463076308	


80	1185.891467001	


81	1226.2585900468	


82	1207.9505747696	


83	1198.9992390243	


84	1184.4843345826	


85	1193.8153977104	


86	1324.6641903235	


87	1331.1100864321	


88	1313.757828405	


89	1189.0643940134	


90	1170.7428525376	


91	1232.6700107459	


92	1237.3924591328	


93	1226.515599129	


94	1238.233736851	


95	1185.8590239893	


96	1192.6612229638	


97	1195.6778821103	


98	1181.0464566409	


99	1148.8045353824	


100	1202.0594042148	


101	1178.4350548825	


102	1200.5005420231	


103	1181.7772307751	


104	1242.7252830448	


105	1187.7604808948	


106	1216.7639782623	


107	1191.7732488377	


108	1202.9248580169	


109	1221.5290492045	


110	1182.3096471949	


111	1132.7017305831	


112	1183.5676962438	


113	1207.8177314065	


114	1215.9505418266	


115	1176.1182501606	


116	1260.5124962543	


117	1200.8750639866	


118	1190.4781832041	


119	1193.9303292134	


120	1219.8534984667	


121	1184.7545785428	


122	1223.8242569083	


123	1195.3514599279	


124	1188.5306645706	


125	1170.4550468307	


126	1152.6050752602	


127	1208.6480753245	


128	4389.1377032101	


129	3928.3879993094	


130	2011.904154879	


131	1708.8554368827	


132	1572.8278223783	


133	1468.5808832987	


134	1402.8193486399	


135	1431.1233980988	


136	1323.6199491854	


137	1333.5074958157	


138	1285.7656617906	


139	1260.083474453	


140	1326.8984513107	


141	1245.5598591674	


142	1274.4794115756	


143	1252.4940767473	


144	1246.9178446692	


145	1202.3206110398	


146	1166.7639717301	


147	1167.4524654042	


148	1193.7180795944	


149	1208.4960809611	


150	1225.7039282428	


151	1199.6664582233	


152	1167.0556655255	


153	1165.897446112	


154	1169.1229286922	


155	1188.2928973748	


156	1191.3248559632	


157	1179.5257011515	


158	1166.3891004319	


159	1132.2590512358	


160	1210.498759802	


161	1200.6035539054	


162	1192.2929282243	


163	1186.1192196392	


164	1164.2717149119	


165	1149.7176741104	


166	1202.0003829196	


167	1193.5929840187	


168	1249.7078616769	


169	1170.7286645451	


170	1182.107126536	


171	1182.254567052	


172	1160.3452579887	


173	1175.9879341114	


174	1172.5587265641	


175	1178.9892280981	


176	1189.2475645998	


177	1151.6378593273	




178	1193.2184901298	


179	1168.0789249034	


180	1179.3301179241	


181	1174.9242048629	


182	1213.323950671	


183	1172.9343842516	


184	1161.5560668624	


185	1178.3151923967	


186	1153.6008461214	


187	1189.6548569838	


188	1207.3235619136	


189	1190.7074659103	


190	1203.5875723237	


191	1158.3567774684	


192	1191.6561778995	


193	1214.3522831711	


194	1181.0994569822	


195	1156.2011527683	


196	1160.3540307327	


197	1132.9201274522	


198	1216.0946099936	


199	1150.6139694807	


200	1236.8218042793	


201	1148.6479051484	


202	1158.4765136214	


203	1159.3052418406	


204	1167.2179522312	


205	1142.1154445126	


206	1179.2654990282	


207	1170.6712991231	


208	1169.244766499	


209	1213.1556800221	


210	1162.9170105029	


211	1182.7672037788	


212	1157.1238051136	


213	1185.5967311302	


214	1176.9354403068	


215	20937.950734902	


216	19010.871975506	


217	6010.9229362742	


218	34556.147510946	


219	28787.317110344	


220	17330.629991166	


221	7281.1669665221	


222	6611.1049813915	


223	5162.9954620694	


224	4308.0333547137	


225	3549.978603214	


226	2597.1242792062	


227	2212.4571358121	


228	2160.3887150884	


229	1829.8134771227	


230	1643.00412071	


231	1445.8983294364	


232	1538.1977204818	


233	1397.0978065192	


234	1444.1539908356	


235	1438.5805333302	


236	1432.5034546104	


237	1471.6683927817	


238	1391.0856048821	


239	1330.9651507004	


240	1322.9746648432	


241	1289.1155354393	


242	1329.9867293121	


243	1299.1127803608	


244	2274.1198305139	


245	2194.6939405395	


246	2065.0043121457	


247	1608.4981929181	


248	1454.9483546927	


249	1357.6087789339	


250	1332.1906858954	


251	1284.3448632288	


252	1328.9936298514	


253	1248.593689092	


254	1262.3449857601	


255	

1225.8751506055	


256	1211.294435472	


257	1209.2673091395	


258	1248.0031669654	


259	1301.2429317296	


260	1381.4072737139	


261	1370.8335925622	


262	1326.3185435661	


263	1234.8582591392	


264	1207.0731665589	


265	

1227.7688113663	


266	1206.5728708473	


267	1199.47586913	


268	1182.2327733231	


269	1201.1450488882	


270	1219.4349590319	


271	1248.1597596099	


272	1209.7994047267	


273	1190.5755327428	


274	1243.148186046	


275	1216.2037110678	


276	1151.9065174473	


277	1223.6515468412	


278	1210.6680846372	


279	1219.9709883612	


280	1184.5855081513	


281	1162.4942300043	


282	1199.0223624666	


283	1196.6473831412	


284	1184.2980674505	


285	1170.257684122	


286	1200.3017058726	


287	1212.7994337831	


288	1167.5677721659	


289	1210.52666929	


290	1176.7805685566	


291	1215.9304174236	


292	1173.7442048521	


293	1182.7793951039	


294	1178.9365104293	


295	1208.0959353941	


296	1157.1621000528	


297	1204.4734867799	


298	1172.4899178912	


299	5373.1992841778	


300	4298.0595992099	


301	3868.6045975559	


302	2299.0558367115	


303	1600.6694216704	




304	1339.081089076	


305	1293.4616949818	


306	1270.3237917588	


307	1217.4205959282	


308	1219.6967638404	


309	1186.8240997501	


310	1195.3912043639	


311	1218.9082207461	


312	1158.6413106138	


313	1197.8282020309	


314	1175.5012294579	


In [28]:
Y_pred = predict(X_test, W, b)
_, Y_pred = torch.max(Y_pred, 2)
_,Y_true = torch.max(Y_test, 2)
acc_score = predict_score(Y_pred, Y_true)
print(acc_score)

0.37511353315168	


In [22]:
write2file("MNB_7.csv", Y_pred)


