In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import numpy as np
import pickle
import cv2
import os

from torchvision.models import vgg16
from torchvision import transforms
from torch import nn

from vocab import Vocab, load_vocab
from common import Data, Split, Batches, load_data, encode_y, load_split
from utils import ProgressBar, evaluate
from pack import Pack
from model import PosterOnlyModel, Model
from torch_models import MultiLayerFCReLUClassifier
from itertools import chain
from train import train_epoches

%load_ext autoreload
%autoreload 2

In [2]:
genre_list = pickle.load(open("../data/tmdb_genres_list.pkl", 'rb'))

GENRES = load_data("../local/genres.pkl")
train = load_split("../local/train.pkl")
val = load_split("../local/val.pkl")
test = load_split("../local/test.pkl")
POSTERS = load_data("../local/posters.npy")

MODEL = vgg16(pretrained=True)

In [3]:
# Freeze base layers (except for the last three conv layers)
for child in list(MODEL.features.children())[:-7]:
    for param in child.parameters():
        param.requires_grad = False

In [4]:
CNN = MODEL.features.cuda()
classifier = MultiLayerFCReLUClassifier(dims=[4096, 4096], num_class=19, encoding_size=25088, cuda=True)

model = PosterOnlyModel(CNN, classifier, POSTERS, GENRES)

In [5]:
adam = torch.optim.Adam(filter(lambda p:p.requires_grad, model.parameters()))

optimizer = adam
scheduler = None

loss = torch.nn.BCEWithLogitsLoss().cuda()

In [6]:
loss_hist = []
n_epochs = 10
for i in range(0,20):
    epoch_losses = train_epoches(n_epochs=n_epochs, model=model, train=train, loss=loss, val=val,
                  batch_size=32, optimizer=optimizer, scheduler=scheduler)
    loss_hist.append(epoch_losses)
    bn = (i+1)*n_epochs
    torch.save(model.encoder, "saved/poster-vgg/encoder_{}_{}.pth".format(bn, str(epoch_losses[1][1][-1])[:4]))
    torch.save(model.classifier, "saved/poster-vgg/cls_{}_{}.pth".format(bn, str(epoch_losses[1][1][-1])[:4]))

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 0: 0.21802922
Val:	P(mi) 0.088623046875 	R(mi): 0.41580756013745707 	F1(mi): 0.14610585630911652
P(ma) 0.043372570942812334 	R(ma): 0.10944310516085227 	F1(ma): 0.04953729318794403
P(w) 0.4228255874873

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


.....................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 1: 0.20316555
Val:	P(mi) 0.052001953125 	R(mi): 0.6283185840707964 	F1(mi): 0.09605411499436302
P(ma) 0.030321302087653518 	R(ma): 0.26613148038635914 	F1(ma): 0.05108918459018593
P(w) 0.109333903436758

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.0769019
Val:	P(mi) 0.29150390625 	R(mi): 0.5091684434968017 	F1(mi): 0.3707498835584536
P(ma) 0.24264563554494534 	R(ma): 0.5608375085229987 	F1(ma): 0.3194441720057705
P(w) 0.3271594570436207 	R(

Val:	P(mi) 0.35791015625 	R(mi): 0.4939353099730458 	F1(mi): 0.4150622876557192
P(ma) 0.31509083199468835 	R(ma): 0.5034400624826589 	F1(ma): 0.3747485554583298
P(w) 0.3827853625758906 	R(w): 0.4939353099730458 	F1(w): 0.42522050300938974


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.........

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 4: 0.030443257
Val:	P(mi) 0.362060546875 	R(mi): 0.4510340632603406 	F1(mi): 0.40167930660888407
P(ma) 0.3162971567759289 	R(ma): 0.4437305688422581 	F1(ma): 0.3587083118234757
P(w) 0.3871810673233985 

Val:	P(mi) 0.32470703125 	R(mi): 0.48188405797101447 	F1(mi): 0.38798133022170367
P(ma) 0.2982323464242372 	R(ma): 0.5341450779052689 	F1(ma): 0.3683351177746191
P(w) 0.34327757410135573 	R(w): 0.48188405797101447 	F1(w): 0.39346016123392935


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95......

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.035175912
Val:	P(mi) 0.366943359375 	R(mi): 0.48359073359073357 	F1(mi): 0.4172681843420322
P(ma) 0.28936266512238756 	R(ma): 0.5861887398924963 	F1(ma): 0.3681972924138874
P(w) 0.4150825399111428

Val:	P(mi) 0.34765625 	R(mi): 0.4992987377279102 	F1(mi): 0.409902130109384
P(ma) 0.2870055749179976 	R(ma): 0.5506052193755032 	F1(ma): 0.3637487657061622
P(w) 0.37568795869890526 	R(w): 0.4992987377279102 	F1(w): 0.42149445815591857


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.............

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 4: 0.03535888
Val:	P(mi) 0.36083984375 	R(mi): 0.49464524765729584 	F1(mi): 0.41727837380011296
P(ma) 0.3035563194329146 	R(ma): 0.5357508176931046 	F1(ma): 0.37731119196208024
P(w) 0.3951805770307307 

Val:	P(mi) 0.34814453125 	R(mi): 0.4774020756611985 	F1(mi): 0.40265424255259075
P(ma) 0.31048711859652967 	R(ma): 0.5249656422986022 	F1(ma): 0.36955050730956757
P(w) 0.3819051962659794 	R(w): 0.4774020756611985 	F1(w): 0.4128261542725


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95...........

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.04733278
Val:	P(mi) 0.355224609375 	R(mi): 0.5059109874826148 	F1(mi): 0.4173838209982788
P(ma) 0.30045800965573277 	R(ma): 0.5506269033485165 	F1(ma): 0.37333245604494847
P(w) 0.39660072276612207

Val:	P(mi) 0.335205078125 	R(mi): 0.46605566870332654 	F1(mi): 0.38994603805737005
P(ma) 0.2949655691594088 	R(ma): 0.5113462392166778 	F1(ma): 0.36142249169560453
P(w) 0.3526343535614362 	R(w): 0.46605566870332654 	F1(w): 0.39476104544520474


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 4: 0.05232189
Val:	P(mi) 0.3291015625 	R(mi): 0.49925925925925924 	F1(mi): 0.3967039434961742
P(ma) 0.29823374209333126 	R(ma): 0.5548106629439541 	F1(ma): 0.37858363042937804
P(w) 0.34911580932945846 

Val:	P(mi) 0.36376953125 	R(mi): 0.49029285949325435 	F1(mi): 0.41765942536790474
P(ma) 0.3033416246493193 	R(ma): 0.5818103289671342 	F1(ma): 0.3820760017963795
P(w) 0.39942438819083287 	R(w): 0.49029285949325435 	F1(w): 0.4272795964123326


......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.......

......................................5.....................................10......................................15.....................................20......................................25.....................................30......................................35.....................................40......................................45.....................................50......................................55.....................................60......................................65.....................................70......................................75.....................................80......................................85.....................................90......................................95.....................................100 F
epoch 9: 0.055917326
Val:	P(mi) 0.35498046875 	R(mi): 0.5039861351819758 	F1(mi): 0.41655923220169033
P(ma) 0.3094181461344697 	R(ma): 0.5771841370878402 	F1(ma): 0.3882870956609483
P(w) 0.3755716230256079 	

KeyboardInterrupt: 

In [None]:
evaluate(test, model, 128)