In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import numpy as np
import pickle

from vocab import Vocab, load_vocab
from common import Data, Split, Batches, load_data, encode_y, load_split
from utils import ProgressBar, evaluate
from pack import Pack
from model import Model, TextOnlyModel
from torch_models import MultiLayerFCReLUClassifier, Encoder
from train import train_epoches

%load_ext autoreload
%autoreload 2

In [2]:
genre_list = pickle.load(open("../data/tmdb_genres_list.pkl", 'rb'))

GENRES = load_data("../local/genres.pkl")
train = load_split("../local/train.pkl")
val = load_split("../local/val.pkl")
test = load_split("../local/test.pkl")
embedding = torch.load('../local/embedding.pth')
OVERVIEWS_ENCODED = load_data("../local/overviews_encoded.pkl")
TITLES_ENCODED = load_data("../local/titles_encoded.pkl")

In [3]:
classifier = MultiLayerFCReLUClassifier(dims=[1024, 512], num_class=19, encoding_size=512, cuda=True)
encoder = Encoder(encoder=torch.nn.GRU, embedding=embedding, input_channel=embedding.embedding_dim,
                  hidden_dim=512, num_layers=3, cuda=True, bidirectional=True)
model = TextOnlyModel(encoder, classifier, OVERVIEWS_ENCODED, GENRES)
loss = torch.nn.BCEWithLogitsLoss().cuda()

In [4]:
adam = torch.optim.Adam(filter(lambda p:p.requires_grad, model.parameters()))

optimizer = adam
scheduler = None

In [None]:
loss_hist = []
n_epochs = 10
for i in range(2,20):
    epoch_losses = train_epoches(n_epochs=n_epochs, model=model, train=train, loss=loss, val=val,
                  batch_size=32, optimizer=optimizer, scheduler=scheduler)
    loss_hist.append(epoch_losses)
    bn = (i+1)*n_epochs
    torch.save(model.encoder, "saved/overview-gru/encoder_{}_{}.pth".format(bn, str(epoch_losses[1][1][-1])[:4]))
    torch.save(model.classifier, "saved/overview-gru/cls_{}_{}.pth".format(bn, str(epoch_losses[1][1][-1])[:4]))

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 0: 0.122330375
Val:	P(mi) 0.375 	R(mi): 0.5581395348837209 	F1(mi): 0.4485981308411215
P(ma) 0.230913423269271 	R(ma): 0.4527017720389579 	F1(ma): 0.2837182727154955
P(w) 0.454673843763858 	R(w): 0.558

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


...................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 1: 0.10903359
Val:	P(mi) 0.375732421875 	R(mi): 0.5637362637362637 	F1(mi): 0.4509229416935247
P(ma) 0.22148061404164177 	R(ma): 0.5219469927576015 	F1(ma): 0.2736769604883948
P(w) 0.45969847291822546 	R(

Val:	P(mi) 0.447021484375 	R(mi): 0.5304171494785631 	F1(mi): 0.4851616322204557
P(ma) 0.34027880118364456 	R(ma): 0.44104456252435537 	F1(ma): 0.37840818673405713
P(w) 0.4607256853464411 	R(w): 0.5304171494785631 	F1(w): 0.48970043430212107


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95......

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 4: 0.015529543
Val:	P(mi) 0.488525390625 	R(mi): 0.5433070866141733 	F1(mi): 0.5144620131122253
P(ma) 0.41190533314568656 	R(ma): 0.5096668909960707 	F1(ma): 0.447011474531566
P(w) 0.5052879556374674 	

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.011065042
Val:	P(mi) 0.4873046875 	R(mi): 0.5512289422811378 	F1(mi): 0.5172994687054554
P(ma) 0.4424038200871513 	R(ma): 0.4951340184329578 	F1(ma): 0.45999150140364536
P(w) 0.49907849335394 	R(w

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 4: 0.010526063
Val:	P(mi) 0.4755859375 	R(mi): 0.5617070357554786 	F1(mi): 0.5150713907985193
P(ma) 0.4194280715025919 	R(ma): 0.5392311917443727 	F1(ma): 0.46338808129207143
P(w) 0.489436178842218 	R(

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 2: 0.009197625
Val:	P(mi) 0.4853515625 	R(mi): 0.5579567779960707 	F1(mi): 0.5191278234756496
P(ma) 0.41539519922847284 	R(ma): 0.5192162507300713 	F1(ma): 0.45409986800998786
P(w) 0.5031501648995276 	

Val:	P(mi) 0.472900390625 	R(mi): 0.5474844544940645 	F1(mi): 0.5074665968037726
P(ma) 0.4199135767105648 	R(ma): 0.49803063193776786 	F1(ma): 0.45290511213453755
P(w) 0.47986007465522584 	R(w): 0.5474844544940645 	F1(w): 0.5099939981502698


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.......

In [None]:
evaluate(test, tomodel, 128)