In [3]:
import numpy as np
import torch
from torch.autograd import Variable
from torch import nn
import loader
import threes

MOVES = [0, 1, 2, 3]
CUDA = False

FILENAME = "saved_parameters"
INPUT_SIZE = 20
HIDDEN_SIZE = 256


def train(model, data_loaders, optimizer, num_epochs=100, log_every=100, verbose=True):
    if CUDA:
        model.network.cuda()

    iter_ = 0
    epoch = 0
    if verbose:
        print u'Training the model!'
        print u'Interrupt at any time to get current model'
    try:
        while epoch < num_epochs:
            epoch += 1
            x = data_loaders.get(model)
            future = x[:, 21:]
            
            future_scores = np.zeros((x.shape[0], len(MOVES)))
            
            for y in future:
                m = Threes(y)
            
            for i, move in enumerate(MOVES):
                future_scores[:, i] = model.Q(np.hstack((future, np.full((x.shape[0],1), move)))).ravel()
            for i, row in enumerate(future):
                game = threes.Threes(save_game=False, data=row.tolist())
                for j, move in enumerate(threes.MoveEnum):
                    if not game.canMove(move):
                        future_scores[i, j] = float('-inf')
                if not game.getPossibleMoves():
                    future_scores[i:,:] = np.full((1,4), x[i, 20])
                
            y = x[:, 20] + np.max(future_scores, axis=1)
            x = x[:, :20]
            if CUDA:
                x = x.cuda()
                y = y.cuda()
            iter_ += 1

            optimizer.zero_grad()
            out = model.Q(x, as_variable=True)
            loss = model.loss(out, y)
            loss.backward()
            optimizer.step()

            if iter_ % log_every == 0 and verbose:
                print u"Minibatch {0: >6}  | loss {1: >5.2f} ".format(iter_, loss.data[0])

    except KeyboardInterrupt:
        pass
    model.save_parameters(FILENAME)


class QLearningNet(object):
    def __init__(self, network, criterion):
        self.network = network
        self.criterion = criterion
            
    def Q(self, batch, as_variable=False):
        batch = Variable(torch.FloatTensor(batch), requires_grad=False)
        if as_variable:
            return self.network.forward(batch) 
        else:
            return self.network.forward(batch).data.numpy()

    def loss(self, out, y):
        y = Variable(torch.FloatTensor(y), requires_grad=False)
        #print(out, y)
        return self.criterion(out, y)
    
    def save_parameters(self, filename):
        torch.save(self.network.state_dict(), filename)
        
    def load_parameters(self, filename):
        self.network.load_state_dict(torch.load(filename))


network = nn.Sequential(nn.Linear(INPUT_SIZE,HIDDEN_SIZE), nn.ReLU(), nn.Linear(HIDDEN_SIZE,1))
criterion = nn.MSELoss()
q_learning_net = QLearningNet(network, criterion)
for p in q_learning_net.network.parameters():
    p.requires_grad = True
optimizer = torch.optim.SGD(network.parameters(), lr=0.001, momentum=0.5)
data_loader = loader.Loader()
train(q_learning_net, data_loader, optimizer)

Training the model!
Interrupt at any time to get current model
[ 0.19643228]
[ 0.16971444]
[ 0.15414011]
[ 0.13500652]
[ 0.21520588]
[ 0.20198394]
[ 0.19425091]
[ 0.17996536]
[ 0.34917015]
[ 0.33262733]
[ 0.32743236]
[ 0.31456152]
[ 0.15741025]
[ 0.11890027]
[ 0.07387105]
[ 0.02213898]
[ 0.27517048]
[ 0.21912783]
[ 0.15079381]
[ 0.08863664]
[ 0.16824798]
[ 0.14021651]
[ 0.09977485]
[ 0.05163632]
[-0.14951964]
[-0.20698978]
[-0.24309157]
[-0.27675447]
[-0.41793707]
[-0.44966877]
[-0.48364702]
[-0.51956975]
[-0.4365612]
[-0.48115334]
[-0.52512288]
[-0.08646557]
[-0.11739291]
[-0.15368645]
[-0.1888859]
tu MoveEnum.Down
[-0.2202488]
[-0.23273827]
[-0.25197542]
[-0.27279058]
[-0.20524023]
[-0.2172007]
[-0.21921369]
[-0.21305467]
[-0.14001748]
[-0.1757133]
[-0.20473678]
[-0.23460041]
[-0.08131445]
[-0.14699249]
[-0.19694956]
[-0.23479865]
[-0.35625729]
[-0.38551569]
[-0.41956383]
[-0.44544291]
[-0.80055439]
[-0.81873423]
[-0.82224005]
[-0.82359111]
[-0.70704955]
[-0.73943812]
[-0.76769984]
[

[ 7.80636024]
[ 7.82105732]
[ 7.79150295]
[ 7.85941219]
[ 7.82545185]
[ 7.79149103]
[ 6.8131485]
[ 6.92362022]
[ 7.03922844]
[ 7.15483618]
[ 6.68851852]
[ 6.92662144]
[ 3.48909116]
[ 3.49153066]
[ 3.49504232]
[ 3.49828386]
[ 4.14030695]
[ 4.14823246]
[ 4.64596033]
[ 4.45334339]
[ 37.94160461]
[ 37.59373856]
[ 30.34679985]
[ 30.16838264]
[ 29.99749184]
[ 29.82962799]
[ 30.99729538]
[ 30.65321541]
[ 0.51271617]
[ 0.44638336]
[ 0.44426727]
[ 0.56585914]
[ 1.68758249]
[ 1.53712821]
[ 1.44573092]
[ 1.41814065]
[ 2.26946878]
[ 2.06262565]
[ 1.8918016]
[ 1.87946486]
[ 2.2836287]
[ 2.12462354]
[ 2.05995345]
[ 2.03165936]
[ 3.0410769]
[ 2.89887881]
[ 2.80369616]
[ 2.70706177]
[ 3.275949]
[ 3.12140346]
[ 3.0043838]
[ 2.90428472]
[ 3.57793403]
[ 3.41407275]
[ 3.28565717]
[ 3.19238544]
[ 4.25326633]
[ 4.10666847]
[ 3.98997617]
[ 3.8743732]
[ 4.31308794]
[ 4.16136599]
[ 4.04369593]
[ 3.94207358]
[ 5.59258795]
[ 5.37449074]
[ 5.17860317]
[ 5.09757805]
[ 5.74421024]
[ 5.50980234]
[ 5.32863998]
[ 5.19

[-1.08005667]
[-0.99046397]
[-0.88044345]
[-0.75507247]
[-1.44933605]
[-1.29940295]
[-1.13956523]
[-0.4879393]
[-0.34738135]
[-0.20624381]
[-0.15490785]
[ 0.122293]
[ 0.09538442]
[ 0.07296389]
[ 0.05054337]
[-0.50060898]
[-0.55191624]
[-0.59416831]
[-0.60615963]
[-0.62194443]
[-0.64959562]
[-0.67755181]
[-0.63523555]
[-0.50606227]
[-0.45859194]
[-0.45998171]
[-0.08902568]
[-0.10177863]
[-0.11453152]
[-0.13043368]
[-0.52529699]
[-0.54752564]
[-0.54986167]
[-0.55219764]
[-0.53264451]
[-0.56305933]
[-0.57742518]
[-0.55306417]
[-0.58391279]
[-0.61314178]
[-0.54649377]
[-0.57734239]
[-0.60819101]
[-0.62680137]
[-0.57475811]
[-0.60560673]
[-0.63821363]
[-1.6116221]
[-1.61610413]
[-1.62058592]
[-1.59678674]
[-0.50923777]
[-0.51113993]
[-0.51304209]
[-1.29582191]
[-1.27294469]
[-1.24982834]
[-1.22451615]
[-0.51341939]
[-0.4905422]
[-0.479139]
[-0.48104113]
[ 0.04814428]
[ 0.02542698]
[-0.00419915]
[-0.03479409]
[-0.8714003]
[-0.88048398]
[-0.87620187]
[-0.8567729]
[-0.63346684]
[-0.58854246]
[

[ 19.16117859]
[ 19.09155273]
[ 19.02192879]
[ 18.95230293]
[ 18.10639]
[ 18.03166008]
[ 17.95692825]
[ 17.88219833]
[ 18.01095009]
[ 17.93618011]
[ 17.86141205]
[ 18.12197876]
[ 18.04725075]
[ 17.97251892]
[ 17.897789]
[ 18.23346329]
[ 18.14376831]
[ 18.06903839]
[ 17.99430656]
[ 18.41147232]
[ 18.33276367]
[ 18.25803375]
[ 18.18330383]
[ 18.36343575]
[ 18.28870392]
[ 18.213974]
[ 18.96373177]
[ 18.88238907]
[ 18.80147171]
[ 18.72055435]
[ 19.50178719]
[ 19.43159103]
[ 19.36139488]
[ 19.29119873]
[ 19.17897987]
[ 19.10878563]
[ 19.03858757]
[ 18.96839142]
[ 18.88093948]
[ 18.79566383]
[ 18.72546959]
[ 19.4238205]
[ 19.27435875]
[ 19.95078278]
[ 19.80815125]
[ 19.8583107]
[ 19.71195984]
[ 20.10866547]
[ 20.05038071]
[ 19.99804688]
[ 19.94893265]
[ 18.19192886]
[ 18.10172653]
[ 18.02775574]
[ 17.96114349]
[ 18.35020447]
[ 18.18756104]
[ 25.8717308]
[ 25.80233192]
[ 25.73293495]
[ 24.97005653]
[ 24.9205265]
[ 24.87099457]
[ 25.86140823]
[ 25.76242065]
[ 4.63086414]
[ 4.5725565]
[ 4.55092

[-0.91331673]
[-0.89713669]
[-0.88095665]
[-0.86477661]
[-0.72052526]
[-0.70434499]
[-0.68816519]
[-0.67198491]
[-0.65657353]
[-0.64039373]
[-0.62421346]
[-0.60803366]
[-1.85644889]
[-1.84026897]
[-1.82408881]
[-3.09917355]
[-3.07805181]
[-3.03580832]
[-3.49740553]
[-3.47399139]
[-3.45057726]
[-3.42716312]
[-3.47074938]
[-3.44733524]
[-3.42392111]
[-3.40050697]
[-3.38390231]
[-3.32325363]
[-3.26260519]
[-3.20195651]
[-3.49850607]
[-3.44112396]
[-3.38047528]
[-3.31982684]
[-3.46742082]
[-3.40677214]
[-3.34612346]
[-3.28547502]
[-3.44218087]
[-3.38345242]
[-3.3247242]
[-3.3062973]
[-3.27098322]
[-3.23566914]
[-3.19774961]
[-3.584939]
[-3.56152487]
[-3.53811073]
[-3.51878071]
[-3.0588007]
[-3.01658702]
[-2.97326851]
[-2.91236496]
[-3.11362863]
[-3.07141495]
[-3.0276463]
[-2.96674275]
[-3.15109038]
[-3.11360121]
[-3.05384922]
[-3.11616755]
[-3.05791593]
[-2.99966431]
[-2.94141245]
[-3.02280927]
[-2.98532009]
[-2.94591522]
[-2.89088821]
[-3.1176219]
[-3.07985687]
[-3.04209208]
[-2.99788165]

[-14.99494743]
[-15.57263279]
[-16.15032005]
[-16.72800446]
[-27.30477524]
[-27.88246155]
[-28.46014595]
[-29.03783226]
[-35.80801773]
[-36.40738678]
[-37.00675583]
[-37.60612488]
[-37.61502457]
[-38.2169342]
[-38.82148361]
[-39.42713165]
[-38.05656052]
[-38.65966797]
[-39.26276779]
[-39.86587524]
[-37.52877045]
[-38.13067627]
[-38.73546982]
[-39.34111404]
[-38.24164581]
[-38.81933594]
[-39.40074539]
[-40.00638962]
[-38.86051559]
[-39.4382019]
[-40.03559113]
[-31.61539459]
[-32.21993637]
[-32.82557678]
[-33.43136597]
[-48.30915451]
[-48.93902206]
[-49.56889343]
[-50.1987648]
[-47.04806519]
[-47.67793655]
[-48.30780411]
[-48.93767166]
[-46.04859161]
[-46.67070389]
[-47.29341888]
[-47.9161377]
[-46.70994949]
[-47.32846451]
[-47.94712448]
[-48.56881332]
tu MoveEnum.Right
[-48.88885498]
[-49.4665451]
[-50.0442276]
[-50.6219101]
[-53.74288559]
[-54.3205719]
[-54.89825821]
[-52.40450287]
[-52.98218918]
[-53.55987549]
[-54.1375618]
[-52.85337448]
[-53.43106079]
[-54.0087471]
[-51.3258934]
[-5

[ 37.25411224]
[ 37.0679245]
[ 62.72046661]
[ 62.53465271]
[ 62.34883118]
[ 62.16300964]
[ 62.51654816]
[ 62.33072662]
[ 62.14491272]
[ 61.95909119]
[ 63.43190002]
[ 63.24607849]
[ 63.06025696]
[ 62.87444305]
[ 61.81295776]
[ 61.62713623]
[ 61.44132996]
[ 61.25550842]
[ 62.07287598]
[ 61.88706207]
[ 61.70124054]
[ 66.74200439]
[ 66.55618286]
[ 66.37036896]
[ 66.18455505]
[ 67.16776276]
[ 66.98194122]
[ 66.79612732]
[ 66.61030579]
[ 66.58601379]
[ 66.40019226]
[ 66.21437836]
[ 66.02855682]
[ 64.8189621]
[ 64.69880676]
[ 64.5786438]
[ 64.45848083]
[ 65.69830322]
[ 65.57814026]
[ 65.45798492]
[ 67.01016998]
[ 66.89001465]
[ 66.76985168]
[ 69.23130798]
[ 69.04548645]
[ 68.85967255]
[ 68.67385101]
[ 74.59314728]
[ 74.40733337]
[ 74.22151947]
[ 74.03570557]
[ 75.44145203]
[ 75.25563049]
[ 75.06980896]
[ 74.88399506]
[ 73.47180176]
[ 73.28720093]
[ 73.16703796]
[ 76.55053711]
[ 76.1789093]
[ 75.99308777]
[ 76.67593384]
[ 76.49011993]
[ 76.3042984]
[ 76.13037872]
[ 110.87552643]
[ 110.59680176

[-3.96862435]
[-3.98542213]
[-4.00222015]
[-4.00830221]
[-3.90973997]
[-3.92653775]
[-3.94333577]
[-4.00830221]
[-4.00830221]
[-4.00830221]
[-4.00830221]
[-4.00830221]
[-4.00830221]
[-4.00830221]
[-4.12869692]
[-8.5794487]
[-8.73078728]
[-8.88212585]
[-9.03346539]
[-4.00830221]
[-4.00830221]
[-4.03210974]
[-4.18338728]
[-4.43366051]
[-4.58493805]
[-4.73621607]
[-4.88749361]
[-4.89036989]
[-5.05844545]
[-5.22652102]
[-4.74372482]
[-4.89500237]
[-5.04628038]
[-5.19755793]
[-4.54274702]
[-4.691679]
[-4.84061098]
[-4.80511618]
[-4.95404768]
[-5.10298014]
[-5.25191212]
[-5.18560219]
[-5.33453417]
[-5.48346615]
[-5.63239813]
[-5.3961606]
[-5.72742271]
[-5.91754103]
[-6.21540451]
[-6.36433697]
[-6.31210232]
[-6.46148205]
[-6.61086178]
[-6.82595825]
[-6.97533798]
[-7.12471724]
[-7.27409649]
[-7.25760174]
[-7.40698051]
[-7.55636024]
[-7.7057395]
[-8.30205345]
[-8.46813202]
[-8.63421059]
[-8.80028915]
[-10.22329712]
[-10.37267876]
[-10.52205658]
[-10.67143631]
[-7.42425346]
[-7.57363319]
[-7.723

[-82.7037735]
[-54.97911453]
[-56.64258957]
[-58.30606461]
[-52.51561737]
[-54.17909241]
[-55.84257126]
[-57.5060463]
[-61.51742935]
[-63.18090439]
[-64.84438324]
[-66.50785828]
[-53.09771729]
[-54.76119614]
[-56.42467117]
[-58.08815002]
[-51.52110291]
[-53.18457794]
[-54.84805298]
[-56.51153183]
[-45.89135361]
[-47.55483246]
[-49.21831131]
[-50.88178635]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.37247753]
[-84.77911377]
[-87.53296661]
[-90.28682709]
[-93.04067993]
[-5.0216856]
[-5.0216856]
[-6.26401806]
[-9.01767254]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.1269021]
[-178.4655304]
[-181.21589661]
[-183.96624756]
[-186.71661377]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-5.0216856]
[-197.85929871]
[-200.61296082]
[-203.36662292]
[-206.1202697

[-5.1613698]
[-5.1613698]
[-5.1613698]
[-5.1613698]
[-5.1613698]
[-5.1613698]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.0989871]
[-5.09443808]
[-5.08988905]
[-5.08534002]
[-5.08398819]
[-5.08153629]
[-5.0790844]
[-5.0766325]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.12978172]
[-5.10247278]
[-5.09792376]
[-5.09337473]
[-5.0888257]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.12994528]
[-5.13001871]
[-5.13001871]
[-5.1284337]
[-5.10405731]
[-5.1040225]
[-5.10398817]
[-5.10395336]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.12736416]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.13001871]
[-5.

[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
tu MoveEnum.Up
[-4.85355854]
[-4.8518343]
[-4.85097218]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.87486076]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84926748]
[-4.84

[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.67743015]
[-4.67800331]
[-4.67857647]
[-4.67914963]
[-4.68839025]
[-4.68896341]
[-4.68953657]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.71725559]
[-4.71782875]
[-4.71840191]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.7280407]
[-4.71061468]
[-4.711761]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.72842884]
[-4.72917271]
[-4.72991705]
[-4.73066092]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.74058437]
[-4.740