In [1]:
import torch
import numpy as np
import pickle

from vocab import Vocab, load_vocab
from common import Data, Split, Batches, load_data, encode_y, load_split
from utils import ProgressBar, evaluate
from pack import Pack
from model import Model, Model2, TextOnlyModel

%load_ext autoreload
%autoreload 2

In [2]:
genre_list = pickle.load(open("../data/tmdb_genres_list.pkl", 'rb'))

GENRES = load_data("../local/genres.pkl")
train = load_split("../local/train.pkl")
val = load_split("../local/val.pkl")
test = load_split("../local/test.pkl")
embedding = torch.load('../local/embedding.pth')
OVERVIEWS_ENCODED = load_data("../local/overviews_encoded.pkl")
TITLES_ENCODED = load_data("../local/titles_encoded.pkl")

In [3]:
_model = Model2(embedding, hidden_dim=512, num_layers=3, cuda=True)
tomodel = TextOnlyModel(_model.encoder, _model.classifier, OVERVIEWS_ENCODED, GENRES)
loss = torch.nn.BCEWithLogitsLoss().cuda()

In [4]:
def train_batches(model:Model, batch_size:int, display:bool=True):
    losses = []

    batches = Batches(train, batch_size)

    pb = ProgressBar(batches.batch_N, display=display)

    train.shuffle()
    pb.reset()
    for i in range(batches.batch_N):
        model.zero_grad()
        optimizer.zero_grad()

        y_true = model.get_y(batches, i)
        y_true = torch.autograd.Variable(torch.from_numpy(y_true)).cuda().type(torch.cuda.FloatTensor)
        model_output = model.predict(batches, i)

        l = loss(model_output, y_true)
        l.backward()

        optimizer.step()

        losses.append(l.data.cpu().numpy()[0])

        pb.tick()
        
    return losses

def train_epoches(model:Model, n_epochs:int, batch_size:int):
    epoch_losses = []
    for epoch in range(n_epochs):
        losses = train_batches(model, batch_size, display=False)
        epoch_losses.append(losses)
        print("epoch {}:".format(epoch), np.mean(losses))
        print("Train:", end="\t")
        evaluate(train, model, batch_size)
        print("Val:", end="\t")
        evaluate(val, model, batch_size)
    return epoch_losses

In [5]:
optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad, _model.parameters()))

In [6]:
loss_hist = []
for i in range(10):
    epoch_losses = train_epoches(tomodel, 20, 32)
    loss_hist.append(epoch_losses)
    bn = (i+1)*50
    torch.save(tomodel.encoder, "saved/encoder_{}.pth".format(bn))
    torch.save(tomodel.classifier, "saved/rnn_cls_{}.pth".format(bn))

epoch 0: 0.29861966
Train:	

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


P 0.0 	R: 0.0 	F1: 0.0
Val:	P 0.0 	R: 0.0 	F1: 0.0
epoch 1: 0.2325613
Train:	P 0.0002150471567693773 	R: 0.5384615384615384 	F1: 0.0004299226139294927
Val:	P 0.00048828125 	R: 1.0 	F1: 0.0009760858955588092
epoch 2: 0.2181127
Train:	P 0.0022119136124850234 	R: 0.5625 	F1: 0.004406499586890663
Val:	P 0.002685546875 	R: 0.5238095238095238 	F1: 0.005343696866650473
epoch 3: 0.21074232
Train:	P 0.010107216368160732 	R: 0.5264 	F1: 0.019833614661200868
Val:	P 0.00830078125 	R: 0.5074626865671642 	F1: 0.01633437424933942
epoch 4: 0.20542252
Train:	P 0.03852416208411416 	R: 0.616519174041298 	F1: 0.07251698713315022
Val:	P 0.03271484375 	R: 0.5851528384279476 	F1: 0.06196531791907514
epoch 5: 0.19967698
Train:	P 0.06122699763448128 	R: 0.6228125 	F1: 0.11149338480042517
Val:	P 0.054443359375 	R: 0.5883905013192612 	F1: 0.09966480446927375
epoch 6: 0.19395664
Train:	P 0.12091794414918128 	R: 0.5664939550949913 	F1: 0.19929618471353702
Val:	P 0.112060546875 	R: 0.5563636363636364 	F1: 0.1865474

epoch 11: 0.022481171
Train:	P 0.9546557709440571 	R: 0.972369985606108 	F1: 0.963431459176239
Val:	P 0.443115234375 	R: 0.5266976204294834 	F1: 0.48130469371519485
epoch 12: 0.01999401
Train:	P 0.9572363368252895 	R: 0.963809582727582 	F1: 0.9605117139334155
Val:	P 0.46923828125 	R: 0.5188984881209503 	F1: 0.4928205128205128
epoch 13: 0.020063225
Train:	P 0.952935393689902 	R: 0.9651213441194773 	F1: 0.9589896585305528
Val:	P 0.452880859375 	R: 0.5296973158195317 	F1: 0.48828639115556727
epoch 14: 0.020375036
Train:	P 0.958925993057049 	R: 0.966228138059124 	F1: 0.9625632169729863
Val:	P 0.457275390625 	R: 0.5227463019815797 	F1: 0.48782393540825625
epoch 15: 0.019253656
Train:	P 0.9681730207981322 	R: 0.9742789130367576 	F1: 0.9712163703041695
Val:	P 0.473388671875 	R: 0.5340126686863124 	F1: 0.5018765368189466
epoch 16: 0.01855374
Train:	P 0.9667598537679334 	R: 0.968694206735209 	F1: 0.967726063625321
Val:	P 0.4716796875 	R: 0.5280131183383439 	F1: 0.4982591876208898
epoch 17: 0.01

epoch 1: 0.009770058
Train:	P 0.9809836871371079 	R: 0.987628355808487 	F1: 0.984294807576715
Val:	P 0.455810546875 	R: 0.5544995544995545 	F1: 0.5003349859305909
epoch 2: 0.009318443
Train:	P 0.9851003041381217 	R: 0.9903026559604694 	F1: 0.9876946296838182
Val:	P 0.458251953125 	R: 0.5477093667931136 	F1: 0.49900305729097444
epoch 3: 0.009609735
Train:	P 0.9837793001751098 	R: 0.9798959608323133 	F1: 0.9818337906823443
Val:	P 0.488525390625 	R: 0.5260252365930599 	F1: 0.5065822784810126
epoch 4: 0.011988357
Train:	P 0.9842093944886486 	R: 0.9776021482408227 	F1: 0.9808946449894369
Val:	P 0.484619140625 	R: 0.54028307022319 	F1: 0.510939510939511
epoch 5: 0.011553955
Train:	P 0.9829805535928235 	R: 0.9850686534080414 	F1: 0.9840234957636891
Val:	P 0.475341796875 	R: 0.55218377765173 	F1: 0.510889530307006
epoch 6: 0.009857952
Train:	P 0.9789560996589967 	R: 0.9815191277028276 	F1: 0.9802359382930618
Val:	P 0.46826171875 	R: 0.5541750939034961 	F1: 0.5076088394865688
epoch 7: 0.0123315

epoch 11: 0.00703427
Train:	P 0.9903843199901693 	R: 0.9923353956967402 	F1: 0.9913588978750884
Val:	P 0.471923828125 	R: 0.5519703026841805 	F1: 0.5088181100289549
epoch 12: 0.008638781
Train:	P 0.9843937206230223 	R: 0.9769207317073171 	F1: 0.9806429893957246
Val:	P 0.479736328125 	R: 0.5289367429340511 	F1: 0.5031366022276277
epoch 13: 0.012100804
Train:	P 0.9817209916746029 	R: 0.992268281322776 	F1: 0.9869664587065291
Val:	P 0.45703125 	R: 0.554995552920249 	F1: 0.5012719239523363
epoch 14: 0.00725058
Train:	P 0.991060182482873 	R: 0.9889941445169993 	F1: 0.9900260856222187
Val:	P 0.479736328125 	R: 0.5360065466448445 	F1: 0.5063128059778407
epoch 15: 0.0050875526
Train:	P 0.9954532886854475 	R: 0.9935913160799705 	F1: 0.9945214308733484
Val:	P 0.47998046875 	R: 0.5349659863945578 	F1: 0.5059837858705445
epoch 16: 0.0051365574
Train:	P 0.9856225615188473 	R: 0.9926363664490578 	F1: 0.9891170304599827
Val:	P 0.455322265625 	R: 0.5524289099526066 	F1: 0.4991970021413276
epoch 17: 0.