In [1]:
from torch import nn
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, BCELoss, MultiLabelMarginLoss
from transformers import WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset, RandomSampler, SequentialSampler
from tqdm import trange, tqdm
import os
import copy
import logging
from seqeval.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_fscore_support

from PIL import Image
from IPython.display  import display
import torchvision.transforms as transforms

from decimal import *

getcontext().prec = 4

logger = logging.getLogger(__name__)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

epochs = 4
batch_size = 16

# Prepare Dataset

This notebook also collects image data to combine the CNN model with BERT. 

In [2]:
# use pandas to read excel data as a dataframe
spreadsheet = pd.read_excel("Data_Full.xlsx")
#spreadsheet = pd.read_excel("Data_Medium.xlsx")
#spreadsheet = pd.read_excel("Data_Smallest.xlsx")
#spreadsheet = pd.read_excel("CleanData.xlsx")

# use prebuilt tags dictionary to match up with images
genres = ['Mystery','Thriller','Action','Adventure','Horror','Crime','Drama','Comedy','War','Romance','Fantasy','SciFi','Family','Biography','Music','Western','Animation','History','Sport','Film-Noir']
tags = {}
tags = {genres[i]: i for i in range(len(genres))}
print(tags)

# extract relevant entries 
movies = spreadsheet[["description", "genre", "imdb_title_id"]].to_numpy()
data = []
images = []
image_folder = '/data/user/jvdelmon/fullimages/FullImages/'
labels_array = []
longest_sentence = 0
tfms = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])
counter = 0
for m in movies:
    #print(m[0].encode())
    if type(m[0]) != float:
        if counter % 200 == 0:
            print(counter)
            
        # tokenize text
        tokens = tokenizer.tokenize(m[0])
        tokens.insert(0, '[CLS]')
        tokens.append('[SEP]')
        tokens = tokenizer.convert_tokens_to_ids(tokens)

        if len(tokens) > longest_sentence:
            longest_sentence = len(tokens)

        # work out genre label
        genres = m[1].split(", ")
        label = []
        for g in genres:
            if g not in tags:
                tags[g] = len(tags)
                labels_array.append(g)
            label.append(tags[g])       
            
        # collect and preprocess images
        try:
            image = tfms(Image.open(image_folder+m[2]+".jpg").convert("RGB")).unsqueeze(0)
            
        except:
            pass

        data.append([tokens, image, label])
        counter += 1
        
print(longest_sentence)


{'Mystery': 0, 'Thriller': 1, 'Action': 2, 'Adventure': 3, 'Horror': 4, 'Crime': 5, 'Drama': 6, 'Comedy': 7, 'War': 8, 'Romance': 9, 'Fantasy': 10, 'SciFi': 11, 'Family': 12, 'Biography': 13, 'Music': 14, 'Western': 15, 'Animation': 16, 'History': 17, 'Sport': 18, 'Film-Noir': 19}
0
200
400
600
800
1000
1200
1400
1600
1800
2000
2200
2400
2600
2800
3000
3200
3400
3600
3800
4000
4200
4400
4600
4800
5000
5200
5400
5600
5800
6000
6200
6400
6600
6800
7000
7200
7400
7600
7800
8000
8200
8400
8600
8800
9000
9200
9400
9600
9800
10000
10200
10400
10600
10800
11000
11200
11400
11600
11800
12000
12200
12400
12600
12800
13000
13200
13400
13600
13800
14000
14200
14400
14600
14800
15000
15200
15400
15600
15800
16000
16200
16400
16600
16800
17000
17200
17400
17600
17800
18000
18200
18400
18600
18800
19000
19200
19400
19600
19800
20000
96


In [3]:
print(tags)
print(len(data))

{'Mystery': 0, 'Thriller': 1, 'Action': 2, 'Adventure': 3, 'Horror': 4, 'Crime': 5, 'Drama': 6, 'Comedy': 7, 'War': 8, 'Romance': 9, 'Fantasy': 10, 'SciFi': 11, 'Family': 12, 'Biography': 13, 'Music': 14, 'Western': 15, 'Animation': 16, 'History': 17, 'Sport': 18, 'Film-Noir': 19}
20051


In [4]:
'''
# fetch rotton tomatoes data
spreadsheet = pd.read_csv("movie_info.tsv", sep='\t')

# extract relevant entries 
movies = spreadsheet[["synopsis", "genre"]].to_numpy()
data = []
tags = {}
labels_array = []
longest_sentence = 0
for m in movies:
    #print(m[0].encode())
    if type(m[0]) != float and type(m[1]) != float:
        # tokenize text
        tokens = tokenizer.tokenize(m[0])
        tokens.insert(0, '[CLS]')
        tokens.append('[SEP]')
        tokens = tokenizer.convert_tokens_to_ids(tokens)

        if len(tokens) > longest_sentence:
            longest_sentence = len(tokens)

        # work out genre label
        genres = m[1].split("|")
        label = []
        for g in genres:
            if g not in tags:
                tags[g] = len(tags)
                labels_array.append(g)
            label.append(tags[g])       

        data.append([tokens, label])
        
longest_sentence = min(512, longest_sentence)
print(longest_sentence)
'''

'\n# fetch rotton tomatoes data\nspreadsheet = pd.read_csv("movie_info.tsv", sep=\'\t\')\n\n# extract relevant entries \nmovies = spreadsheet[["synopsis", "genre"]].to_numpy()\ndata = []\ntags = {}\nlabels_array = []\nlongest_sentence = 0\nfor m in movies:\n    #print(m[0].encode())\n    if type(m[0]) != float and type(m[1]) != float:\n        # tokenize text\n        tokens = tokenizer.tokenize(m[0])\n        tokens.insert(0, \'[CLS]\')\n        tokens.append(\'[SEP]\')\n        tokens = tokenizer.convert_tokens_to_ids(tokens)\n\n        if len(tokens) > longest_sentence:\n            longest_sentence = len(tokens)\n\n        # work out genre label\n        genres = m[1].split("|")\n        label = []\n        for g in genres:\n            if g not in tags:\n                tags[g] = len(tags)\n                labels_array.append(g)\n            label.append(tags[g])       \n\n        data.append([tokens, label])\n        \nlongest_sentence = min(512, longest_sentence)\nprint(longest_

In [5]:
# pad sequences and create n hot encodings for labels
multi_data = []
for m in data:
    # truncate sequences than BERT's hard limit of 512
    sequence = m[0][:512]
    
    # ensure that the final token in the sequence is the [SEP] token
    if sequence[-1] != 102:
        sequence[-1] = 102
 
    padding = [0] * max(0, longest_sentence - len(sequence))
    sequence.extend(padding)
    
    line = np.zeros((len(tags), 1))
    for j in m[2]:
        line[j] = 1
    multi_data.append([sequence, m[1], line])

Building the labels is similar to most other classifiers. Rather than a one hot encoding, each label corresponding to a given sample is represented with a 1. 

The tags dictionary maps the label values onto discrete integers. These correspond to features of the output layer. We then map label values onto these tags. 

In [6]:
tags

{'Mystery': 0,
 'Thriller': 1,
 'Action': 2,
 'Adventure': 3,
 'Horror': 4,
 'Crime': 5,
 'Drama': 6,
 'Comedy': 7,
 'War': 8,
 'Romance': 9,
 'Fantasy': 10,
 'SciFi': 11,
 'Family': 12,
 'Biography': 13,
 'Music': 14,
 'Western': 15,
 'Animation': 16,
 'History': 17,
 'Sport': 18,
 'Film-Noir': 19}

An example of the n hot encodings:

In [7]:
for i in range(10):
    print(multi_data[i][2].flatten())

[1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


Processing the text data is more involved. The raw text data must be converted into numeric vectors which can be passed along to BERT. Additionally, these must be of a fixed length, and have several special tokens which will be covered later. This happens through the following steps:
1. Tokenization
2. Conversion to ids
3. Adding special characters and padding
4. Passing along id vectors (in the form of torch tensors) to BERT

The first step is to pass the text to the BERT tokenizer. This will break them down into subwords which can then be converted into ids. You'll also notice that all interior tokens are prefixed with ## and that all text is converted to lowercase. 

In [8]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

tokens = tokenizer.tokenize(movies[5][0])
print(tokens)

print(tokenizer.convert_tokens_to_ids(tokens))

['john', 'mcc', '##lane', 'attempts', 'to', 'ave', '##rt', 'disaster', 'as', 'rogue', 'military', 'operatives', 'seize', 'control', 'of', 'dull', '##es', 'international', 'airport', 'in', 'washington', ',', 'd', '.', 'c', '.']
[2198, 23680, 20644, 4740, 2000, 13642, 5339, 7071, 2004, 12406, 2510, 25631, 15126, 2491, 1997, 10634, 2229, 2248, 3199, 1999, 2899, 1010, 1040, 1012, 1039, 1012]


Additionally, special tokens need to be added to the beginning and end of the sentence. The [CLS] token is what BERT uses to signal the beginning of a piece of text and the [SEP] token is used between (and after) sentences. These map to ids 101 and 102, respectively.

Due to the difficulty of sentence segmentation, we ignore that and treat each sample as a single sentence. 

In [9]:
tokens.insert(0, '[CLS]')
tokens.append('[SEP]')

print(tokens)
print(tokenizer.convert_tokens_to_ids(tokens))

['[CLS]', 'john', 'mcc', '##lane', 'attempts', 'to', 'ave', '##rt', 'disaster', 'as', 'rogue', 'military', 'operatives', 'seize', 'control', 'of', 'dull', '##es', 'international', 'airport', 'in', 'washington', ',', 'd', '.', 'c', '.', '[SEP]']
[101, 2198, 23680, 20644, 4740, 2000, 13642, 5339, 7071, 2004, 12406, 2510, 25631, 15126, 2491, 1997, 10634, 2229, 2248, 3199, 1999, 2899, 1010, 1040, 1012, 1039, 1012, 102]


Inputs into BERT must be of a common length. Since the length of text samples will vary, we pad sentences to be of a uniform length. BERT supports lengths of up to 512 tokens, though having samples this large will cause memory issues with a batch size of 32 for 12-16 GB GPU's (https://github.com/google-research/bert#out-of-memory-issues). It's common practice to use either the longest sentence length or a fixed upper bound. In these examples, the sentences tend to be pretty short and so will keep memory manageable. 

Since the longest sentence in our dataset is 97 (as shown in the second cell), that's the value we'll pad our sentences to. 

In [10]:
np.array(multi_data[3][0]).flatten()

array([ 101, 2019, 6376, 9530, 1010, 2006, 1996, 2448, 2013, 1996, 2375,
       1010, 5829, 2046, 1037, 2496, 3232, 1005, 1055, 2160, 1998, 3138,
       2058, 2037, 3268, 1012,  102,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0])

At this point, the preprocessing is complete. The next step is to prepare the training, testing, and validation sets. This also converts the data to pytorch types. 

The attention mask is a vector denoting which tokens are padding (represented as 0) and which are not (represented as 1). These are used to get BERT to ignore padding tokens. 

In [11]:
# cache the data to skip doing preprocessing each iteration
# do this before creating splits to ensure each test is distinct
# done in four batches to prevent crashes in individual files 
import pickle

pickle.dump(multi_data[:5000], open('/data/user/jvdelmon/fullimages/multi_data', 'wb'))
pickle.dump(multi_data[5000:10000], open('/data/user/jvdelmon/fullimages/multi_data1', 'wb'))
pickle.dump(multi_data[10000:15000], open('/data/user/jvdelmon/fullimages/multi_data2', 'wb'))
pickle.dump(multi_data[15000:], open('/data/user/jvdelmon/fullimages/multi_data3', 'wb'))

Unless reprocessing the data, run from here down. 

In [2]:
import pickle

multi_data = pickle.load(open('/data/user/jvdelmon/fullimages/multi_data', 'rb'))
multi_data.extend(pickle.load(open('/data/user/jvdelmon/fullimages/multi_data1', 'rb')))
multi_data.extend(pickle.load(open('/data/user/jvdelmon/fullimages/multi_data2', 'rb')))
multi_data.extend(pickle.load(open('/data/user/jvdelmon/fullimages/multi_data3', 'rb')))
len(multi_data)

20051

In [3]:
test = multi_data[:10]
ims = [i[1] for i in test]
torch.stack(ims, axis=1)[0].shape

torch.Size([10, 3, 224, 224])

In [4]:
multi_data[0][1].shape

torch.Size([1, 3, 224, 224])

In [5]:
# as a test, compute logit values on the pretrained image model before building the combined model
# note: this is solely to ease memory as the full dataset uses around 11GB 

In [6]:
genres = ['Mystery','Thriller','Action','Adventure','Horror','Crime','Drama','Comedy','War','Romance','Fantasy','SciFi','Family','Biography','Music','Western','Animation','History','Sport','Film-Noir']
tags = {}
tags = {genres[i]: i for i in range(len(genres))}

In [7]:
import math

# build test, val, and training sets
threshold = [int(len(multi_data) * .8), int(len(multi_data) * .9)]
train_set, val_set, test_set = np.split(multi_data, threshold)
X_train = [] 
X_val = []
X_test = []

train_images = []
val_images = []
test_images = []

y_train = []
y_val = []
y_test = []
paddings = [] 

sequence_length = 96

# function to create sequence_length tensors to feed into bert, will pad or truncate as necessary
def BERT_tensorizer(doc, y=False):
    padding = max(0, sequence_length - len(doc))
    if y:
        #with_padding = [labels[CrossEntropyLoss().ignore_index]] * padding
        with_padding = [0] * padding
    else:
        with_padding = [CrossEntropyLoss().ignore_index] * padding
        with_padding = [0] * padding
        paddings.append(padding)
    return np.concatenate((np.array(doc[:sequence_length]).astype(int), with_padding))

# convert each doc to a tensor of fixed length
for doc in train_set:
    X_train.append(BERT_tensorizer(doc[0]))
    train_images.append(doc[1])
    y_train.append(doc[2])

for doc in val_set:
    X_val.append(BERT_tensorizer(doc[0]))
    val_images.append(doc[1])
    y_val.append(doc[2])

for doc in test_set:
    X_test.append(BERT_tensorizer(doc[0]))
    test_images.append(doc[1])
    y_test.append(doc[2])
    
# since the full training set is too large to process all at once, break it up into smaller chunks
# note: the full dataset will still run through the model each epoch
#for i in range(math.ceil(len(train_set/3200))):
#    pickle.dump(train_set[3200*i:3200*(i+1)], open('/data/user/jvdelmon/fullimages/training'+i, 'wb'))
    
# build attention masks
train_masks = [[float(i>0) for i in ii] for ii in X_train]
val_masks = [[float(i>0) for i in ii] for ii in X_val]
test_masks = [[float(i>0) for i in ii] for ii in X_test]

# convert to tensors
train_masks = torch.tensor(train_masks)
val_masks = torch.tensor(val_masks)
test_masks = torch.tensor(test_masks)

training_data = TensorDataset(torch.tensor(X_train).to(torch.int64), train_masks, torch.stack(train_images, axis=1)[0], torch.tensor(y_train).to(torch.int64))
val_data = TensorDataset(torch.tensor(X_val).to(torch.int64), val_masks, torch.stack(val_images, axis=1)[0], torch.tensor(y_val).to(torch.int64))
test_data = TensorDataset(torch.tensor(X_test).to(torch.int64), test_masks, torch.stack(test_images, axis=1)[0], torch.tensor(y_test).to(torch.int64))

multi_data = []  # clear memory

# then convert to dataloaders
train_dataloader = DataLoader(training_data, sampler=RandomSampler(training_data), batch_size=batch_size)
val_dataloader = DataLoader(val_data, sampler=RandomSampler(val_data), batch_size=batch_size)
test_dataloader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size)

The data is now ready for use with BERT.

# Building the Model

BERT is currently the cutting edge in NLP. The model itself is an extension of the transformer, developed by Google and published in the paper "Attention is All You Need" in July, 2017. The transformer uses multi-headed self attention rather than recurrence and significantly outperforms other models on a variety of NLP tasks, usually after only a single epoch. BERT adds the additional feature of bidirectionality, allowing the model to scan left and right as it processes samples. There are two main versions for working with English text: BERT Base and Bert Large. The Base model has 110M parameters and the large model has 340M parameters, but its memory constraints make it difficult to use on most GPUs. Additionally, there are a variety of other models in the transformers family, such as RoBERTa, GPT (and GPT2), and XLM, some of these are designed as enhancements to BERT, others work with translation or other languages. 

Multilabel classification requires 

In [8]:
from transformers import BertPreTrainedModel, BertModel
from transformers.configuration_bert import BertConfig

# extending BERT classes requires inheriting from the base model and adding on a head to perform a specific funtion
# I'm modeling my code on the transformers library's other models
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):

    # create a child class from the BertPreTrainedModel 
    # this is the same base which is used in BertForSequenceClassification
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.sigmoid = nn.Sigmoid()
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
        self.batch_norm = nn.BatchNorm2d(3)  # 3 for RGB input
        self.classifier2 = nn.Linear(2 * self.config.num_labels, self.config.num_labels)  # manually inputting size due to a label discrepancy
        self.init_weights()

    # updating the transformer head (code again adapted from BertForSequenceClassification)
    def forward(
        self, input_ids, attention_mask=None, image=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            # omit input_embeds in favor of letting the transformers module handle conversion from ids to embeddings
            #inputs_embeds=inputs_embeds,
        )

        # when taking gradients, outputs is [loss/device information, logits]
        # if with torch.no_grad() was used, only logits would be returned
        # thus in this case, outputs[1] is used
        pooled_output = outputs[1]

        # apply standard dropout
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        # apply a sigmoid to match image outputs
        #logits = self.sigmoid(logits)
        
        # process image logits and concatenate with BERT's outputs
        image = self.batch_norm(image)
        logits_image = image_model(image)

        # testing with a fully connected layer
        # this tends to not produce useful results, likely because too much is also training at the same time
        #logits = torch.cat((logits, logits_image), 1)
        #logits = self.classifier2(logits)
        #logits = nn.functional.relu(logits)
        
        # try with average of logit values instead of cat
        logits = torch.mean(torch.stack([logits, logits_image]), dim=0)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            # this is where the code differs
            # multilabel classification uses sigmoid + Binary Cross Entropy instead of the standard Cross Entropy 
            # used by the SequenceClassification version
            # here, the WithLogits version is used to combine both steps in an efficient implementation
            # for a list of loss functions available in torch.nn see:
            # https://pytorch.org/docs/stable/nn.html#loss-functions
            # test other options maybe, try BCELoss() to not include a second sigmoid
            loss_fct = BCEWithLogitsLoss()
            #loss_fct = BCELoss()
            #loss_fct = MultiLabelMarginLoss()
            
            # view reshapes samples to match logits
            loss = loss_fct(logits, labels.float().view(-1, self.num_labels))
            #loss = loss_fct(logits.view(-1, self.num_labels), labels.float().view(-1, self.num_labels))
            # loss for MLML
            #loss = loss_fct(logits.view(-1, self.num_labels), labels.long().view(-1, self.num_labels))

            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

This section specifies hyperparamaters and other variables used to manage the training process. We define the number of epochs, the batch size, as well as instantiate the specific BERT model we plan to use. The optimizer we use is AdamW with the linear schedule to manage training rates. Once this is done, configure torch to use the GPU, then send model.to() to pass the model to the GPU and model.cuda() to initialize GPU training. 

The BERT paper recommends 2-4 epochs and a batch size of 32. In my experience, performance tends to waver in the 4th epoch, making 3 sufficient in most cases. A batch size of 32 can cause problems given the size of the model. The authors used 64GB TPU's in their training, the 16GB GPU's available through Cheaha will run out of memory on long sequences. Fortunately, use of gradient accumulation can mitigate this, allowing gradients to be accumulated through mini batches then updated once sufficiently many have come in. 

In [9]:
# using the from_pretrained model applies the bert model trained on a large corpus and fine tunes it to our task
# num_labels is the output size, with sentiment analysis, use 2
model = BertForMultiLabelSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(tags))  
image_model = torch.load('/data/user/jvdelmon/Model_Pretrain2')
print(image_model)

pad_token_label_id = CrossEntropyLoss().ignore_index


'''
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
     "weight_decay": 0.0},
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]
optimizer = AdamW(model.parameters(), lr=5e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps =len(train_dataloader) * epochs)
'''

FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters()) 
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
#scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps =len(train_dataloader) * epochs)

# setup GPU
torch.cuda.set_device(0)
device = torch.device("cuda", 0)  # , 0
model.to(device)
model.cuda()
image_model.to(device)
image_model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
# evaluation method
# appears here because it's used in the validation stage
def evalModel(model, test_dataloader, multi_label=False):
    # predict
    model.eval()
    predictions, true_labels = [], []
    sample_out = []

    for batch in test_dataloader:
        #print(len(batch[0]), batch[0])
        batch = tuple(t.to(device) for t in batch)
        batch_labels = batch[2]
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "image": batch[2]}

        # calculating gradients is unnecessary in eval mode, saves time and memory to disable calculating it
        with torch.no_grad():
            outputs = model(**inputs)

            #batch_accuracy = flat_accuracy(outputs[0].detach().cpu().numpy(), batch_labels)
            #print(batch_accuracy)
            for i in outputs[0].cpu().detach().numpy():
                # in the case of multilabel classification, predicted labels are those with positive logit value
                if multi_label:
                    predictions.append((i > 0).astype(int))
                # single label case
                else:
                    predictions.append(np.argmax(i))
                
    return predictions

In [11]:
def train(model, optimizer, device, train_dataloader, val_dataloader):
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs, desc="Epoch")
    gradient_accumulation_steps = 32 / batch_size  # emulate batch size of 32
    logging_steps = 5000
    save_steps = 10000
    max_steps = 0
    epoch = 0
    
    # training cycle
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            #print(len(batch), batch[0])

            # tell the model to use training mode
            model.train()

            # GPU can only process tensors on the GPU
            batch = tuple(t.to(device) for t in batch)

            # configure parameters passed along to the model
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "image": batch[2], "labels": batch[3]}

            # the ** operator unpacks array/dictionary elements as optional arguments
            outputs = model(**inputs)     

            # model outputs are always tuple in pytorch-transformers, the first entry being loss
            loss = outputs[0]  

            # tell the model to do backprop
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

            tr_loss += loss.item()
            
            # track gradient accumulation
            if (step + 1) % gradient_accumulation_steps == 0 or batch_size > 32:
                # this performs the SGD update step
                optimizer.step()
                #scheduler.step()  # Update learning rate schedule
                
                # zeroing gradients is necessary between parameter updates
                # by default, torch accumulates gradients each time loss is calculated
                # and will only reset when calling zero_grad()
                model.zero_grad()
            
        # validation
        predictions = evalModel(model, val_dataloader, multi_label=True)
  
        # output results data for epoch
        preds = []
        for i in predictions:
            preds.append(i.tolist())

        true_labels = []
        for i in y_val:
            true_labels.append(i.astype(int).flatten().tolist())

        accuracy = 0
        positive_accuracy = 0
        perfect_matches = 0
        num_samples = len(true_labels)
        print(len(true_labels), len(preds))
        for i in range(num_samples):
            if true_labels[i] == preds[i]:
                accuracy += 1
                positive_accuracy += 1
                perfect_matches += 1
            else:
                # positive accuracy
                # num correct over max(num_guessed, num_true)
                positive_accuracy += np.sum(2*np.array(true_labels[i])-1==np.array(preds[i])) / max(np.sum(true_labels[i]), sum(preds[i]))

                # flat accuracy
                accuracy += np.sum(np.array(true_labels[i])==np.array(preds[i])) / len(true_labels[i])

        preds = np.array(preds)
        true_labels = np.array(true_labels)
        epoch += 1

        print("Training epoch:", epoch)
        print("Size of test set:", len(true_labels))
        print("Perfect matches: {}".format(perfect_matches))
        print("Positive accuracy:", positive_accuracy/num_samples)
        print("Flat accuracy:", accuracy/num_samples)
        print("F1-Score: {}".format(f1_score(y_pred=preds, y_true=true_labels, average='weighted')))
        stats = precision_recall_fscore_support(y_pred=preds, y_true=true_labels, average=None, labels=range(len(tags)))
        stats = np.array(stats)

        for t in tags:
            print(tags[t], t, stats[:, tags[t]])

In [12]:
train(model, optimizer, device, train_dataloader, val_dataloader)

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1003 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1003 [00:00<07:12,  2.32it/s][A
Iteration:   0%|          | 2/1003 [00:00<06:27,  2.59it/s][A
Iteration:   0%|          | 3/1003 [00:00<05:44,  2.90it/s][A
Iteration:   0%|          | 4/1003 [00:01<05:22,  3.10it/s][A
Iteration:   0%|          | 5/1003 [00:01<04:59,  3.34it/s][A
Iteration:   1%|          | 6/1003 [00:01<04:50,  3.43it/s][A
Iteration:   1%|          | 7/1003 [00:01<04:37,  3.59it/s][A
Iteration:   1%|          | 8/1003 [00:02<04:35,  3.62it/s][A
Iteration:   1%|          | 9/1003 [00:02<04:26,  3.74it/s][A
Iteration:   1%|          | 10/1003 [00:02<04:27,  3.72it/s][A
Iteration:   1%|          | 11/1003 [00:03<04:20,  3.81it/s][A
Iteration:   1%|          | 12/1003 [00:03<04:24,  3.74it/s][A
Iteration:   1%|▏         | 13/1003 [00:03<04:18,  3.82it/s][A
Iteration:   1%|▏         | 14/1003 [00:03<04:22,  3.77it/s][A
Iteration:   

2005 2005
Training epoch: 1
Size of test set: 2005
Perfect matches: 100
Positive accuracy: 0.16990856192851186
Flat accuracy: 0.8886783042394052
F1-Score: 0.15617024910194408
0 Mystery [  0.   0.   0. 103.]
1 Thriller [  0.   0.   0. 181.]
2 Action [  0.   0.   0. 210.]
3 Adventure [  0.   0.   0. 275.]
4 Horror [  0.   0.   0. 213.]
5 Crime [8.10810811e-02 2.29007634e-02 3.57142857e-02 2.62000000e+02]
6 Drama [4.86935867e-01 6.16232465e-01 5.44007077e-01 9.98000000e+02]
7 Comedy [2.80405405e-01 1.45614035e-01 1.91685912e-01 5.70000000e+02]
8 War [  0.   0.   0. 152.]
9 Romance [  0.   0.   0. 340.]
10 Fantasy [ 0.  0.  0. 68.]
11 SciFi [  0.   0.   0. 168.]
12 Family [  0.   0.   0. 120.]
13 Biography [ 0.  0.  0. 71.]
14 Music [  0.   0.   0. 130.]
15 Western [  0.   0.   0. 241.]
16 Animation [ 0.  0.  0. 17.]
17 History [ 0.  0.  0. 70.]
18 Sport [ 0.  0.  0. 18.]
19 Film-Noir [ 0.  0.  0. 29.]



Iteration:   0%|          | 1/1003 [00:00<04:05,  4.08it/s][A
Iteration:   0%|          | 2/1003 [00:00<04:12,  3.97it/s][A
Iteration:   0%|          | 3/1003 [00:00<04:10,  3.99it/s][A
Iteration:   0%|          | 4/1003 [00:01<04:15,  3.91it/s][A
Iteration:   0%|          | 5/1003 [00:01<04:12,  3.96it/s][A
Iteration:   1%|          | 6/1003 [00:01<04:16,  3.89it/s][A
Iteration:   1%|          | 7/1003 [00:01<04:12,  3.94it/s][A
Iteration:   1%|          | 8/1003 [00:02<04:16,  3.87it/s][A
Iteration:   1%|          | 9/1003 [00:02<04:12,  3.93it/s][A
Iteration:   1%|          | 10/1003 [00:02<04:16,  3.87it/s][A
Iteration:   1%|          | 11/1003 [00:02<04:12,  3.93it/s][A
Iteration:   1%|          | 12/1003 [00:03<04:16,  3.87it/s][A
Iteration:   1%|▏         | 13/1003 [00:03<04:12,  3.93it/s][A
Iteration:   1%|▏         | 14/1003 [00:03<04:15,  3.87it/s][A
Iteration:   1%|▏         | 15/1003 [00:03<04:11,  3.93it/s][A
Iteration:   2%|▏         | 16/1003 [00:04<04:15

2005 2005
Training epoch: 2
Size of test set: 2005
Perfect matches: 109
Positive accuracy: 0.17082294264339154
Flat accuracy: 0.8776807980049901
F1-Score: 0.18378629631882842
0 Mystery [  0.   0.   0. 103.]
1 Thriller [  0.   0.   0. 181.]
2 Action [8.18181818e-02 4.28571429e-02 5.62500000e-02 2.10000000e+02]
3 Adventure [1.78082192e-01 4.72727273e-02 7.47126437e-02 2.75000000e+02]
4 Horror [1.13636364e-01 7.04225352e-02 8.69565217e-02 2.13000000e+02]
5 Crime [1.56976744e-01 1.03053435e-01 1.24423963e-01 2.62000000e+02]
6 Drama [5.08214677e-01 4.64929860e-01 4.85609628e-01 9.98000000e+02]
7 Comedy [2.80318091e-01 2.47368421e-01 2.62814539e-01 5.70000000e+02]
8 War [  0.   0.   0. 152.]
9 Romance [2.74725275e-01 7.35294118e-02 1.16009281e-01 3.40000000e+02]
10 Fantasy [ 0.  0.  0. 68.]
11 SciFi [  0.   0.   0. 168.]
12 Family [  0.   0.   0. 120.]
13 Biography [ 0.  0.  0. 71.]
14 Music [  0.   0.   0. 130.]
15 Western [1.79487179e-01 5.80912863e-02 8.77742947e-02 2.41000000e+02]
16 Ani


Iteration:   0%|          | 1/1003 [00:00<04:05,  4.08it/s][A
Iteration:   0%|          | 2/1003 [00:00<04:12,  3.96it/s][A
Iteration:   0%|          | 3/1003 [00:00<04:10,  4.00it/s][A
Iteration:   0%|          | 4/1003 [00:01<04:15,  3.91it/s][A
Iteration:   0%|          | 5/1003 [00:01<04:12,  3.95it/s][A
Iteration:   1%|          | 6/1003 [00:01<04:17,  3.87it/s][A
Iteration:   1%|          | 7/1003 [00:01<04:13,  3.93it/s][A
Iteration:   1%|          | 8/1003 [00:02<04:17,  3.87it/s][A
Iteration:   1%|          | 9/1003 [00:02<04:13,  3.93it/s][A
Iteration:   1%|          | 10/1003 [00:02<04:16,  3.87it/s][A
Iteration:   1%|          | 11/1003 [00:02<04:12,  3.93it/s][A
Iteration:   1%|          | 12/1003 [00:03<04:16,  3.87it/s][A
Iteration:   1%|▏         | 13/1003 [00:03<04:12,  3.92it/s][A
Iteration:   1%|▏         | 14/1003 [00:03<04:15,  3.86it/s][A
Iteration:   1%|▏         | 15/1003 [00:03<04:11,  3.92it/s][A
Iteration:   2%|▏         | 16/1003 [00:04<04:15

2005 2005
Training epoch: 3
Size of test set: 2005
Perfect matches: 95
Positive accuracy: 0.1794679966749788
Flat accuracy: 0.8648877805486275
F1-Score: 0.19442449277738583
0 Mystery [  0.   0.   0. 103.]
1 Thriller [5.17241379e-02 1.65745856e-02 2.51046025e-02 1.81000000e+02]
2 Action [1.18110236e-01 7.14285714e-02 8.90207715e-02 2.10000000e+02]
3 Adventure [1.37254902e-01 5.09090909e-02 7.42705570e-02 2.75000000e+02]
4 Horror [1.08490566e-01 1.07981221e-01 1.08235294e-01 2.13000000e+02]
5 Crime [1.48936170e-01 1.06870229e-01 1.24444444e-01 2.62000000e+02]
6 Drama [5.02057613e-01 4.88977956e-01 4.95431472e-01 9.98000000e+02]
7 Comedy [2.64940239e-01 2.33333333e-01 2.48134328e-01 5.70000000e+02]
8 War [7.40740741e-02 1.31578947e-02 2.23463687e-02 1.52000000e+02]
9 Romance [1.60804020e-01 9.41176471e-02 1.18738404e-01 3.40000000e+02]
10 Fantasy [ 0.  0.  0. 68.]
11 SciFi [6.89655172e-02 3.57142857e-02 4.70588235e-02 1.68000000e+02]
12 Family [  0.   0.   0. 120.]
13 Biography [ 0.  0.  


Iteration:   0%|          | 1/1003 [00:00<04:07,  4.04it/s][A
Iteration:   0%|          | 2/1003 [00:00<04:15,  3.92it/s][A
Iteration:   0%|          | 3/1003 [00:00<04:12,  3.95it/s][A
Iteration:   0%|          | 4/1003 [00:01<04:18,  3.87it/s][A
Iteration:   0%|          | 5/1003 [00:01<04:14,  3.92it/s][A
Iteration:   1%|          | 6/1003 [00:01<04:19,  3.84it/s][A
Iteration:   1%|          | 7/1003 [00:01<04:16,  3.89it/s][A
Iteration:   1%|          | 8/1003 [00:02<04:20,  3.83it/s][A
Iteration:   1%|          | 9/1003 [00:02<04:15,  3.88it/s][A
Iteration:   1%|          | 10/1003 [00:02<04:19,  3.82it/s][A
Iteration:   1%|          | 11/1003 [00:02<04:15,  3.88it/s][A
Iteration:   1%|          | 12/1003 [00:03<04:19,  3.82it/s][A
Iteration:   1%|▏         | 13/1003 [00:03<04:15,  3.88it/s][A
Iteration:   1%|▏         | 14/1003 [00:03<04:18,  3.82it/s][A
Iteration:   1%|▏         | 15/1003 [00:03<04:14,  3.88it/s][A
Iteration:   2%|▏         | 16/1003 [00:04<04:18

2005 2005
Training epoch: 4
Size of test set: 2005
Perfect matches: 87
Positive accuracy: 0.18561928512053125
Flat accuracy: 0.8583790523690737
F1-Score: 0.203611116979714
0 Mystery [  0.   0.   0. 103.]
1 Thriller [8.00000000e-02 3.31491713e-02 4.68750000e-02 1.81000000e+02]
2 Action [1.00000000e-01 6.19047619e-02 7.64705882e-02 2.10000000e+02]
3 Adventure [1.45454545e-01 8.72727273e-02 1.09090909e-01 2.75000000e+02]
4 Horror [9.52380952e-02 9.38967136e-02 9.45626478e-02 2.13000000e+02]
5 Crime [1.20603015e-01 9.16030534e-02 1.04121475e-01 2.62000000e+02]
6 Drama [4.98522167e-01 5.07014028e-01 5.02732240e-01 9.98000000e+02]
7 Comedy [2.66037736e-01 2.47368421e-01 2.56363636e-01 5.70000000e+02]
8 War [9.45945946e-02 4.60526316e-02 6.19469027e-02 1.52000000e+02]
9 Romance [1.61392405e-01 1.50000000e-01 1.55487805e-01 3.40000000e+02]
10 Fantasy [ 0.  0.  0. 68.]
11 SciFi [8.45070423e-02 3.57142857e-02 5.02092050e-02 1.68000000e+02]
12 Family [1.02564103e-01 3.33333333e-02 5.03144654e-02 




In [13]:
predictions = evalModel(model, test_dataloader, multi_label=True)

In [14]:
# perform analysis on the predictions
preds = []
for i in predictions:
    preds.append(i.tolist())
    
true_labels = []
for i in y_test:
    true_labels.append(i.astype(int).flatten().tolist())
    
accuracy = 0
positive_accuracy = 0
perfect_matches = 0
num_samples = len(true_labels)
for i in range(num_samples):
    if true_labels[i] == preds[i]:
        accuracy += 1
        positive_accuracy += 1
        perfect_matches += 1
    else:
        # positive accuracy
        # num correct over max(num_guessed, num_true)
        positive_accuracy += np.sum(2*np.array(true_labels[i])-1==np.array(preds[i])) / max(np.sum(true_labels[i]), sum(preds[i]))
        
        # flat accuracy
        accuracy += np.sum(np.array(true_labels[i])==np.array(preds[i])) / len(true_labels[i])
     
preds = np.array(preds)
true_labels = np.array(true_labels)
    
print("Size of test set:", len(true_labels))
print("Perfect matches: {}".format(perfect_matches))
print("Positive accuracy:", positive_accuracy/num_samples)
print("Flat accuracy:", accuracy/num_samples)
print("F1-Score: {}".format(f1_score(y_pred=preds, y_true=true_labels, average='weighted')))
stats = precision_recall_fscore_support(y_pred=preds, y_true=true_labels, average=None, labels=range(len(tags)))
stats = np.array(stats)

for t in tags:
    print(tags[t], t, stats[:, tags[t]])

#print("Confusion Matrix:")
#print(confusion_matrix(preds, true_labels))

Size of test set: 2006
Perfect matches: 329
Positive accuracy: 0.4566716517115332
Flat accuracy: 0.9138085742771787
F1-Score: 0.5109011326998559
0 Mystery [6.00000000e-01 9.75609756e-02 1.67832168e-01 1.23000000e+02]
1 Thriller [4.46969697e-01 2.08480565e-01 2.84337349e-01 2.83000000e+02]
2 Action [6.82080925e-01 3.08900524e-01 4.25225225e-01 3.82000000e+02]
3 Adventure [  0.5648855    0.31092437   0.40108401 238.        ]
4 Horror [  0.82043344   0.71815718   0.76589595 369.        ]
5 Crime [  0.67549669   0.5164557    0.58536585 395.        ]
6 Drama [7.01312910e-01 6.53414883e-01 6.76517150e-01 9.81000000e+02]
7 Comedy [6.25688073e-01 5.25423729e-01 5.71189280e-01 6.49000000e+02]
8 War [ 0.63636364  0.38888889  0.48275862 36.        ]
9 Romance [  0.3943662    0.40776699   0.40095465 206.        ]
10 Fantasy [2.5000000e-01 1.2195122e-02 2.3255814e-02 8.2000000e+01]
11 SciFi [  0.84482759   0.31410256   0.45794393 156.        ]
12 Family [ 0.65853659  0.28125     0.39416058 96.     

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


In [15]:
for i in preds[:20]:
    print(i)

[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]


Running on Data_Smallest with 12796 samples
Size of test set: 1277

4 epochs 
Perfect matches: 214
Positive accuracy: 0.444009397024275
Flat accuracy: 0.9143589378515038
F1-Score: 0.5423658558507969
0 Mystery [ 0.48387097  0.3125      0.37974684 96.        ]
1 Thriller [  0.46448087   0.40669856   0.43367347 209.        ]
2 Action [  0.51445087   0.39035088   0.44389027 228.        ]
3 Adventure [  0.61445783   0.49756098   0.54986523 205.        ]
4 Horror [  0.91385768   0.68732394   0.78456592 355.        ]
5 Crime [  0.52073733   0.62087912   0.56641604 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.48184818e-01 4.52464789e-01 5.90126292e-01 5.68000000e+02]
8 War [ 0.64285714  0.375       0.47368421 24.        ]
9 Romance [  0.28571429   0.21428571   0.24489796 112.        ]
10 Fantasy [ 0.46666667  0.10144928  0.16666667 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.51282051  0.25974026  0.34482759 77.        ]
13 Biography [ 0.66666667  0.5         0.57142857 28.        ]
14 Music [ 0.35714286  0.27777778  0.3125     18.        ]
15 Western [  0.85714286   0.61538462   0.71641791 156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 1.          0.09090909  0.16666667 11.        ]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.31818182  0.17073171  0.22222222 41.        ]
21 Sci-Fi [  0.86956522   0.46783626   0.60836502 171.        ]

Perfect matches: 170
Positive accuracy: 0.40015661707125955
Flat accuracy: 0.9083078237346108
F1-Score: 0.510587001697972
0 Mystery [ 0.37313433  0.26041667  0.30674847 96.        ]
1 Thriller [  0.4491018    0.35885167   0.39893617 209.        ]
2 Action [  0.52554745   0.31578947   0.39452055 228.        ]
3 Adventure [  0.65806452   0.49756098   0.56666667 205.        ]
4 Horror [  0.90612245   0.62535211   0.74       355.        ]
5 Crime [  0.54205607   0.63736264   0.58585859 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.64541833e-01 3.82042254e-01 5.29914530e-01 5.68000000e+02]
8 War [ 0.63636364  0.29166667  0.4        24.        ]
9 Romance [  0.28985507   0.35714286   0.32       112.        ]
10 Fantasy [ 0.40540541  0.2173913   0.28301887 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.46875     0.19480519  0.27522936 77.        ]
13 Biography [ 0.66666667  0.35714286  0.46511628 28.        ]
14 Music [ 0.57142857  0.22222222  0.32       18.        ]
15 Western [  0.80869565   0.59615385   0.68634686 156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.42307692  0.26829268  0.32835821 41.        ]
21 Sci-Fi [  0.86111111   0.3625731    0.51028807 171.        ]

2 epochs
Perfect matches: 238
Positive accuracy: 0.3987209605847025
Flat accuracy: 0.9167081939204152
F1-Score: 0.4656778510612374
0 Mystery [7.27272727e-01 8.33333333e-02 1.49532710e-01 9.60000000e+01]
1 Thriller [5.76271186e-01 1.62679426e-01 2.53731343e-01 2.09000000e+02]
2 Action [  0.48066298   0.38157895   0.42542787 228.        ]
3 Adventure [  0.71774194   0.43414634   0.54103343 205.        ]
4 Horror [  0.88850174   0.71830986   0.79439252 355.        ]
5 Crime [  0.57653061   0.62087912   0.5978836  182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [9.09090909e-01 3.34507042e-01 4.89060489e-01 5.68000000e+02]
8 War [ 0.66666667  0.25        0.36363636 24.        ]
9 Romance [  0.47222222   0.15178571   0.22972973 112.        ]
10 Fantasy [ 0.  0.  0. 69.]
11 SciFi [0. 0. 0. 0.]
12 Family [1.00000000e+00 1.29870130e-02 2.56410256e-02 7.70000000e+01]
13 Biography [ 1.          0.17857143  0.3030303  28.        ]
14 Music [ 0.  0.  0. 18.]
15 Western [  0.86868687   0.55128205   0.6745098  156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [1.00000000e+00 2.43902439e-02 4.76190476e-02 4.10000000e+01]
21 Sci-Fi [  0.90277778   0.38011696   0.53497942 171.        ]

Perfect matches: 228
Positive accuracy: 0.38958496476115784
Flat accuracy: 0.9154979710970367
F1-Score: 0.4471914401927798
0 Mystery [5.00000000e-01 4.16666667e-02 7.69230769e-02 9.60000000e+01]
1 Thriller [5.52631579e-01 1.00478469e-01 1.70040486e-01 2.09000000e+02]
2 Action [  0.45714286   0.49122807   0.47357294 228.        ]
3 Adventure [  0.76635514   0.4          0.52564103 205.        ]
4 Horror [  0.92342342   0.57746479   0.71057192 355.        ]
5 Crime [  0.55042017   0.71978022   0.62380952 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.64963504e-01 4.17253521e-01 5.62945368e-01 5.68000000e+02]
8 War [ 0.75   0.25   0.375 24.   ]
9 Romance [  0.5483871    0.15178571   0.23776224 112.        ]
10 Fantasy [ 0.  0.  0. 69.]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.72727273  0.1038961   0.18181818 77.        ]
13 Biography [ 1.          0.21428571  0.35294118 28.        ]
14 Music [ 0.  0.  0. 18.]
15 Western [  0.81914894   0.49358974   0.616      156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [1.00000000e+00 2.43902439e-02 4.76190476e-02 4.10000000e+01]
21 Sci-Fi [9.16666667e-01 1.28654971e-01 2.25641026e-01 1.71000000e+02]

3 epochs
Perfect matches: 215
Positive accuracy: 0.39323936309057594
Flat accuracy: 0.9149996440521161
F1-Score: 0.47711102003170297
0 Mystery [7.00000000e-01 7.29166667e-02 1.32075472e-01 9.60000000e+01]
1 Thriller [5.12195122e-01 2.00956938e-01 2.88659794e-01 2.09000000e+02]
2 Action [  0.5359116    0.4254386    0.47432763 228.        ]
3 Adventure [  0.73786408   0.37073171   0.49350649 205.        ]
4 Horror [  0.888        0.62535211   0.7338843  355.        ]
5 Crime [  0.55769231   0.63736264   0.59487179 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.82845188e-01 3.71478873e-01 5.22924411e-01 5.68000000e+02]
8 War [ 0.56        0.58333333  0.57142857 24.        ]
9 Romance [  0.38095238   0.21428571   0.27428571 112.        ]
10 Fantasy [ 0.85714286  0.08695652  0.15789474 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [5.55555556e-01 6.49350649e-02 1.16279070e-01 7.70000000e+01]
13 Biography [ 1.          0.28571429  0.44444444 28.        ]
14 Music [ 0.  0.  0. 18.]
15 Western [  0.87012987   0.42948718   0.5751073  156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.42857143  0.07317073  0.125      41.        ]
21 Sci-Fi [  0.92105263   0.40935673   0.56680162 171.        ]

Perfect matches: 218
Positive accuracy: 0.4107935264943871
Flat accuracy: 0.9156403502527276
F1-Score: 0.4940888922481723
0 Mystery [ 0.6875      0.11458333  0.19642857 96.        ]
1 Thriller [  0.45394737   0.33014354   0.38227147 209.        ]
2 Action [  0.54639175   0.23245614   0.32615385 228.        ]
3 Adventure [  0.78571429   0.37560976   0.50825083 205.        ]
4 Horror [  0.83900929   0.76338028   0.79941003 355.        ]
5 Crime [  0.53809524   0.62087912   0.57653061 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [9.23404255e-01 3.82042254e-01 5.40473225e-01 5.68000000e+02]
8 War [ 0.7         0.29166667  0.41176471 24.        ]
9 Romance [  0.39473684   0.26785714   0.31914894 112.        ]
10 Fantasy [ 0.5         0.13043478  0.20689655 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.56        0.18181818  0.2745098  77.        ]
13 Biography [ 0.75        0.42857143  0.54545455 28.        ]
14 Music [ 1.          0.05555556  0.10526316 18.        ]
15 Western [  0.82644628   0.64102564   0.72202166 156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.4         0.09756098  0.15686275 41.        ]
21 Sci-Fi [  0.9          0.26315789   0.40723982 171.        ]

Running on DataMedium with 18488 samples
Size of test set: 1843
Based on the results of the Data Small test, I'm using only 2 epochs

Perfect matches: 314
Positive accuracy: 0.44732320491951755
Flat accuracy: 0.9220391653924017
F1-Score: 0.4810285072493092
0 Mystery [6.50000000e-01 1.08333333e-01 1.85714286e-01 1.20000000e+02]
1 Thriller [3.61111111e-01 4.76190476e-02 8.41423948e-02 2.73000000e+02]
2 Action [6.25766871e-01 2.81767956e-01 3.88571429e-01 3.62000000e+02]
3 Adventure [  0.52892562   0.31683168   0.39628483 202.        ]
4 Horror [  0.875        0.62222222   0.72727273 360.        ]
5 Crime [  0.68194842   0.62303665   0.65116279 382.        ]
6 Drama [6.91466083e-01 6.96802646e-01 6.94124108e-01 9.07000000e+02]
7 Comedy [6.40668524e-01 4.44015444e-01 5.24515393e-01 5.18000000e+02]
8 War [ 0.5         0.18181818  0.26666667 22.        ]
9 Romance [  0.42857143   0.29787234   0.35146444 141.        ]
10 Fantasy [4.44444444e-01 5.79710145e-02 1.02564103e-01 6.90000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.6875      0.14285714  0.23655914 77.        ]
13 Biography [8.33333333e-01 2.89017341e-02 5.58659218e-02 1.73000000e+02]
14 Music [ 1.          0.10638298  0.19230769 47.        ]
15 Western [  0.7654321    0.50406504   0.60784314 123.        ]
16 Animation [0. 0. 0. 0.]
17 History [  0.   0.   0. 115.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.7704918    0.63087248   0.69372694 149.        ]
21 Musical [ 0.28571429  0.10810811  0.15686275 37.        ]

Perfect matches: 337
Positive accuracy: 0.4588081027310576
Flat accuracy: 0.9241355497459777
F1-Score: 0.48448282543563465
0 Mystery [5.45454545e-01 5.00000000e-02 9.16030534e-02 1.20000000e+02]
1 Thriller [4.11111111e-01 1.35531136e-01 2.03856749e-01 2.73000000e+02]
2 Action [6.63212435e-01 3.53591160e-01 4.61261261e-01 3.62000000e+02]
3 Adventure [  0.81132075   0.21287129   0.3372549  202.        ]
4 Horror [  0.85616438   0.69444444   0.76687117 360.        ]
5 Crime [  0.68767908   0.62827225   0.65663475 382.        ]
6 Drama [6.79282869e-01 7.51929438e-01 7.13762428e-01 9.07000000e+02]
7 Comedy [7.09897611e-01 4.01544402e-01 5.12946979e-01 5.18000000e+02]
8 War [ 0.2         0.13636364  0.16216216 22.        ]
9 Romance [  0.47368421   0.25531915   0.33179724 141.        ]
10 Fantasy [ 0.  0.  0. 69.]
11 SciFi [0. 0. 0. 0.]
12 Family [6.00000000e-01 3.89610390e-02 7.31707317e-02 7.70000000e+01]
13 Biography [1.00000000e+00 5.78034682e-03 1.14942529e-02 1.73000000e+02]
14 Music [ 0.  0.  0. 47.]
15 Western [  0.76530612   0.6097561    0.67873303 123.        ]
16 Animation [0. 0. 0. 0.]
17 History [  0.   0.   0. 115.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.87341772   0.46308725   0.60526316 149.        ]
21 Musical [3.33333333e-01 2.70270270e-02 5.00000000e-02 3.70000000e+01]

Perfect matches: 310
Positive accuracy: 0.44682582745523863
Flat accuracy: 0.9208059981255943
F1-Score: 0.47200835107796413
0 Mystery [  0.57692308   0.125        0.20547945 120.        ]
1 Thriller [3.68421053e-01 5.12820513e-02 9.00321543e-02 2.73000000e+02]
2 Action [6.72413793e-01 2.15469613e-01 3.26359833e-01 3.62000000e+02]
3 Adventure [  0.67741935   0.31188119   0.42711864 202.        ]
4 Horror [  0.87341772   0.575        0.69346734 360.        ]
5 Crime [  0.62559242   0.69109948   0.65671642 382.        ]
6 Drama [6.84542587e-01 7.17750827e-01 7.00753498e-01 9.07000000e+02]
7 Comedy [6.15384615e-01 4.63320463e-01 5.28634361e-01 5.18000000e+02]
8 War [ 0.31578947  0.27272727  0.29268293 22.        ]
9 Romance [  0.43209877   0.24822695   0.31531532 141.        ]
10 Fantasy [ 0.33333333  0.08695652  0.13793103 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.61111111  0.14285714  0.23157895 77.        ]
13 Biography [  0.   0.   0. 173.]
14 Music [ 0.75        0.06382979  0.11764706 47.        ]
15 Western [  0.72641509   0.62601626   0.67248908 123.        ]
16 Animation [0. 0. 0. 0.]
17 History [  0.   0.   0. 115.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.84146341   0.46308725   0.5974026  149.        ]
21 Musical [ 0.21052632  0.10810811  0.14285714 37.        ]

Running on Data Large with 20157 samples
Size of test set: 2009
Also using 2 epochs

Perfect matches: 338
Positive accuracy: 0.4745727559316427
Flat accuracy: 0.9252228607629416
F1-Score: 0.5229275320921905
0 Mystery [8.00000000e-01 3.22580645e-02 6.20155039e-02 1.24000000e+02]
1 Thriller [4.21052632e-01 5.65371025e-02 9.96884735e-02 2.83000000e+02]
2 Action [  0.63485477   0.3984375    0.4896     384.        ]
3 Adventure [  0.62043796   0.35714286   0.45333333 238.        ]
4 Horror [  0.88016529   0.57567568   0.69607843 370.        ]
5 Crime [  0.6954023    0.6080402    0.64879357 398.        ]
6 Drama [6.69981917e-01 7.53048780e-01 7.09090909e-01 9.84000000e+02]
7 Comedy [6.68604651e-01 5.32407407e-01 5.92783505e-01 6.48000000e+02]
8 War [ 0.75        0.33333333  0.46153846 36.        ]
9 Romance [  0.49640288   0.33658537   0.40116279 205.        ]
10 Fantasy [5.00000000e-01 3.65853659e-02 6.81818182e-02 8.20000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.53333333  0.25        0.34042553 96.        ]
13 Biography [ 0.82608696  0.25675676  0.39175258 74.        ]
14 Music [1.00000000e+00 2.04081633e-02 4.00000000e-02 4.90000000e+01]
15 Western [  0.87777778   0.58088235   0.69911504 136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 1.          0.05882353  0.11111111 34.        ]
18 Sport [ 0.66666667  0.4         0.5        50.        ]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.87878788   0.37179487   0.52252252 156.        ]
21 Musical [ 0.26315789  0.11904762  0.16393443 42.        ]

Perfect matches: 330
Positive accuracy: 0.45831259333001617
Flat accuracy: 0.9235259514005264
F1-Score: 0.505436043263464
0 Mystery [  0.59375      0.15322581   0.24358974 124.        ]
1 Thriller [4.83516484e-01 1.55477032e-01 2.35294118e-01 2.83000000e+02]
2 Action [6.98529412e-01 2.47395833e-01 3.65384615e-01 3.84000000e+02]
3 Adventure [  0.68932039   0.29831933   0.41642229 238.        ]
4 Horror [  0.88617886   0.58918919   0.70779221 370.        ]
5 Crime [  0.73381295   0.51256281   0.6035503  398.        ]
6 Drama [6.53145695e-01 8.01829268e-01 7.19890511e-01 9.84000000e+02]
7 Comedy [6.88249400e-01 4.42901235e-01 5.38967136e-01 6.48000000e+02]
8 War [ 0.66666667  0.33333333  0.44444444 36.        ]
9 Romance [  0.41176471   0.44390244   0.42723005 205.        ]
10 Fantasy [5.00000000e-01 4.87804878e-02 8.88888889e-02 8.20000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.55555556  0.15625     0.24390244 96.        ]
13 Biography [ 0.9         0.12162162  0.21428571 74.        ]
14 Music [ 0.  0.  0. 49.]
15 Western [  0.80898876   0.52941176   0.64       136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 0.  0.  0. 34.]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.875        0.44871795   0.59322034 156.        ]
21 Musical [ 0.57142857  0.0952381   0.16326531 42.        ]

Perfect matches: 340
Positive accuracy: 0.4535423925667845
Flat accuracy: 0.9232770713607049
F1-Score: 0.4978844267417332
0 Mystery [  0.64285714   0.14516129   0.23684211 124.        ]
1 Thriller [4.44444444e-01 2.82685512e-02 5.31561462e-02 2.83000000e+02]
2 Action [  0.56521739   0.44010417   0.49487555 384.        ]
3 Adventure [  0.50485437   0.43697479   0.46846847 238.        ]
4 Horror [  0.8619403    0.62432432   0.72413793 370.        ]
5 Crime [  0.69047619   0.58291457   0.63215259 398.        ]
6 Drama [6.79802956e-01 7.01219512e-01 6.90345173e-01 9.84000000e+02]
7 Comedy [7.12793734e-01 4.21296296e-01 5.29582929e-01 6.48000000e+02]
8 War [ 0.7         0.38888889  0.5        36.        ]
9 Romance [  0.5412844    0.28780488   0.37579618 205.        ]
10 Fantasy [7.50000000e-01 3.65853659e-02 6.97674419e-02 8.20000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.72222222  0.13541667  0.22807018 96.        ]
13 Biography [ 0.85        0.22972973  0.36170213 74.        ]
14 Music [ 0.  0.  0. 49.]
15 Western [  0.75824176   0.50735294   0.60792952 136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 1.          0.05882353  0.11111111 34.        ]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.82758621   0.30769231   0.44859813 156.        ]
21 Musical [ 0.6         0.07142857  0.12765957 42.        ]

Combined model, using mean to combine, 2 epochs, Data Full

Size of test set: 2006
Perfect matches: 247
Positive accuracy: 0.3671070122964445
Flat accuracy: 0.9080508474576381
F1-Score: 0.388382629789021
0 Mystery [  0.   0.   0. 123.]
1 Thriller [  0.   0.   0. 283.]
2 Action [5.55555556e-01 2.22513089e-01 3.17757009e-01 3.82000000e+02]
3 Adventure [7.24137931e-01 1.76470588e-01 2.83783784e-01 2.38000000e+02]
4 Horror [9.69696970e-01 2.60162602e-01 4.10256410e-01 3.69000000e+02]
5 Crime [  0.65591398   0.61772152   0.63624511 395.        ]
6 Drama [6.34346754e-01 7.86952090e-01 7.02456779e-01 9.81000000e+02]
7 Comedy [6.52073733e-01 4.36055470e-01 5.22622345e-01 6.49000000e+02]
8 War [ 0.  0.  0. 36.]
9 Romance [  0.5          0.2961165    0.37195122 206.        ]
10 Fantasy [ 0.  0.  0. 82.]
11 SciFi [1.00000000e+00 6.41025641e-03 1.27388535e-02 1.56000000e+02]
12 Family [ 0.  0.  0. 96.]
13 Biography [ 0.  0.  0. 74.]
14 Music [7.77777778e-01 8.04597701e-02 1.45833333e-01 8.70000000e+01]
15 Western [  0.   0.   0. 137.]
16 Animation [ 0.  0.  0. 28.]
17 History [ 0.  0.  0. 34.]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]


4 epochs
Perfect matches: 328
Positive accuracy: 0.4605766035227663
Flat accuracy: 0.9144815553340078
F1-Score: 0.5110378547982589
0 Mystery [  0.54285714   0.15447154   0.24050633 123.        ]
1 Thriller [4.44444444e-01 1.13074205e-01 1.80281690e-01 2.83000000e+02]
2 Action [6.01941748e-01 3.24607330e-01 4.21768707e-01 3.82000000e+02]
3 Adventure [  0.61006289   0.40756303   0.48866499 238.        ]
4 Horror [  0.87288136   0.55826558   0.68099174 369.        ]
5 Crime [  0.66295265   0.60253165   0.63129973 395.        ]
6 Drama [6.77934272e-01 7.35983690e-01 7.05767351e-01 9.81000000e+02]
7 Comedy [6.45748988e-01 4.91525424e-01 5.58180227e-01 6.49000000e+02]
8 War [ 0.56        0.38888889  0.45901639 36.        ]
9 Romance [  0.45348837   0.37864078   0.41269841 206.        ]
10 Fantasy [ 0.  0.  0. 82.]
11 SciFi [  0.88888889   0.25641026   0.39800995 156.        ]
12 Family [ 0.56        0.14583333  0.23140496 96.        ]
13 Biography [ 0.8         0.21621622  0.34042553 74.        ]
14 Music [ 0.51666667  0.35632184  0.42176871 87.        ]
15 Western [  0.77173913   0.51824818   0.62008734 137.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 1.          0.05882353  0.11111111 34.        ]
18 Sport [3.33333333e-01 2.00000000e-02 3.77358491e-02 5.00000000e+01]
19 Film-Noir [0. 0. 0. 0.]

6 epochs
Perfect matches: 322
Positive accuracy: 0.48209538052509243
Flat accuracy: 0.9093718843469694
F1-Score: 0.5468431308956253
0 Mystery [  0.39090909   0.3495935    0.36909871 123.        ]
1 Thriller [  0.37179487   0.30742049   0.33655706 283.        ]
2 Action [  0.55769231   0.45549738   0.50144092 382.        ]
3 Adventure [  0.53296703   0.40756303   0.46190476 238.        ]
4 Horror [  0.82130584   0.64769648   0.72424242 369.        ]
5 Crime [  0.57051282   0.67594937   0.61877173 395.        ]
6 Drama [6.98342541e-01 6.44240571e-01 6.70201485e-01 9.81000000e+02]
7 Comedy [6.03119584e-01 5.36209553e-01 5.67699837e-01 6.49000000e+02]
8 War [ 0.62068966  0.5         0.55384615 36.        ]
9 Romance [  0.39572193   0.3592233    0.37659033 206.        ]
10 Fantasy [ 0.42857143  0.18292683  0.25641026 82.        ]
11 SciFi [  0.76923077   0.51282051   0.61538462 156.        ]
12 Family [ 0.5         0.26041667  0.34246575 96.        ]
13 Biography [ 0.83333333  0.2027027   0.32608696 74.        ]
14 Music [ 0.46987952  0.44827586  0.45882353 87.        ]
15 Western [  0.81318681   0.54014599   0.64912281 137.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 0.57142857  0.11764706  0.19512195 34.        ]
18 Sport [ 0.68        0.34        0.45333333 50.        ]
19 Film-Noir [0. 0. 0. 0.]

Images only, 6 epochs

Perfect matches: 176
Positive accuracy: 0.30803423064140917
Flat accuracy: 0.8884097706879402
F1-Score: 0.3376902784802167
0 Mystery [7.69230769e-02 8.13008130e-03 1.47058824e-02 1.23000000e+02]
1 Thriller [2.25225225e-01 8.83392226e-02 1.26903553e-01 2.83000000e+02]
2 Action [4.46043165e-01 1.62303665e-01 2.38003839e-01 3.82000000e+02]
3 Adventure [2.31707317e-01 1.59663866e-01 1.89054726e-01 2.38000000e+02]
4 Horror [7.20930233e-01 2.52032520e-01 3.73493976e-01 3.69000000e+02]
5 Crime [3.75000000e-01 1.13924051e-01 1.74757282e-01 3.95000000e+02]
6 Drama [6.13333333e-01 6.09582059e-01 6.11451943e-01 9.81000000e+02]
7 Comedy [5.34458509e-01 5.85516179e-01 5.58823529e-01 6.49000000e+02]
8 War [ 0.  0.  0. 36.]
9 Romance [1.92982456e-01 2.13592233e-01 2.02764977e-01 2.06000000e+02]
10 Fantasy [ 0.  0.  0. 82.]
11 SciFi [4.66666667e-01 8.97435897e-02 1.50537634e-01 1.56000000e+02]
12 Family [3.33333333e-01 8.33333333e-02 1.33333333e-01 9.60000000e+01]
13 Biography [ 0.  0.  0. 74.]
14 Music [ 0.21428571  0.13793103  0.16783217 87.        ]
15 Western [  0.484375     0.22627737   0.30845771 137.        ]
16 Animation [ 0.47058824  0.28571429  0.35555556 28.        ]
17 History [ 0.  0.  0. 34.]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]

4 epochs
Perfect matches: 185
Positive accuracy: 0.31447324692588924
Flat accuracy: 0.8917497507477612
F1-Score: 0.34376919250550314
0 Mystery [1.42857143e-01 8.13008130e-03 1.53846154e-02 1.23000000e+02]
1 Thriller [3.16455696e-01 8.83392226e-02 1.38121547e-01 2.83000000e+02]
2 Action [4.13407821e-01 1.93717277e-01 2.63814617e-01 3.82000000e+02]
3 Adventure [2.61744966e-01 1.63865546e-01 2.01550388e-01 2.38000000e+02]
4 Horror [7.01388889e-01 2.73712737e-01 3.93762183e-01 3.69000000e+02]
5 Crime [4.02684564e-01 1.51898734e-01 2.20588235e-01 3.95000000e+02]
6 Drama [5.95397891e-01 6.33027523e-01 6.13636364e-01 9.81000000e+02]
7 Comedy [5.71200000e-01 5.50077042e-01 5.60439560e-01 6.49000000e+02]
8 War [ 0.  0.  0. 36.]
9 Romance [  0.27388535   0.20873786   0.2369146  206.        ]
10 Fantasy [2.5000000e-01 1.2195122e-02 2.3255814e-02 8.2000000e+01]
11 SciFi [3.60000000e-01 1.15384615e-01 1.74757282e-01 1.56000000e+02]
12 Family [1.33333333e-01 2.08333333e-02 3.60360360e-02 9.60000000e+01]
13 Biography [ 0.  0.  0. 74.]
14 Music [1.81818182e-01 6.89655172e-02 1.00000000e-01 8.70000000e+01]
15 Western [  0.58333333   0.15328467   0.24277457 137.        ]
16 Animation [ 0.4         0.07142857  0.12121212 28.        ]
17 History [ 0.  0.  0. 34.]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]

Perfect matches: 171
Positive accuracy: 0.31995679627783313
Flat accuracy: 0.8899302093718869
F1-Score: 0.3269156809484112
0 Mystery [  0.   0.   0. 123.]
1 Thriller [2.20000000e-01 3.88692580e-02 6.60660661e-02 2.83000000e+02]
2 Action [4.33628319e-01 1.28272251e-01 1.97979798e-01 3.82000000e+02]
3 Adventure [2.50000000e-01 1.47058824e-01 1.85185185e-01 2.38000000e+02]
4 Horror [7.45762712e-01 2.38482385e-01 3.61396304e-01 3.69000000e+02]
5 Crime [2.93706294e-01 1.06329114e-01 1.56133829e-01 3.95000000e+02]
6 Drama [5.81683168e-01 7.18654434e-01 6.42954856e-01 9.81000000e+02]
7 Comedy [5.03676471e-01 6.33281972e-01 5.61092150e-01 6.49000000e+02]
8 War [ 0.  0.  0. 36.]
9 Romance [3.07692308e-01 1.35922330e-01 1.88552189e-01 2.06000000e+02]
10 Fantasy [ 0.  0.  0. 82.]
11 SciFi [3.54838710e-01 7.05128205e-02 1.17647059e-01 1.56000000e+02]
12 Family [2.41379310e-01 7.29166667e-02 1.12000000e-01 9.60000000e+01]
13 Biography [ 0.  0.  0. 74.]
14 Music [1.02564103e-01 4.59770115e-02 6.34920635e-02 8.70000000e+01]
15 Western [  0.73076923   0.13868613   0.23312883 137.        ]
16 Animation [ 0.35714286  0.17857143  0.23809524 28.        ]
17 History [ 0.  0.  0. 34.]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]