In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import numpy as np
import pickle

from common import Data, Split, Batches, load_data, encode_y, load_split
from utils import ProgressBar, evaluate
from pack import Pack
from model import Model, TextOnlyModel
from torch_models import MultiLayerFCReLUClassifier, Encoder2
from cls import BRClassifier
from train import train_epoches

%load_ext autoreload
%autoreload 2

In [2]:
genre_list = pickle.load(open("../data/tmdb_genres_list.pkl", 'rb'))

GENRES = load_data("../local/genres.pkl")
train = load_split("../local/train.pkl")
val = load_split("../local/val.pkl")
test = load_split("../local/test.pkl")
embedding = torch.load('../local/embedding.pth')
OVERVIEWS_ENCODED = load_data("../local/overviews_encoded.pkl")
TITLES_ENCODED = load_data("../local/titles_encoded.pkl")

In [3]:
classifier = BRClassifier(dims=[1024, 512], num_class=19, encoding_size=1024, cuda=True)
encoder = Encoder2(encoder=torch.nn.LSTM, embedding=embedding, input_channel=embedding.embedding_dim,
                  hidden_dim=512, num_layers=3, cuda=True, bidirectional=True, dropout=0)
model = TextOnlyModel(encoder, classifier, OVERVIEWS_ENCODED, GENRES)
loss = torch.nn.BCEWithLogitsLoss().cuda()

In [4]:
adam = torch.optim.Adam(filter(lambda p:p.requires_grad, model.parameters()))

optimizer = adam
scheduler = None

In [5]:
loss_hist = []
n_epochs = 10
for i in range(0,20):
    epoch_losses = train_epoches(n_epochs=n_epochs, model=model, train=train, loss=loss, val=val,
                  batch_size=32, optimizer=optimizer, scheduler=scheduler)
    loss_hist.append(epoch_losses)
    bn = (i+1)*n_epochs
    torch.save(model.encoder, "saved/overview-lstm2/encoder_{}_{}.pth".format(bn, str(epoch_losses[1][1][-1])[:4]))
    torch.save(model.classifier, "saved/overview-lstm2/cls_{}_{}.pth".format(bn, str(epoch_losses[1][1][-1])[:4]))

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 0: 0.2301148
Val:	P(mi) 0.0 	R(mi): 0.0 	F1(mi): 0.0
P(ma) 0.0 	R(ma): 0.0 	F1(ma): 0.0
P(w) 0 	R(w): 0 	F1(w): 0


.

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


.....................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 1: 0.2256135
Val:	P(mi) 0.0 	R(mi): 0.0 	F1(mi): 0.0
P(ma) 0.0 	R(ma): 0.0 	F1(ma): 0.0
P(w) 0 	R(w): 0 	F1(w): 0


......................................5.....................................10........

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.13446751
Val:	P(mi) 0.355224609375 	R(mi): 0.5517633674630261 	F1(mi): 0.4321996138422694
P(ma) 0.24829782222821242 	R(ma): 0.35273280459989415 	F1(ma): 0.2813262082764469
P(w) 0.43400543566191613

Val:	P(mi) 0.47216796875 	R(mi): 0.550997150997151 	F1(mi): 0.5085458848277675
P(ma) 0.3868311490316593 	R(ma): 0.5142204354550869 	F1(ma): 0.43462547957315556
P(w) 0.4924885948993364 	R(w): 0.550997150997151 	F1(w): 0.5164211834512606


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95............

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 4: 0.010382361
Val:	P(mi) 0.496337890625 	R(mi): 0.5528963829208594 	F1(mi): 0.5230927569792871
P(ma) 0.4295941741995723 	R(ma): 0.5426673844924571 	F1(ma): 0.4744371225895801
P(w) 0.5082087938423557 	

Val:	P(mi) 0.492431640625 	R(mi): 0.5561069754618142 	F1(mi): 0.5223358798394406
P(ma) 0.42957672830488935 	R(ma): 0.5277425302382298 	F1(ma): 0.4665925907328186
P(w) 0.4985725899187362 	R(w): 0.5561069754618142 	F1(w): 0.5237074360065753


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.........

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.005424143
Val:	P(mi) 0.501220703125 	R(mi): 0.5495182012847966 	F1(mi): 0.5242594484167518
P(ma) 0.4310078271203549 	R(ma): 0.5492017061753984 	F1(ma): 0.47048586523118957
P(w) 0.5137339147159962 

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 7: 0.0042636204
Val:	P(mi) 0.47802734375 	R(mi): 0.5726820707809301 	F1(mi): 0.5210911510312708
P(ma) 0.4170699729455803 	R(ma): 0.5560627869467455 	F1(ma): 0.4708707900602363
P(w) 0.49447766793787207 

Val:	P(mi) 0.484375 	R(mi): 0.566533409480297 	F1(mi): 0.5222426954461701
P(ma) 0.418188233447275 	R(ma): 0.5750507656664019 	F1(ma): 0.4766415298914996
P(w) 0.49834644357442254 	R(w): 0.566533409480297 	F1(w): 0.5272029948310643


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95..................

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 2: 0.0036232325
Val:	P(mi) 0.4873046875 	R(mi): 0.5709382151029748 	F1(mi): 0.5258166491043204
P(ma) 0.4085601691607773 	R(ma): 0.5893660302239354 	F1(ma): 0.4711341837297902
P(w) 0.5101125161341326 	R

Val:	P(mi) 0.511474609375 	R(mi): 0.5557029177718833 	F1(mi): 0.5326722603610475
P(ma) 0.4298116825810588 	R(ma): 0.5348310037492194 	F1(ma): 0.47040663206109334
P(w) 0.5294396649694694 	R(w): 0.5557029177718833 	F1(w): 0.537964379115973


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95..........

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 7: 0.002796664
Val:	P(mi) 0.503173828125 	R(mi): 0.5580828594638505 	F1(mi): 0.5292078572345615
P(ma) 0.4246052951848302 	R(ma): 0.562260536982586 	F1(ma): 0.47529454367195656
P(w) 0.5233684575367241 	

Val:	P(mi) 0.484375 	R(mi): 0.5804564072557051 	F1(mi): 0.5280809156241683
P(ma) 0.41908238142744453 	R(ma): 0.5689107037331875 	F1(ma): 0.4756775770317976
P(w) 0.5038615179488082 	R(w): 0.5804564072557051 	F1(w): 0.5350329824032334


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95...............

KeyboardInterrupt: 

In [None]:
evaluate(test, model, 128)

In [None]:
torch.save(optimizer, "saved/overview-lstm2/optimizer.pth")