In [1]:
from torch import nn
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MultiLabelMarginLoss
from transformers import WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset, RandomSampler, SequentialSampler
from tqdm import trange, tqdm
import os
import copy
import logging
from seqeval.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_fscore_support

logger = logging.getLogger(__name__)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

batch_size = 16
epochs = 2

# Prepare Dataset



In [2]:
# use pandas to read excel data as a dataframe
spreadsheet = pd.read_excel("Data_Full.xlsx")
#spreadsheet = pd.read_excel("Data_Medium.xlsx")
#spreadsheet = pd.read_excel("Data_Smallest.xlsx")
#spreadsheet = pd.read_excel("CleanData.xlsx")

# use prebuilt tags dictionary to match up with images
genres = ['Mystery','Thriller','Action','Adventure','Horror','Crime','Drama','Comedy','War','Romance','Fantasy','SciFi','Family','Biography','Music','Western','Animation','History','Sport','Film-Noir']
tags = {}
tags = {genres[i]: i for i in range(len(genres))}
print(tags)

# extract relevant entries 
movies = spreadsheet[["description", "genre"]].to_numpy()
data = []
#tags = {}
labels_array = []
longest_sentence = 0
for m in movies:
    #print(m[0].encode())
    if type(m[0]) != float:
        # tokenize text
        tokens = tokenizer.tokenize(m[0])
        tokens.insert(0, '[CLS]')
        tokens.append('[SEP]')
        tokens = tokenizer.convert_tokens_to_ids(tokens)

        if len(tokens) > longest_sentence:
            longest_sentence = len(tokens)

        # work out genre label
        genres = m[1].split(", ")
        label = []
        for g in genres:
            if g not in tags:
                tags[g] = len(tags)
                labels_array.append(g)
            label.append(tags[g])       

        data.append([tokens, label])
        
print(longest_sentence)


{'Mystery': 0, 'Thriller': 1, 'Action': 2, 'Adventure': 3, 'Horror': 4, 'Crime': 5, 'Drama': 6, 'Comedy': 7, 'War': 8, 'Romance': 9, 'Fantasy': 10, 'SciFi': 11, 'Family': 12, 'Biography': 13, 'Music': 14, 'Western': 15, 'Animation': 16, 'History': 17, 'Sport': 18, 'Film-Noir': 19}
96


In [3]:
'''
# fetch rotton tomatoes data
spreadsheet = pd.read_csv("movie_info.tsv", sep='\t')

# extract relevant entries 
movies = spreadsheet[["synopsis", "genre"]].to_numpy()
data = []
tags = {}
labels_array = []
longest_sentence = 0
for m in movies:
    #print(m[0].encode())
    if type(m[0]) != float and type(m[1]) != float:
        # tokenize text
        tokens = tokenizer.tokenize(m[0])
        tokens.insert(0, '[CLS]')
        tokens.append('[SEP]')
        tokens = tokenizer.convert_tokens_to_ids(tokens)

        if len(tokens) > longest_sentence:
            longest_sentence = len(tokens)

        # work out genre label
        genres = m[1].split("|")
        label = []
        for g in genres:
            if g not in tags:
                tags[g] = len(tags)
                labels_array.append(g)
            label.append(tags[g])       

        data.append([tokens, label])
        
longest_sentence = min(512, longest_sentence)
print(longest_sentence)
'''

'\n# fetch rotton tomatoes data\nspreadsheet = pd.read_csv("movie_info.tsv", sep=\'\t\')\n\n# extract relevant entries \nmovies = spreadsheet[["synopsis", "genre"]].to_numpy()\ndata = []\ntags = {}\nlabels_array = []\nlongest_sentence = 0\nfor m in movies:\n    #print(m[0].encode())\n    if type(m[0]) != float and type(m[1]) != float:\n        # tokenize text\n        tokens = tokenizer.tokenize(m[0])\n        tokens.insert(0, \'[CLS]\')\n        tokens.append(\'[SEP]\')\n        tokens = tokenizer.convert_tokens_to_ids(tokens)\n\n        if len(tokens) > longest_sentence:\n            longest_sentence = len(tokens)\n\n        # work out genre label\n        genres = m[1].split("|")\n        label = []\n        for g in genres:\n            if g not in tags:\n                tags[g] = len(tags)\n                labels_array.append(g)\n            label.append(tags[g])       \n\n        data.append([tokens, label])\n        \nlongest_sentence = min(512, longest_sentence)\nprint(longest_

In [4]:
# pad sequences and create n hot encodings for labels
multi_data = []
for m in data:
    # truncate sequences than BERT's hard limit of 512
    sequence = m[0][:512]
    
    # ensure that the final token in the sequence is the [SEP] token
    if sequence[-1] != 102:
        sequence[-1] = 102
 
    padding = [0] * max(0, longest_sentence - len(sequence))
    sequence.extend(padding)
    
    line = np.zeros((len(tags), 1))
    for j in m[1]:
        line[j] = 1
    multi_data.append([sequence, line])

Building the labels is similar to most other classifiers. Rather than a one hot encoding, each label corresponding to a given sample is represented with a 1. 

The tags dictionary maps the label values onto discrete integers. These correspond to features of the output layer. We then map label values onto these tags. 

In [5]:
tags

{'Mystery': 0,
 'Thriller': 1,
 'Action': 2,
 'Adventure': 3,
 'Horror': 4,
 'Crime': 5,
 'Drama': 6,
 'Comedy': 7,
 'War': 8,
 'Romance': 9,
 'Fantasy': 10,
 'SciFi': 11,
 'Family': 12,
 'Biography': 13,
 'Music': 14,
 'Western': 15,
 'Animation': 16,
 'History': 17,
 'Sport': 18,
 'Film-Noir': 19}

An example of the n hot encodings:

In [6]:
for i in range(10):
    print(multi_data[i][1].flatten())

[1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


Processing the text data is more involved. The raw text data must be converted into numeric vectors which can be passed along to BERT. Additionally, these must be of a fixed length, and have several special tokens which will be covered later. This happens through the following steps:
1. Tokenization
2. Conversion to ids
3. Adding special characters and padding
4. Passing along id vectors (in the form of torch tensors) to BERT

The first step is to pass the text to the BERT tokenizer. This will break them down into subwords which can then be converted into ids. You'll also notice that all interior tokens are prefixed with ## and that all text is converted to lowercase. 

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

tokens = tokenizer.tokenize(movies[5][0])
print(tokens)

print(tokenizer.convert_tokens_to_ids(tokens))

['john', 'mcc', '##lane', 'attempts', 'to', 'ave', '##rt', 'disaster', 'as', 'rogue', 'military', 'operatives', 'seize', 'control', 'of', 'dull', '##es', 'international', 'airport', 'in', 'washington', ',', 'd', '.', 'c', '.']
[2198, 23680, 20644, 4740, 2000, 13642, 5339, 7071, 2004, 12406, 2510, 25631, 15126, 2491, 1997, 10634, 2229, 2248, 3199, 1999, 2899, 1010, 1040, 1012, 1039, 1012]


Additionally, special tokens need to be added to the beginning and end of the sentence. The [CLS] token is what BERT uses to signal the beginning of a piece of text and the [SEP] token is used between (and after) sentences. These map to ids 101 and 102, respectively.

Due to the difficulty of sentence segmentation, we ignore that and treat each sample as a single sentence. 

In [8]:
tokens.insert(0, '[CLS]')
tokens.append('[SEP]')

print(tokens)
print(tokenizer.convert_tokens_to_ids(tokens))

['[CLS]', 'john', 'mcc', '##lane', 'attempts', 'to', 'ave', '##rt', 'disaster', 'as', 'rogue', 'military', 'operatives', 'seize', 'control', 'of', 'dull', '##es', 'international', 'airport', 'in', 'washington', ',', 'd', '.', 'c', '.', '[SEP]']
[101, 2198, 23680, 20644, 4740, 2000, 13642, 5339, 7071, 2004, 12406, 2510, 25631, 15126, 2491, 1997, 10634, 2229, 2248, 3199, 1999, 2899, 1010, 1040, 1012, 1039, 1012, 102]


Inputs into BERT must be of a common length. Since the length of text samples will vary, we pad sentences to be of a uniform length. BERT supports lengths of up to 512 tokens, though having samples this large will cause memory issues with a batch size of 32 for 12-16 GB GPU's (https://github.com/google-research/bert#out-of-memory-issues). It's common practice to use either the longest sentence length or a fixed upper bound. In these examples, the sentences tend to be pretty short and so will keep memory manageable. 

Since the longest sentence in our dataset is 97 (as shown in the second cell), that's the value we'll pad our sentences to. 

In [9]:
np.array(multi_data[3][0]).flatten()

array([ 101, 2019, 6376, 9530, 1010, 2006, 1996, 2448, 2013, 1996, 2375,
       1010, 5829, 2046, 1037, 2496, 3232, 1005, 1055, 2160, 1998, 3138,
       2058, 2037, 3268, 1012,  102,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0])

At this point, the preprocessing is complete. The next step is to prepare the training, testing, and validation sets. This also converts the data to pytorch types. 

The attention mask is a vector denoting which tokens are padding (represented as 0) and which are not (represented as 1). These are used to get BERT to ignore padding tokens. 

In [10]:
# build test, val, and training sets
threshold = [int(len(multi_data) * .8), int(len(multi_data) * .9)]
train_set, val_set, test_set = np.split(multi_data, threshold)
X_train = [] 
X_val = []
X_test = []

y_train = []
y_val = []
y_test = []
paddings = [] 

sequence_length = longest_sentence

# function to create sequence_length tensors to feed into bert, will pad or truncate as necessary
def BERT_tensorizer(doc, y=False):
    padding = max(0, sequence_length - len(doc))
    if y:
        #with_padding = [labels[CrossEntropyLoss().ignore_index]] * padding
        with_padding = [0] * padding
    else:
        with_padding = [CrossEntropyLoss().ignore_index] * padding
        with_padding = [0] * padding
        paddings.append(padding)
    return np.concatenate((np.array(doc[:sequence_length]).astype(int), with_padding))

# convert each doc to a tensor of fixed length
for doc in train_set:
    X_train.append(BERT_tensorizer(doc[0]))
    y_train.append(doc[1])

for doc in val_set:
    X_val.append(BERT_tensorizer(doc[0]))
    y_val.append(doc[1])

for doc in test_set:
    X_test.append(BERT_tensorizer(doc[0]))
    y_test.append(doc[1])

# build attention masks
train_masks = [[float(i>0) for i in ii] for ii in X_train]
val_masks = [[float(i>0) for i in ii] for ii in X_val]
test_masks = [[float(i>0) for i in ii] for ii in X_test]

# convert to tensors
train_masks = torch.tensor(train_masks)
val_masks = torch.tensor(val_masks)
test_masks = torch.tensor(test_masks)

training_data = TensorDataset(torch.tensor(X_train).to(torch.int64), train_masks, torch.tensor(y_train).to(torch.int64))
val_data = TensorDataset(torch.tensor(X_val).to(torch.int64), val_masks, torch.tensor(y_val).to(torch.int64))
test_data = TensorDataset(torch.tensor(X_test).to(torch.int64), test_masks, torch.tensor(y_test).to(torch.int64))

# then convert to dataloaders
train_dataloader = DataLoader(training_data, sampler=RandomSampler(training_data), batch_size=batch_size)
val_dataloader = DataLoader(val_data, sampler=RandomSampler(val_data), batch_size=batch_size)
test_dataloader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size)

The data is now ready for use with BERT.

# Building the Model

BERT is currently the cutting edge in NLP. The model itself is an extension of the transformer, developed by Google and published in the paper "Attention is All You Need" in July, 2017. The transformer uses multi-headed self attention rather than recurrence and significantly outperforms other models on a variety of NLP tasks, usually after only a single epoch. BERT adds the additional feature of bidirectionality, allowing the model to scan left and right as it processes samples. There are two main versions for working with English text: BERT Base and Bert Large. The Base model has 110M parameters and the large model has 340M parameters, but its memory constraints make it difficult to use on most GPUs. Additionally, there are a variety of other models in the transformers family, such as RoBERTa, GPT (and GPT2), and XLM, some of these are designed as enhancements to BERT, others work with translation or other languages. 

Multilabel classification requires 

In [11]:
from transformers import BertPreTrainedModel, BertModel
from transformers.configuration_bert import BertConfig

# extending BERT classes requires inheriting from the base model and adding on a head to perform a specific funtion
# I'm modeling my code on the transformers library's other models
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):

    # create a child class from the BertPreTrainedModel 
    # this is the same base which is used in BertForSequenceClassification
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
        self.init_weights()

    # updating the transformer head (code again adapted from BertForSequenceClassification)
    def forward(
        self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            # omit input_embeds in favor of letting the transformers module handle conversion from ids to embeddings
            #inputs_embeds=inputs_embeds,
        )

        # when taking gradients, outputs is [loss/device information, logits]
        # if with torch.no_grad() was used, only logits would be returned
        # thus in this case, outputs[1] is used
        pooled_output = outputs[1]

        # apply standard dropout
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            # this is where the code differs
            # multilabel classification uses sigmoid + Binary Cross Entropy instead of the standard Cross Entropy 
            # used by the SequenceClassification version
            # here, the WithLogits version is used to combine both steps in an efficient implementation
            # for a list of loss functions available in torch.nn see:
            # https://pytorch.org/docs/stable/nn.html#loss-functions
            loss_fct = BCEWithLogitsLoss()
            #loss_fct = MultiLabelMarginLoss()
            # loss for BCE
            loss = loss_fct(logits.view(-1, self.num_labels), labels.float().view(-1, self.num_labels))
            # loss for MLML
            #loss = loss_fct(logits.view(-1, self.num_labels), labels.long().view(-1, self.num_labels))

            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

This section specifies hyperparamaters and other variables used to manage the training process. We define the number of epochs, the batch size, as well as instantiate the specific BERT model we plan to use. The optimizer we use is AdamW with the linear schedule to manage training rates. Once this is done, configure torch to use the GPU, then send model.to() to pass the model to the GPU and model.cuda() to initialize GPU training. 

The BERT paper recommends 2-4 epochs and a batch size of 32. In my experience, performance tends to waver in the 4th epoch, making 3 sufficient in most cases. A batch size of 32 can cause problems given the size of the model. The authors used 64GB TPU's in their training, the 16GB GPU's available through Cheaha will run out of memory on long sequences. Fortunately, use of gradient accumulation can mitigate this, allowing gradients to be accumulated through mini batches then updated once sufficiently many have come in. 

In [12]:
# using the from_pretrained model applies the bert model trained on a large corpus and fine tunes it to our task
# num_labels is the output size, with sentiment analysis, use 2

# running these lines will automatically download weights if not cached
# BERT base
model = BertForMultiLabelSequenceClassification.from_pretrained('bert-base-uncased', 
                                                                 cache_dir='/data/user/jvdelmon/bertlarge/',
                                                                 num_labels=len(tags))  

# BERT large
#model = BertForMultiLabelSequenceClassification.from_pretrained('bert-large-uncased', 
#                                                                cache_dir='/data/user/jvdelmon/bertlarge/',
#                                                               num_labels=len(tags))  


FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']  # look into what these mean
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters()) 
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
#scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps =len(train_dataloader) * epochs)

# setup GPU
torch.cuda.set_device(0)
device = torch.device("cuda", 0)  # , 0
model.to(device)
model.cuda()

BertForMultiLabelSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-1

In [13]:
def evalModel(model, test_dataloader, multi_label=False):
    # predict
    model.eval()
    predictions, true_labels = [], []
    sample_out = []

    for batch in test_dataloader:
        #print(len(batch[0]), batch[0])
        batch = tuple(t.to(device) for t in batch)
        batch_labels = batch[2]
        inputs = {"input_ids": batch[0], "attention_mask": batch[1]}

        # calculating gradients is unnecessary in eval mode, saves time and memory to disable calculating it
        with torch.no_grad():
            outputs = model(**inputs)

            #batch_accuracy = flat_accuracy(outputs[0].detach().cpu().numpy(), batch_labels)
            #print(batch_accuracy)
            for i in outputs[0].cpu().detach().numpy():
                # in the case of multilabel classification, predicted labels are those with positive logit value
                if multi_label:
                    predictions.append((i > 0).astype(int))
                # single label case
                else:
                    predictions.append(np.argmax(i))
                
    return predictions

In [16]:
def train(model, optimizer, device, train_dataloader, val_dataloader):
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs, desc="Epoch")
    gradient_accumulation_steps = 32 / batch_size  # emulate batch size of 32
    logging_steps = 5000
    save_steps = 10000
    max_steps = 0
    epoch = 0
    
    # training cycle
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            #print(len(batch), batch[0])

            # tell the model to use training mode
            model.train()

            # GPU can only process tensors on the GPU
            batch = tuple(t.to(device) for t in batch)

            # configure parameters passed along to the model
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]}

            # the ** operator unpacks array/dictionary elements as optional arguments
            outputs = model(**inputs)

            # model outputs are always tuple in pytorch-transformers, the first entry being loss
            loss = outputs[0]  

            # tell the model to do backprop
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

            tr_loss += loss.item()
            
            # track gradient accumulation
            if (step + 1) % gradient_accumulation_steps == 0:
                # this performs the SGD update step
                optimizer.step()
                #scheduler.step()  # Update learning rate schedule
                
                # zeroing gradients is necessary between parameter updates
                # by default, torch accumulates gradients each time loss is calculated
                # and will only reset when calling zero_grad()
                model.zero_grad()

        # validation
        predictions = evalModel(model, val_dataloader, multi_label=True)
  
        # output results data for epoch
        preds = []
        for i in predictions:
            preds.append(i.tolist())

        true_labels = []
        for i in y_val:
            true_labels.append(i.astype(int).flatten().tolist())

        accuracy = 0
        positive_accuracy = 0
        perfect_matches = 0
        num_samples = len(true_labels)
        print(len(true_labels), len(preds))
        for i in range(num_samples):
            if true_labels[i] == preds[i]:
                accuracy += 1
                positive_accuracy += 1
                perfect_matches += 1
            else:
                # positive accuracy
                # num correct over max(num_guessed, num_true)
                positive_accuracy += np.sum(2*np.array(true_labels[i])-1==np.array(preds[i])) / max(np.sum(true_labels[i]), sum(preds[i]))

                # flat accuracy
                accuracy += np.sum(np.array(true_labels[i])==np.array(preds[i])) / len(true_labels[i])

        preds = np.array(preds)
        true_labels = np.array(true_labels)
        epoch += 1

        print("Training epoch:", epoch)
        print("Size of test set:", len(true_labels))
        print("Perfect matches: {}".format(perfect_matches))
        print("Positive accuracy:", positive_accuracy/num_samples)
        print("Flat accuracy:", accuracy/num_samples)
        print("F1-Score: {}".format(f1_score(y_pred=preds, y_true=true_labels, average='weighted')))
        stats = precision_recall_fscore_support(y_pred=preds, y_true=true_labels, average=None, labels=range(len(tags)))
        stats = np.array(stats)

        for t in tags:
            print(tags[t], t, stats[:, tags[t]])

In [17]:
train(model, optimizer, device, train_dataloader, val_dataloader)


Epoch:   0%|          | 0/2 [00:00<?, ?it/s][A

Iteration:   0%|          | 0/1003 [00:00<?, ?it/s][A[A

Iteration:   0%|          | 1/1003 [00:00<02:45,  6.05it/s][A[A

Iteration:   0%|          | 2/1003 [00:00<02:49,  5.92it/s][A[A

Iteration:   0%|          | 3/1003 [00:00<02:44,  6.08it/s][A[A

Iteration:   0%|          | 4/1003 [00:00<02:47,  5.95it/s][A[A

Iteration:   0%|          | 5/1003 [00:00<02:43,  6.10it/s][A[A

Iteration:   1%|          | 6/1003 [00:01<02:47,  5.97it/s][A[A

Iteration:   1%|          | 7/1003 [00:01<02:42,  6.12it/s][A[A

Iteration:   1%|          | 8/1003 [00:01<02:46,  5.98it/s][A[A

Iteration:   1%|          | 9/1003 [00:01<02:42,  6.12it/s][A[A

Iteration:   1%|          | 10/1003 [00:01<02:45,  5.98it/s][A[A

Iteration:   1%|          | 11/1003 [00:01<02:41,  6.13it/s][A[A

Iteration:   1%|          | 12/1003 [00:01<02:46,  5.96it/s][A[A

Iteration:   1%|▏         | 13/1003 [00:02<02:42,  6.11it/s][A[A

Iteration:   1%|

2005 2005
Training epoch: 1
Size of test set: 2005
Perfect matches: 87
Positive accuracy: 0.1773482959268493
Flat accuracy: 0.8605236907730629
F1-Score: 0.19628402153637492
0 Mystery [9.09090909e-02 1.94174757e-02 3.20000000e-02 1.03000000e+02]
1 Thriller [7.89473684e-02 3.31491713e-02 4.66926070e-02 1.81000000e+02]
2 Action [1.25000000e-01 9.04761905e-02 1.04972376e-01 2.10000000e+02]
3 Adventure [9.63855422e-02 5.81818182e-02 7.25623583e-02 2.75000000e+02]
4 Horror [1.16935484e-01 1.36150235e-01 1.25813449e-01 2.13000000e+02]
5 Crime [1.60000000e-01 1.37404580e-01 1.47843943e-01 2.62000000e+02]
6 Drama [5.04522613e-01 5.03006012e-01 5.03763171e-01 9.98000000e+02]
7 Comedy [2.56658596e-01 1.85964912e-01 2.15666328e-01 5.70000000e+02]
8 War [1.35922330e-01 9.21052632e-02 1.09803922e-01 1.52000000e+02]
9 Romance [1.61490683e-01 7.64705882e-02 1.03792415e-01 3.40000000e+02]
10 Fantasy [5.55555556e-02 1.47058824e-02 2.32558140e-02 6.80000000e+01]
11 SciFi [6.86274510e-02 4.16666667e-02 5.



Iteration:   0%|          | 1/1003 [00:00<02:36,  6.41it/s][A[A

Iteration:   0%|          | 2/1003 [00:00<02:42,  6.14it/s][A[A

Iteration:   0%|          | 3/1003 [00:00<02:40,  6.24it/s][A[A

Iteration:   0%|          | 4/1003 [00:00<02:45,  6.03it/s][A[A

Iteration:   0%|          | 5/1003 [00:00<02:42,  6.14it/s][A[A

Iteration:   1%|          | 6/1003 [00:01<02:47,  5.96it/s][A[A

Iteration:   1%|          | 7/1003 [00:01<02:43,  6.11it/s][A[A

Iteration:   1%|          | 8/1003 [00:01<02:47,  5.95it/s][A[A

Iteration:   1%|          | 9/1003 [00:01<02:43,  6.07it/s][A[A

Iteration:   1%|          | 10/1003 [00:01<02:47,  5.92it/s][A[A

Iteration:   1%|          | 11/1003 [00:01<02:43,  6.07it/s][A[A

Iteration:   1%|          | 12/1003 [00:02<02:47,  5.91it/s][A[A

Iteration:   1%|▏         | 13/1003 [00:02<02:43,  6.06it/s][A[A

Iteration:   1%|▏         | 14/1003 [00:02<02:47,  5.90it/s][A[A

Iteration:   1%|▏         | 15/1003 [00:02<02:42,  6.06

2005 2005
Training epoch: 2
Size of test set: 2005
Perfect matches: 87
Positive accuracy: 0.19322527015793806
Flat accuracy: 0.861396508728176
F1-Score: 0.20618738088068472
0 Mystery [1.05263158e-01 1.94174757e-02 3.27868852e-02 1.03000000e+02]
1 Thriller [2.08333333e-02 5.52486188e-03 8.73362445e-03 1.81000000e+02]
2 Action [8.53658537e-02 6.66666667e-02 7.48663102e-02 2.10000000e+02]
3 Adventure [1.43884892e-01 7.27272727e-02 9.66183575e-02 2.75000000e+02]
4 Horror [1.06976744e-01 1.07981221e-01 1.07476636e-01 2.13000000e+02]
5 Crime [1.57894737e-01 1.71755725e-01 1.64533821e-01 2.62000000e+02]
6 Drama [4.82553191e-01 5.68136273e-01 5.21859181e-01 9.98000000e+02]
7 Comedy [3.14285714e-01 1.92982456e-01 2.39130435e-01 5.70000000e+02]
8 War [6.52173913e-02 3.94736842e-02 4.91803279e-02 1.52000000e+02]
9 Romance [2.07692308e-01 7.94117647e-02 1.14893617e-01 3.40000000e+02]
10 Fantasy [ 0.  0.  0. 68.]
11 SciFi [1.26315789e-01 7.14285714e-02 9.12547529e-02 1.68000000e+02]
12 Family [1.66




In [18]:
predictions = evalModel(model, test_dataloader, multi_label=True)

In [19]:
# do a comparison 
preds = []
for i in predictions:
    preds.append(i.tolist())
    
true_labels = []
for i in y_test:
    true_labels.append(i.astype(int).flatten().tolist())
    
accuracy = 0
positive_accuracy = 0
perfect_matches = 0
num_samples = len(true_labels)
for i in range(num_samples):
    if true_labels[i] == preds[i]:
        accuracy += 1
        positive_accuracy += 1
        perfect_matches += 1
    else:
        # positive accuracy
        # num correct over max(num_guessed, num_true)
        positive_accuracy += np.sum(2*np.array(true_labels[i])-1==np.array(preds[i])) / max(np.sum(true_labels[i]), sum(preds[i]))
        
        # flat accuracy
        accuracy += np.sum(np.array(true_labels[i])==np.array(preds[i])) / len(true_labels[i])
     
preds = np.array(preds)
true_labels = np.array(true_labels)
    
print("Size of test set:", len(true_labels))
print("Perfect matches: {}".format(perfect_matches))
print("Positive accuracy:", positive_accuracy/num_samples)
print("Flat accuracy:", accuracy/num_samples)
print("F1-Score: {}".format(f1_score(y_pred=preds, y_true=true_labels, average='micro')))
stats = precision_recall_fscore_support(y_pred=preds, y_true=true_labels, average=None, labels=range(len(tags)))
stats = np.array(stats)

for t in tags:
    print(tags[t], t, stats[:, tags[t]])

#print("Confusion Matrix:")
#print(confusion_matrix(preds, true_labels))

Size of test set: 2006
Perfect matches: 357
Positive accuracy: 0.4733300099700907
Flat accuracy: 0.9163010967098802
F1-Score: 0.5523860303918955
0 Mystery [  0.69230769   0.14634146   0.24161074 123.        ]
1 Thriller [4.18181818e-01 1.62544170e-01 2.34096692e-01 2.83000000e+02]
2 Action [6.37614679e-01 3.63874346e-01 4.63333333e-01 3.82000000e+02]
3 Adventure [  0.66037736   0.29411765   0.40697674 238.        ]
4 Horror [  0.79402985   0.72086721   0.75568182 369.        ]
5 Crime [  0.6375       0.64556962   0.64150943 395.        ]
6 Drama [6.63033606e-01 7.44138634e-01 7.01248799e-01 9.81000000e+02]
7 Comedy [6.99438202e-01 3.83667180e-01 4.95522388e-01 6.49000000e+02]
8 War [ 0.61904762  0.36111111  0.45614035 36.        ]
9 Romance [  0.5443038    0.20873786   0.30175439 206.        ]
10 Fantasy [5.00000000e-01 7.31707317e-02 1.27659574e-01 8.20000000e+01]
11 SciFi [  0.79381443   0.49358974   0.60869565 156.        ]
12 Family [ 0.65217391  0.15625     0.25210084 96.        ]

  'recall', 'true', average, warn_for)


In [20]:
len(preds)

2006

In [21]:
preds[:20]

array([[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 1, 0,

Raw output data


Running on Data_Smallest with 12796 samples
Size of test set: 1277

4 epochs 
Perfect matches: 214
Positive accuracy: 0.444009397024275
Flat accuracy: 0.9143589378515038
F1-Score: 0.5423658558507969
0 Mystery [ 0.48387097  0.3125      0.37974684 96.        ]
1 Thriller [  0.46448087   0.40669856   0.43367347 209.        ]
2 Action [  0.51445087   0.39035088   0.44389027 228.        ]
3 Adventure [  0.61445783   0.49756098   0.54986523 205.        ]
4 Horror [  0.91385768   0.68732394   0.78456592 355.        ]
5 Crime [  0.52073733   0.62087912   0.56641604 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.48184818e-01 4.52464789e-01 5.90126292e-01 5.68000000e+02]
8 War [ 0.64285714  0.375       0.47368421 24.        ]
9 Romance [  0.28571429   0.21428571   0.24489796 112.        ]
10 Fantasy [ 0.46666667  0.10144928  0.16666667 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.51282051  0.25974026  0.34482759 77.        ]
13 Biography [ 0.66666667  0.5         0.57142857 28.        ]
14 Music [ 0.35714286  0.27777778  0.3125     18.        ]
15 Western [  0.85714286   0.61538462   0.71641791 156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 1.          0.09090909  0.16666667 11.        ]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.31818182  0.17073171  0.22222222 41.        ]
21 Sci-Fi [  0.86956522   0.46783626   0.60836502 171.        ]

Perfect matches: 170
Positive accuracy: 0.40015661707125955
Flat accuracy: 0.9083078237346108
F1-Score: 0.510587001697972
0 Mystery [ 0.37313433  0.26041667  0.30674847 96.        ]
1 Thriller [  0.4491018    0.35885167   0.39893617 209.        ]
2 Action [  0.52554745   0.31578947   0.39452055 228.        ]
3 Adventure [  0.65806452   0.49756098   0.56666667 205.        ]
4 Horror [  0.90612245   0.62535211   0.74       355.        ]
5 Crime [  0.54205607   0.63736264   0.58585859 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.64541833e-01 3.82042254e-01 5.29914530e-01 5.68000000e+02]
8 War [ 0.63636364  0.29166667  0.4        24.        ]
9 Romance [  0.28985507   0.35714286   0.32       112.        ]
10 Fantasy [ 0.40540541  0.2173913   0.28301887 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.46875     0.19480519  0.27522936 77.        ]
13 Biography [ 0.66666667  0.35714286  0.46511628 28.        ]
14 Music [ 0.57142857  0.22222222  0.32       18.        ]
15 Western [  0.80869565   0.59615385   0.68634686 156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.42307692  0.26829268  0.32835821 41.        ]
21 Sci-Fi [  0.86111111   0.3625731    0.51028807 171.        ]

2 epochs
Perfect matches: 238
Positive accuracy: 0.3987209605847025
Flat accuracy: 0.9167081939204152
F1-Score: 0.4656778510612374
0 Mystery [7.27272727e-01 8.33333333e-02 1.49532710e-01 9.60000000e+01]
1 Thriller [5.76271186e-01 1.62679426e-01 2.53731343e-01 2.09000000e+02]
2 Action [  0.48066298   0.38157895   0.42542787 228.        ]
3 Adventure [  0.71774194   0.43414634   0.54103343 205.        ]
4 Horror [  0.88850174   0.71830986   0.79439252 355.        ]
5 Crime [  0.57653061   0.62087912   0.5978836  182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [9.09090909e-01 3.34507042e-01 4.89060489e-01 5.68000000e+02]
8 War [ 0.66666667  0.25        0.36363636 24.        ]
9 Romance [  0.47222222   0.15178571   0.22972973 112.        ]
10 Fantasy [ 0.  0.  0. 69.]
11 SciFi [0. 0. 0. 0.]
12 Family [1.00000000e+00 1.29870130e-02 2.56410256e-02 7.70000000e+01]
13 Biography [ 1.          0.17857143  0.3030303  28.        ]
14 Music [ 0.  0.  0. 18.]
15 Western [  0.86868687   0.55128205   0.6745098  156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [1.00000000e+00 2.43902439e-02 4.76190476e-02 4.10000000e+01]
21 Sci-Fi [  0.90277778   0.38011696   0.53497942 171.        ]

Perfect matches: 228
Positive accuracy: 0.38958496476115784
Flat accuracy: 0.9154979710970367
F1-Score: 0.4471914401927798
0 Mystery [5.00000000e-01 4.16666667e-02 7.69230769e-02 9.60000000e+01]
1 Thriller [5.52631579e-01 1.00478469e-01 1.70040486e-01 2.09000000e+02]
2 Action [  0.45714286   0.49122807   0.47357294 228.        ]
3 Adventure [  0.76635514   0.4          0.52564103 205.        ]
4 Horror [  0.92342342   0.57746479   0.71057192 355.        ]
5 Crime [  0.55042017   0.71978022   0.62380952 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.64963504e-01 4.17253521e-01 5.62945368e-01 5.68000000e+02]
8 War [ 0.75   0.25   0.375 24.   ]
9 Romance [  0.5483871    0.15178571   0.23776224 112.        ]
10 Fantasy [ 0.  0.  0. 69.]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.72727273  0.1038961   0.18181818 77.        ]
13 Biography [ 1.          0.21428571  0.35294118 28.        ]
14 Music [ 0.  0.  0. 18.]
15 Western [  0.81914894   0.49358974   0.616      156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [1.00000000e+00 2.43902439e-02 4.76190476e-02 4.10000000e+01]
21 Sci-Fi [9.16666667e-01 1.28654971e-01 2.25641026e-01 1.71000000e+02]

3 epochs
Perfect matches: 215
Positive accuracy: 0.39323936309057594
Flat accuracy: 0.9149996440521161
F1-Score: 0.47711102003170297
0 Mystery [7.00000000e-01 7.29166667e-02 1.32075472e-01 9.60000000e+01]
1 Thriller [5.12195122e-01 2.00956938e-01 2.88659794e-01 2.09000000e+02]
2 Action [  0.5359116    0.4254386    0.47432763 228.        ]
3 Adventure [  0.73786408   0.37073171   0.49350649 205.        ]
4 Horror [  0.888        0.62535211   0.7338843  355.        ]
5 Crime [  0.55769231   0.63736264   0.59487179 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [8.82845188e-01 3.71478873e-01 5.22924411e-01 5.68000000e+02]
8 War [ 0.56        0.58333333  0.57142857 24.        ]
9 Romance [  0.38095238   0.21428571   0.27428571 112.        ]
10 Fantasy [ 0.85714286  0.08695652  0.15789474 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [5.55555556e-01 6.49350649e-02 1.16279070e-01 7.70000000e+01]
13 Biography [ 1.          0.28571429  0.44444444 28.        ]
14 Music [ 0.  0.  0. 18.]
15 Western [  0.87012987   0.42948718   0.5751073  156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.42857143  0.07317073  0.125      41.        ]
21 Sci-Fi [  0.92105263   0.40935673   0.56680162 171.        ]

Perfect matches: 218
Positive accuracy: 0.4107935264943871
Flat accuracy: 0.9156403502527276
F1-Score: 0.4940888922481723
0 Mystery [ 0.6875      0.11458333  0.19642857 96.        ]
1 Thriller [  0.45394737   0.33014354   0.38227147 209.        ]
2 Action [  0.54639175   0.23245614   0.32615385 228.        ]
3 Adventure [  0.78571429   0.37560976   0.50825083 205.        ]
4 Horror [  0.83900929   0.76338028   0.79941003 355.        ]
5 Crime [  0.53809524   0.62087912   0.57653061 182.        ]
6 Drama [0. 0. 0. 0.]
7 Comedy [9.23404255e-01 3.82042254e-01 5.40473225e-01 5.68000000e+02]
8 War [ 0.7         0.29166667  0.41176471 24.        ]
9 Romance [  0.39473684   0.26785714   0.31914894 112.        ]
10 Fantasy [ 0.5         0.13043478  0.20689655 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.56        0.18181818  0.2745098  77.        ]
13 Biography [ 0.75        0.42857143  0.54545455 28.        ]
14 Music [ 1.          0.05555556  0.10526316 18.        ]
15 Western [  0.82644628   0.64102564   0.72202166 156.        ]
16 Animation [0. 0. 0. 0.]
17 History [ 0.  0.  0. 11.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Musical [ 0.4         0.09756098  0.15686275 41.        ]
21 Sci-Fi [  0.9          0.26315789   0.40723982 171.        ]

Running on DataMedium with 18488 samples
Size of test set: 1843
Based on the results of the Data Small test, I'm using only 2 epochs

Perfect matches: 314
Positive accuracy: 0.44732320491951755
Flat accuracy: 0.9220391653924017
F1-Score: 0.4810285072493092
0 Mystery [6.50000000e-01 1.08333333e-01 1.85714286e-01 1.20000000e+02]
1 Thriller [3.61111111e-01 4.76190476e-02 8.41423948e-02 2.73000000e+02]
2 Action [6.25766871e-01 2.81767956e-01 3.88571429e-01 3.62000000e+02]
3 Adventure [  0.52892562   0.31683168   0.39628483 202.        ]
4 Horror [  0.875        0.62222222   0.72727273 360.        ]
5 Crime [  0.68194842   0.62303665   0.65116279 382.        ]
6 Drama [6.91466083e-01 6.96802646e-01 6.94124108e-01 9.07000000e+02]
7 Comedy [6.40668524e-01 4.44015444e-01 5.24515393e-01 5.18000000e+02]
8 War [ 0.5         0.18181818  0.26666667 22.        ]
9 Romance [  0.42857143   0.29787234   0.35146444 141.        ]
10 Fantasy [4.44444444e-01 5.79710145e-02 1.02564103e-01 6.90000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.6875      0.14285714  0.23655914 77.        ]
13 Biography [8.33333333e-01 2.89017341e-02 5.58659218e-02 1.73000000e+02]
14 Music [ 1.          0.10638298  0.19230769 47.        ]
15 Western [  0.7654321    0.50406504   0.60784314 123.        ]
16 Animation [0. 0. 0. 0.]
17 History [  0.   0.   0. 115.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.7704918    0.63087248   0.69372694 149.        ]
21 Musical [ 0.28571429  0.10810811  0.15686275 37.        ]

Perfect matches: 337
Positive accuracy: 0.4588081027310576
Flat accuracy: 0.9241355497459777
F1-Score: 0.48448282543563465
0 Mystery [5.45454545e-01 5.00000000e-02 9.16030534e-02 1.20000000e+02]
1 Thriller [4.11111111e-01 1.35531136e-01 2.03856749e-01 2.73000000e+02]
2 Action [6.63212435e-01 3.53591160e-01 4.61261261e-01 3.62000000e+02]
3 Adventure [  0.81132075   0.21287129   0.3372549  202.        ]
4 Horror [  0.85616438   0.69444444   0.76687117 360.        ]
5 Crime [  0.68767908   0.62827225   0.65663475 382.        ]
6 Drama [6.79282869e-01 7.51929438e-01 7.13762428e-01 9.07000000e+02]
7 Comedy [7.09897611e-01 4.01544402e-01 5.12946979e-01 5.18000000e+02]
8 War [ 0.2         0.13636364  0.16216216 22.        ]
9 Romance [  0.47368421   0.25531915   0.33179724 141.        ]
10 Fantasy [ 0.  0.  0. 69.]
11 SciFi [0. 0. 0. 0.]
12 Family [6.00000000e-01 3.89610390e-02 7.31707317e-02 7.70000000e+01]
13 Biography [1.00000000e+00 5.78034682e-03 1.14942529e-02 1.73000000e+02]
14 Music [ 0.  0.  0. 47.]
15 Western [  0.76530612   0.6097561    0.67873303 123.        ]
16 Animation [0. 0. 0. 0.]
17 History [  0.   0.   0. 115.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.87341772   0.46308725   0.60526316 149.        ]
21 Musical [3.33333333e-01 2.70270270e-02 5.00000000e-02 3.70000000e+01]

Perfect matches: 310
Positive accuracy: 0.44682582745523863
Flat accuracy: 0.9208059981255943
F1-Score: 0.47200835107796413
0 Mystery [  0.57692308   0.125        0.20547945 120.        ]
1 Thriller [3.68421053e-01 5.12820513e-02 9.00321543e-02 2.73000000e+02]
2 Action [6.72413793e-01 2.15469613e-01 3.26359833e-01 3.62000000e+02]
3 Adventure [  0.67741935   0.31188119   0.42711864 202.        ]
4 Horror [  0.87341772   0.575        0.69346734 360.        ]
5 Crime [  0.62559242   0.69109948   0.65671642 382.        ]
6 Drama [6.84542587e-01 7.17750827e-01 7.00753498e-01 9.07000000e+02]
7 Comedy [6.15384615e-01 4.63320463e-01 5.28634361e-01 5.18000000e+02]
8 War [ 0.31578947  0.27272727  0.29268293 22.        ]
9 Romance [  0.43209877   0.24822695   0.31531532 141.        ]
10 Fantasy [ 0.33333333  0.08695652  0.13793103 69.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.61111111  0.14285714  0.23157895 77.        ]
13 Biography [  0.   0.   0. 173.]
14 Music [ 0.75        0.06382979  0.11764706 47.        ]
15 Western [  0.72641509   0.62601626   0.67248908 123.        ]
16 Animation [0. 0. 0. 0.]
17 History [  0.   0.   0. 115.]
18 Sport [0. 0. 0. 0.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.84146341   0.46308725   0.5974026  149.        ]
21 Musical [ 0.21052632  0.10810811  0.14285714 37.        ]

Running on Data Large with 20157 samples
Size of test set: 2009
Also using 2 epochs

Perfect matches: 338
Positive accuracy: 0.4745727559316427
Flat accuracy: 0.9252228607629416
F1-Score: 0.5229275320921905
0 Mystery [8.00000000e-01 3.22580645e-02 6.20155039e-02 1.24000000e+02]
1 Thriller [4.21052632e-01 5.65371025e-02 9.96884735e-02 2.83000000e+02]
2 Action [  0.63485477   0.3984375    0.4896     384.        ]
3 Adventure [  0.62043796   0.35714286   0.45333333 238.        ]
4 Horror [  0.88016529   0.57567568   0.69607843 370.        ]
5 Crime [  0.6954023    0.6080402    0.64879357 398.        ]
6 Drama [6.69981917e-01 7.53048780e-01 7.09090909e-01 9.84000000e+02]
7 Comedy [6.68604651e-01 5.32407407e-01 5.92783505e-01 6.48000000e+02]
8 War [ 0.75        0.33333333  0.46153846 36.        ]
9 Romance [  0.49640288   0.33658537   0.40116279 205.        ]
10 Fantasy [5.00000000e-01 3.65853659e-02 6.81818182e-02 8.20000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.53333333  0.25        0.34042553 96.        ]
13 Biography [ 0.82608696  0.25675676  0.39175258 74.        ]
14 Music [1.00000000e+00 2.04081633e-02 4.00000000e-02 4.90000000e+01]
15 Western [  0.87777778   0.58088235   0.69911504 136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 1.          0.05882353  0.11111111 34.        ]
18 Sport [ 0.66666667  0.4         0.5        50.        ]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.87878788   0.37179487   0.52252252 156.        ]
21 Musical [ 0.26315789  0.11904762  0.16393443 42.        ]

Perfect matches: 330
Positive accuracy: 0.45831259333001617
Flat accuracy: 0.9235259514005264
F1-Score: 0.505436043263464
0 Mystery [  0.59375      0.15322581   0.24358974 124.        ]
1 Thriller [4.83516484e-01 1.55477032e-01 2.35294118e-01 2.83000000e+02]
2 Action [6.98529412e-01 2.47395833e-01 3.65384615e-01 3.84000000e+02]
3 Adventure [  0.68932039   0.29831933   0.41642229 238.        ]
4 Horror [  0.88617886   0.58918919   0.70779221 370.        ]
5 Crime [  0.73381295   0.51256281   0.6035503  398.        ]
6 Drama [6.53145695e-01 8.01829268e-01 7.19890511e-01 9.84000000e+02]
7 Comedy [6.88249400e-01 4.42901235e-01 5.38967136e-01 6.48000000e+02]
8 War [ 0.66666667  0.33333333  0.44444444 36.        ]
9 Romance [  0.41176471   0.44390244   0.42723005 205.        ]
10 Fantasy [5.00000000e-01 4.87804878e-02 8.88888889e-02 8.20000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.55555556  0.15625     0.24390244 96.        ]
13 Biography [ 0.9         0.12162162  0.21428571 74.        ]
14 Music [ 0.  0.  0. 49.]
15 Western [  0.80898876   0.52941176   0.64       136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 0.  0.  0. 34.]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.875        0.44871795   0.59322034 156.        ]
21 Musical [ 0.57142857  0.0952381   0.16326531 42.        ]

Perfect matches: 340
Positive accuracy: 0.4535423925667845
Flat accuracy: 0.9232770713607049
F1-Score: 0.4978844267417332
0 Mystery [  0.64285714   0.14516129   0.23684211 124.        ]
1 Thriller [4.44444444e-01 2.82685512e-02 5.31561462e-02 2.83000000e+02]
2 Action [  0.56521739   0.44010417   0.49487555 384.        ]
3 Adventure [  0.50485437   0.43697479   0.46846847 238.        ]
4 Horror [  0.8619403    0.62432432   0.72413793 370.        ]
5 Crime [  0.69047619   0.58291457   0.63215259 398.        ]
6 Drama [6.79802956e-01 7.01219512e-01 6.90345173e-01 9.84000000e+02]
7 Comedy [7.12793734e-01 4.21296296e-01 5.29582929e-01 6.48000000e+02]
8 War [ 0.7         0.38888889  0.5        36.        ]
9 Romance [  0.5412844    0.28780488   0.37579618 205.        ]
10 Fantasy [7.50000000e-01 3.65853659e-02 6.97674419e-02 8.20000000e+01]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.72222222  0.13541667  0.22807018 96.        ]
13 Biography [ 0.85        0.22972973  0.36170213 74.        ]
14 Music [ 0.  0.  0. 49.]
15 Western [  0.75824176   0.50735294   0.60792952 136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 1.          0.05882353  0.11111111 34.        ]
18 Sport [ 0.  0.  0. 50.]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.82758621   0.30769231   0.44859813 156.        ]
21 Musical [ 0.6         0.07142857  0.12765957 42.        ]

4 epochs
Perfect matches: 345
Positive accuracy: 0.5001244400199109
Flat accuracy: 0.9231865695280422
F1-Score: 0.5578971781672757
0 Mystery [  0.39830508   0.37903226   0.38842975 124.        ]
1 Thriller [  0.41071429   0.32508834   0.36291913 283.        ]
2 Action [6.48514851e-01 3.41145833e-01 4.47098976e-01 3.84000000e+02]
3 Adventure [  0.57746479   0.34453782   0.43157895 238.        ]
4 Horror [  0.82580645   0.69189189   0.75294118 370.        ]
5 Crime [  0.63443396   0.6758794    0.65450122 398.        ]
6 Drama [6.69520548e-01 7.94715447e-01 7.26765799e-01 9.84000000e+02]
7 Comedy [6.97270471e-01 4.33641975e-01 5.34728830e-01 6.48000000e+02]
8 War [ 0.55555556  0.41666667  0.47619048 36.        ]
9 Romance [  0.42857143   0.33658537   0.37704918 205.        ]
10 Fantasy [ 0.55555556  0.18292683  0.27522936 82.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.6097561   0.26041667  0.3649635  96.        ]
13 Biography [ 0.72413793  0.28378378  0.40776699 74.        ]
14 Music [ 0.78947368  0.30612245  0.44117647 49.        ]
15 Western [  0.80232558   0.50735294   0.62162162 136.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 0.75        0.17647059  0.28571429 34.        ]
18 Sport [ 0.72727273  0.48        0.57831325 50.        ]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.67295597   0.68589744   0.67936508 156.        ]
21 Musical [ 0.5         0.07142857  0.125      42.        ]

Perfect matches: 349
Positive accuracy: 0.4942342790774852
Flat accuracy: 0.9222363002850885
F1-Score: 0.5558154331550054
0 Mystery [  0.42708333   0.33064516   0.37272727 124.        ]
1 Thriller [  0.3805668    0.33215548   0.35471698 283.        ]
2 Action [  0.59183673   0.453125     0.51327434 384.        ]
3 Adventure [  0.58666667   0.3697479    0.45360825 238.        ]
4 Horror [  0.83386581   0.70540541   0.76427526 370.        ]
5 Crime [  0.65797101   0.57035176   0.61103634 398.        ]
6 Drama [6.69071235e-01 7.54065041e-01 7.09030100e-01 9.84000000e+02]
7 Comedy [6.63677130e-01 4.56790123e-01 5.41133455e-01 6.48000000e+02]
8 War [ 0.625       0.41666667  0.5        36.        ]
9 Romance [  0.48         0.35121951   0.4056338  205.        ]
10 Fantasy [ 0.59090909  0.15853659  0.25       82.        ]
11 SciFi [0. 0. 0. 0.]
12 Family [ 0.53846154  0.29166667  0.37837838 96.        ]
13 Biography [ 0.80645161  0.33783784  0.47619048 74.        ]
14 Music [ 0.75        0.06122449  0.11320755 49.        ]
15 Western [  0.78125      0.55147059   0.64655172 136.        ]
16 Animation [ 0.5         0.07142857  0.125      28.        ]
17 History [ 0.58333333  0.20588235  0.30434783 34.        ]
18 Sport [ 0.74074074  0.4         0.51948052 50.        ]
19 Film-Noir [0. 0. 0. 0.]
20 Sci-Fi [  0.75         0.53846154   0.62686567 156.        ]
21 Musical [ 0.27272727  0.07142857  0.11320755 42.        ]

Running Data Full on BERT Large - 2 epochs

Size of test set: 2006
Perfect matches: 370
Positive accuracy: 0.4944749086075116
Flat accuracy: 0.9176969092721945
F1-Score: 0.5600788402239633
0 Mystery [  0.43820225   0.31707317   0.36792453 123.        ]
1 Thriller [  0.46067416   0.28975265   0.35574837 283.        ]
2 Action [  0.56443299   0.57329843   0.56883117 382.        ]
3 Adventure [  0.5748503    0.40336134   0.47407407 238.        ]
4 Horror [  0.81065089   0.74254743   0.77510608 369.        ]
5 Crime [  0.62982005   0.62025316   0.625      395.        ]
6 Drama [7.14445689e-01 6.50356779e-01 6.80896478e-01 9.81000000e+02]
7 Comedy [7.71771772e-01 3.95993837e-01 5.23421589e-01 6.49000000e+02]
8 War [ 0.63888889  0.63888889  0.63888889 36.        ]
9 Romance [  0.6125       0.23786408   0.34265734 206.        ]
10 Fantasy [ 0.51851852  0.17073171  0.25688073 82.        ]
11 SciFi [  0.81443299   0.50641026   0.62450593 156.        ]
12 Family [ 0.70967742  0.22916667  0.34645669 96.        ]
13 Biography [ 0.62711864  0.5         0.55639098 74.        ]
14 Music [ 0.51923077  0.31034483  0.38848921 87.        ]
15 Western [  0.78125      0.54744526   0.64377682 137.        ]
16 Animation [ 1.          0.21428571  0.35294118 28.        ]
17 History [ 0.28571429  0.23529412  0.25806452 34.        ]
18 Sport [ 0.80952381  0.34        0.47887324 50.        ]
19 Film-Noir [0. 0. 0. 0.]

Perfect matches: 402
Positive accuracy: 0.5119225656364245
Flat accuracy: 0.9204636091724931
F1-Score: 0.5690768241257756
0 Mystery [5.00000000e-01 1.13821138e-01 1.85430464e-01 1.23000000e+02]
1 Thriller [4.34285714e-01 2.68551237e-01 3.31877729e-01 2.83000000e+02]
2 Action [  0.62271062   0.44502618   0.51908397 382.        ]
3 Adventure [  0.60233918   0.43277311   0.50366748 238.        ]
4 Horror [  0.8189415    0.79674797   0.80769231 369.        ]
5 Crime [  0.71341463   0.59240506   0.6473029  395.        ]
6 Drama [7.16775599e-01 6.70744139e-01 6.92996314e-01 9.81000000e+02]
7 Comedy [7.12359551e-01 4.88443760e-01 5.79524680e-01 6.49000000e+02]
8 War [ 0.5862069   0.47222222  0.52307692 36.        ]
9 Romance [  0.53846154   0.30582524   0.39009288 206.        ]
10 Fantasy [ 0.54166667  0.15853659  0.24528302 82.        ]
11 SciFi [  0.87671233   0.41025641   0.55895197 156.        ]
12 Family [ 0.55384615  0.375       0.44720497 96.        ]
13 Biography [ 0.69767442  0.40540541  0.51282051 74.        ]
14 Music [ 0.52631579  0.34482759  0.41666667 87.        ]
15 Western [  0.78225806   0.7080292    0.74329502 137.        ]
16 Animation [ 0.75        0.10714286  0.1875     28.        ]
17 History [ 0.66666667  0.05882353  0.10810811 34.        ]
18 Sport [ 0.65789474  0.5         0.56818182 50.        ]
19 Film-Noir [0. 0. 0. 0.]


BERT Large - 3 epochs
Perfect matches: 372
Positive accuracy: 0.49796444001329404
Flat accuracy: 0.9177716849451758
F1-Score: 0.5638168091315486
0 Mystery [  0.41538462   0.2195122    0.28723404 123.        ]
1 Thriller [  0.43965517   0.36042403   0.3961165  283.        ]
2 Action [  0.58544304   0.48429319   0.53008596 382.        ]
3 Adventure [  0.58860759   0.3907563    0.46969697 238.        ]
4 Horror [  0.86531987   0.69647696   0.77177177 369.        ]
5 Crime [  0.72222222   0.52658228   0.6090776  395.        ]
6 Drama [7.00523560e-01 6.81957187e-01 6.91115702e-01 9.81000000e+02]
7 Comedy [7.08333333e-01 4.71494607e-01 5.66142461e-01 6.49000000e+02]
8 War [ 0.59259259  0.44444444  0.50793651 36.        ]
9 Romance [  0.51851852   0.33980583   0.41055718 206.        ]
10 Fantasy [ 0.55555556  0.24390244  0.33898305 82.        ]
11 SciFi [  0.82105263   0.5          0.62151394 156.        ]
12 Family [ 0.6         0.28125     0.38297872 96.        ]
13 Biography [ 0.77777778  0.37837838  0.50909091 74.        ]
14 Music [ 0.515625    0.37931034  0.43708609 87.        ]
15 Western [  0.82051282   0.46715328   0.59534884 137.        ]
16 Animation [ 0.  0.  0. 28.]
17 History [ 0.75        0.08823529  0.15789474 34.        ]
18 Sport [ 0.7037037   0.38        0.49350649 50.        ]
19 Film-Noir [0. 0. 0. 0.]



BERT Large - 4 epochs
Perfect matches: 328
Positive accuracy: 0.49958457959455044
Flat accuracy: 0.9136340977068901
F1-Score: 0.5657790814737901
0 Mystery [  0.4          0.2601626    0.31527094 123.        ]
1 Thriller [  0.36964981   0.33568905   0.35185185 283.        ]
2 Action [6.45000000e-01 3.37696335e-01 4.43298969e-01 3.82000000e+02]
3 Adventure [  0.56470588   0.40336134   0.47058824 238.        ]
4 Horror [  0.83742331   0.7398374    0.78561151 369.        ]
5 Crime [  0.6685879    0.58734177   0.62533693 395.        ]
6 Drama [6.88605108e-01 7.14576962e-01 7.01350675e-01 9.81000000e+02]
7 Comedy [6.35869565e-01 5.40832049e-01 5.84512906e-01 6.49000000e+02]
8 War [ 0.76923077  0.27777778  0.40816327 36.        ]
9 Romance [  0.3755102    0.44660194   0.40798226 206.        ]
10 Fantasy [ 0.55555556  0.18292683  0.27522936 82.        ]
11 SciFi [  0.80851064   0.48717949   0.608      156.        ]
12 Family [ 0.60606061  0.20833333  0.31007752 96.        ]
13 Biography [ 0.67924528  0.48648649  0.56692913 74.        ]
14 Music [ 0.51315789  0.44827586  0.47852761 87.        ]
15 Western [  0.74561404   0.62043796   0.67729084 137.        ]
16 Animation [ 0.66666667  0.28571429  0.4        28.        ]
17 History [ 0.4         0.23529412  0.2962963  34.        ]
18 Sport [ 0.59459459  0.44        0.50574713 50.        ]
19 Film-Noir [0. 0. 0. 0.]

Perfect matches: 328
Positive accuracy: 0.5041126620139587
Flat accuracy: 0.9143070787637197
F1-Score: 0.5722273915499336
0 Mystery [  0.47058824   0.19512195   0.27586207 123.        ]
1 Thriller [4.13407821e-01 2.61484099e-01 3.20346320e-01 2.83000000e+02]
2 Action [  0.60288809   0.43717277   0.50682853 382.        ]
3 Adventure [  0.53763441   0.42016807   0.47169811 238.        ]
4 Horror [  0.87959866   0.71273713   0.78742515 369.        ]
5 Crime [  0.63783784   0.59746835   0.61699346 395.        ]
6 Drama [6.95247934e-01 6.86034659e-01 6.90610570e-01 9.81000000e+02]
7 Comedy [5.93703148e-01 6.10169492e-01 6.01823708e-01 6.49000000e+02]
8 War [ 0.55882353  0.52777778  0.54285714 36.        ]
9 Romance [  0.40101523   0.38349515   0.39205955 206.        ]
10 Fantasy [ 0.52        0.31707317  0.39393939 82.        ]
11 SciFi [  0.78070175   0.57051282   0.65925926 156.        ]
12 Family [ 0.53191489  0.26041667  0.34965035 96.        ]
13 Biography [ 0.71698113  0.51351351  0.5984252  74.        ]
14 Music [ 0.45679012  0.42528736  0.44047619 87.        ]
15 Western [  0.80898876   0.52554745   0.63716814 137.        ]
16 Animation [ 0.58823529  0.35714286  0.44444444 28.        ]
17 History [ 0.40909091  0.26470588  0.32142857 34.        ]
18 Sport [ 0.60526316  0.46        0.52272727 50.        ]
19 Film-Noir [0. 0. 0. 0.]