# Sponsor content detection in YouTube videos
## Transfomers for binary text classification
This notebook seeks to accomplish the task of sponsored-content detection using a binary text classification model. The text classification model is created by fine-tuning a DistilBERT pre-trained model.

## Motivation
Several similar projects based on a BERT-type text classification model have been written about in on the Internet. Unfortunately, in both instances the authors do not share details about the performance of the model. Instead, they used vague language like "95% accuracy" without qualifying that in any meaningful way. What is more, the trained models in both instances then demonstrably perform poorly in the downstream task of task classification, but no exact numbers are reported. 

We wanted to investigate how well a text classification model can perform on what is essentially a span extraction task.

In [4]:
import os
import sys
import itertools
import re

import numpy as np
import torch
from datasets import Dataset, IterableDataset, IterableDatasetDict, ClassLabel, load_dataset, load_from_disk, load_metric
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer, DataCollatorWithPadding
import pandas as pd
import pyarrow as pa

sys.path.append(os.path.dirname(os.path.realpath('..')))
from data_loader import load_examples_from_chunks, load_captions_from_chunks, get_intersection_range, segment_text, load_data_from_chunks

os.environ["WANDB_DISABLED"] = "true"

# Prepare the data

Read the transcripts from the `data.N.json.gz` and extract examples using `load_examples_from_chunks`. 

In [8]:
def chunk_video(captions, segment_ranges):
    """
    Convert a list of captions with multiple sponsor ranges
    into multiple lists of captions with 0-1 sponsor ranges each.
    """
    if len(segment_ranges) == 1:
        return captions, segment_ranges[0]
    
    segment_ranges.sort()
    
    last_chunk_end = 0
    for i, r in enumerate(segment_ranges):
        start_idx, end_idx = r
        if i + 1 < len(segment_ranges):
            # Pick the mid-point between the end of this segment and the start of the next
            # to end this chunk
            chunk_end = (end_idx + segment_ranges[i + 1][0]) // 2
        else:
            chunk_end = len(captions) - 1
        
        yield captions[last_chunk_end:chunk_end], start_idx - last_chunk_end, end_idx - last_chunk_end
        last_chunk_end = chunk_end

MAX_DURATION_PER_TOKEN = 1
        
class SpanExtractionDataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset: 'Iterable'):
        super().__init__()
        self.dataset = dataset
        
    def __iter__(self):
        for video_id, captions, sponsor_times in self.dataset:
            drop_row = False
            segment_ranges = [get_intersection_range(captions, start_time, end_time) for start_time, end_time in sponsor_times]
            # Filter out broken ranges
            segment_ranges = [r for r in segment_ranges if r[0] is not None and r[1] is not None and r[0] != r[1]]

            for caption_chunk, start_idx, end_idx in chunk_video(captions, segment_ranges):
                text = segment_text(caption_chunk)
                sponsor_text = segment_text(caption_chunk[start_idx:end_idx + 1]).strip()
                if len(sponsor_text) == 0:
                    continue
                
                start_char_idx = text.index(sponsor_text)
                end_char_idx = start_char_idx + len(sponsor_text)
                yield text, start_char_idx, end_char_idx
                
# Chunks 1-15 for training but skip 12 because it has a broken encoding (not UTF-8)
raw_train_dataset = SpanExtractionDataset(load_data_from_chunks('data', './', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]))
raw_test_dataset = SpanExtractionDataset(itertools.islice(load_data_from_chunks('data', './', [16]), 1000))

# Tokenize inputs
Tokenize the datatset with the pre-trained tokenizer. Sequences are padded to the maximum length supported by BERT and truncated if longer.

In [122]:
def index_of_token(offset_mapping, char_index, default_value):
    for i, r in enumerate(offset_mapping):
        if i == 0:
            # Skip the [CLS]
            continue
        if r[0] <= char_index <= r[1]:
            return i
        
    return default_value

class TokenizedSpanExtractionDataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset: 'Iterable', tokenize_function):
        super().__init__()
        self.dataset = dataset
        self.tokenizer = tokenize_function
        
    def __iter__(self):
        for text, start_char_idx, end_char_idx in self.dataset:
            output = self.tokenizer(text)
            
            offset_mapping = output['offset_mapping']
            # 0 is the special ignored_index [CLS]
            start_position = index_of_token(offset_mapping, start_char_idx, default_value=0)
            end_position = index_of_token(offset_mapping, end_char_idx, default_value=0)
            if start_position != 0 and end_position == 0:
                end_position = len(output['input_ids']) - 2
            
            yield {**output, 'start_char_index': start_char_idx, 'end_char_index': end_char_idx, 'start_positions': start_position, 'end_positions': end_position}

            
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def tokenize_function(text):
    return tokenizer(
        text,
        max_length=512,
        truncation=True,
        return_offsets_mapping=True,
        padding='max_length',
        stride=128,
    )

train_dataset = TokenizedSpanExtractionDataset(raw_train_dataset, tokenize_function)
test_dataset = TokenizedSpanExtractionDataset(raw_test_dataset, tokenize_function)

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.19.0",
  "vocab_size": 30522
}

loading file https://huggingface.co/distilbert-base-uncased/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/0e1bbfda7f63a99bb52e3915dcf10

# Prepare for training
Set training parameters, configure metrics, etc.

In [177]:
torch.cuda.empty_cache()
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased")
training_args = TrainingArguments(
    output_dir="distilbert-span-extraction-uncased", 
    per_device_train_batch_size=48, 
    per_device_eval_batch_size=48,
    save_total_limit=2, 
    max_steps=9_000,
    save_steps=300,
    eval_steps=301,
    save_strategy='steps',
    evaluation_strategy='steps',
    ignore_data_skip=True)

accuracy_metric = load_metric("accuracy")
precision_metric = load_metric("precision")
recall_metric = load_metric("recall")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.19.0",
  "vocab_size": 30522
}

loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/9c169103d7e5a

# Train the model ⚡
We're using the default number of batches, but we terminate the training early because we observe that the model performs extremely well on all metric on the test dataset and because the training loss and validation loss are comparable after step 30,000, indicating that there is not too much over- or under-fitting, and that the model is not likely to learn anything else.

In [178]:
trainer.train('./distilbert-span-extraction-uncased/checkpoint-2100')

Loading model from ./distilbert-span-extraction-uncased/checkpoint-2100).
***** Running training *****
  Num examples = 432000
  Num Epochs = 9223372036854775807
  Instantaneous batch size per device = 48
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 1
  Total optimization steps = 9000
  Continuing training from checkpoint, will skip to saved global_step
  Continuing training from epoch 0
  Continuing training from global step 2100
[34mOpening ./data.1.json.gz for reading...[0m
The following columns in the training set don't have a corresponding argument in `DistilBertForQuestionAnswering.forward` and have been ignored: start_char_index, end_char_index, offset_mapping. If start_char_index, end_char_index, offset_mapping are not expected by `DistilBertForQuestionAnswering.forward`,  you can safely ignore this message.


Step,Training Loss,Validation Loss
2107,0.6113,No log
2408,0.6113,No log
2709,0.5952,No log
3010,0.3779,No log
3311,0.3779,No log
3612,0.2829,No log
3913,0.2829,No log
4214,0.2095,No log
4515,0.1332,No log
4816,0.1332,No log


***** Running Evaluation *****
  Num examples: Unknown
  Batch size = 48
[34mClosed ./data.1.json.gz.[0m
[34mOpening ./data.2.json.gz for reading...[0m
[34mClosed ./data.2.json.gz.[0m
[34mOpening ./data.3.json.gz for reading...[0m
[34mClosed ./data.3.json.gz.[0m
[34mOpening ./data.4.json.gz for reading...[0m
[34mClosed ./data.4.json.gz.[0m
[34mOpening ./data.5.json.gz for reading...[0m
[34mClosed ./data.5.json.gz.[0m
[34mOpening ./data.6.json.gz for reading...[0m
[34mClosed ./data.6.json.gz.[0m
[34mOpening ./data.7.json.gz for reading...[0m
[34mClosed ./data.7.json.gz.[0m
[34mOpening ./data.8.json.gz for reading...[0m
[34mClosed ./data.8.json.gz.[0m
[34mOpening ./data.9.json.gz for reading...[0m
Saving model checkpoint to distilbert-span-extraction-uncased/checkpoint-2400
Configuration saved in distilbert-span-extraction-uncased/checkpoint-2400/config.json
Model weights saved in distilbert-span-extraction-uncased/checkpoint-2400/pytorch_model.bin
Deletin

[34mClosed ./data.7.json.gz.[0m
[34mOpening ./data.8.json.gz for reading...[0m
[34mClosed ./data.8.json.gz.[0m
[34mOpening ./data.9.json.gz for reading...[0m
[34mClosed ./data.9.json.gz.[0m
[34mOpening ./data.10.json.gz for reading...[0m
[34mClosed ./data.10.json.gz.[0m
[34mOpening ./data.11.json.gz for reading...[0m
[34mClosed ./data.11.json.gz.[0m
[34mOpening ./data.1.json.gz for reading...[0m
[34mClosed ./data.1.json.gz.[0m
[34mOpening ./data.2.json.gz for reading...[0m
[34mClosed ./data.2.json.gz.[0m
[34mOpening ./data.3.json.gz for reading...[0m
[34mClosed ./data.3.json.gz.[0m
[34mOpening ./data.4.json.gz for reading...[0m
Saving model checkpoint to distilbert-span-extraction-uncased/checkpoint-4500
Configuration saved in distilbert-span-extraction-uncased/checkpoint-4500/config.json
Model weights saved in distilbert-span-extraction-uncased/checkpoint-4500/pytorch_model.bin
Deleting older checkpoint [distilbert-span-extraction-uncased/checkpoint-390

[34mClosed ./data.3.json.gz.[0m
[34mOpening ./data.4.json.gz for reading...[0m
[34mClosed ./data.4.json.gz.[0m
[34mOpening ./data.5.json.gz for reading...[0m
[34mClosed ./data.5.json.gz.[0m
[34mOpening ./data.6.json.gz for reading...[0m
[34mClosed ./data.6.json.gz.[0m
[34mOpening ./data.7.json.gz for reading...[0m
[34mClosed ./data.7.json.gz.[0m
[34mOpening ./data.8.json.gz for reading...[0m
[34mClosed ./data.8.json.gz.[0m
[34mOpening ./data.9.json.gz for reading...[0m
Saving model checkpoint to distilbert-span-extraction-uncased/checkpoint-6600
Configuration saved in distilbert-span-extraction-uncased/checkpoint-6600/config.json
Model weights saved in distilbert-span-extraction-uncased/checkpoint-6600/pytorch_model.bin
Deleting older checkpoint [distilbert-span-extraction-uncased/checkpoint-6000] due to args.save_total_limit
[34mClosed ./data.9.json.gz.[0m
[34mOpening ./data.10.json.gz for reading...[0m
***** Running Evaluation *****
  Num examples: Unknown

[34mClosed ./data.2.json.gz.[0m
[34mOpening ./data.3.json.gz for reading...[0m
[34mClosed ./data.3.json.gz.[0m
[34mOpening ./data.4.json.gz for reading...[0m
Saving model checkpoint to distilbert-span-extraction-uncased/checkpoint-8700
Configuration saved in distilbert-span-extraction-uncased/checkpoint-8700/config.json
Model weights saved in distilbert-span-extraction-uncased/checkpoint-8700/pytorch_model.bin
Deleting older checkpoint [distilbert-span-extraction-uncased/checkpoint-8100] due to args.save_total_limit
[34mClosed ./data.4.json.gz.[0m
[34mOpening ./data.5.json.gz for reading...[0m
***** Running Evaluation *****
  Num examples: Unknown
  Batch size = 48
[34mClosed ./data.5.json.gz.[0m
[34mOpening ./data.6.json.gz for reading...[0m
[34mClosed ./data.6.json.gz.[0m
[34mOpening ./data.7.json.gz for reading...[0m
[34mClosed ./data.7.json.gz.[0m
[34mOpening ./data.8.json.gz for reading...[0m
[34mClosed ./data.8.json.gz.[0m
[34mOpening ./data.9.json.gz f

TrainOutput(global_step=9000, training_loss=0.10736266560024685, metrics={'train_runtime': 9452.9679, 'train_samples_per_second': 45.7, 'train_steps_per_second': 0.952, 'total_flos': 5.641626990001766e+16, 'train_loss': 0.10736266560024685, 'epoch': 18.0})

In [179]:
!tar -zcvf distilbert-span-extraction-uncased-checkpoint-9000.tar.gz distilbert-span-extraction-uncased/checkpoint-9000
!curl --upload-file ./distilbert-span-extraction-uncased-checkpoint-9000.tar.gz https://bashupload.com/ | cat

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
distilbert-span-extraction-uncased/checkpoint-9000/
distilbert-span-extraction-uncased/checkpoint-9000/config.json
distilbert-span-extraction-uncased/checkpoint-9000/pytorch_model.bin
distilbert-span-extraction-uncased/checkpoint-9000/training_args.bin
distilbert-span-extraction-uncased/checkpoint-9000/optimizer.pt
distilbert-span-extraction-uncased/checkpoint-9000/scheduler.pt
distilbert-span-extraction-uncased/checkpoint-9000/trainer_state.json
distilbert-span-extraction-uncased/checkpoint-9000/rng_state.pth
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible

```
Epoch	Training Loss	Validation Loss	Accuracy	Precision	Recall
    1	0.140800	    0.131712	    0.951745	0.962096	0.940169
    2	0.095600	    0.137763	    0.955120	0.957821	0.951820
    3	0.050900	    0.155389	    0.956762	0.966651	0.945832
```
We chose to use the model trained after 2 epochs because 3 seems to overfit the training set.

In [6]:
model = None
trainer = None
trained = None
torch.cuda.empty_cache()

def softmax_outputs(outputs) -> dict:
    return torch.nn.functional.softmax(outputs.logits, dim=-1)[0].tolist()

trained = AutoModelForQuestionAnswering.from_pretrained('./distilbert-span-extraction-uncased/checkpoint-9000')
trained.to('cuda')

DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            

# Run on full video transcripts

In [14]:
metric = load_metric('squad')

def evaluate(text, start_char_idx, end_char_idx):
#     print([text, text[start_char:end_char + 1]])

    inputs = tokenize_function(text)

    start_position = index_of_token(inputs['offset_mapping'], start_char_idx, default_value=0)
    end_position = index_of_token(inputs['offset_mapping'], end_char_idx, default_value=0)
    if start_position != 0 and end_position == 0:
        end_position = len(inputs['input_ids']) - 2

    outputs = trained(input_ids=torch.tensor([inputs['input_ids']]).cuda())

    pred_start_position = torch.argmax(outputs.start_logits).cpu()
    pred_end_position = torch.argmax(outputs.end_logits).cpu()
    
    pred_char_start_idx = index_of_token(inputs['offset_mapping'], pred_start_position, default_value=0)
    pred_char_end_idx = index_of_token(inputs['offset_mapping'], pred_end_position, default_value=0)
    
    predicted_answers = [{
        'id': 0,
        'prediction_text': text[pred_char_start_idx:pred_char_end_idx+1],
    }]
    
    theoretical_answers = [{
        'id': 0,
        'answers': [{
            'answer_start': start_char_idx,
            'text':text[start_char_idx:end_char_idx+1]
        }],
    }]

    print('Actual', (start_position, end_position))
    print('Predicted', (pred_start_position, pred_end_position))
    print('---------')
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)
    
for inputs in itertools.islice(SpanExtractionDataset(load_data_from_chunks('data', './', [16])), 50):
    print(evaluate(*inputs))

[34mOpening ./data.16.json.gz for reading...[0m


Actual (1, 39)
Predicted (tensor(1), tensor(55))
---------
{'exact_match': 0.0, 'f1': 5.88235294117647}
Actual (388, 510)
Predicted (tensor(399), tensor(510))
---------
{'exact_match': 0.0, 'f1': 2.0725388601036268}
Actual (11, 24)
Predicted (tensor(9), tensor(27))
---------
{'exact_match': 0.0, 'f1': 0.0}
Actual (0, 0)
Predicted (tensor(0), tensor(0))
---------
{'exact_match': 0.0, 'f1': 0.0}
Actual (30, 77)
Predicted (tensor(1), tensor(77))
---------
{'exact_match': 0.0, 'f1': 5.714285714285714}
Actual (0, 0)
Predicted (tensor(0), tensor(0))
---------
{'exact_match': 0.0, 'f1': 0.0}
Actual (380, 510)
Predicted (tensor(424), tensor(510))
---------
{'exact_match': 0.0, 'f1': 0.6269592476489028}
Actual (466, 482)
Predicted (tensor(0), tensor(0))
---------
{'exact_match': 0.0, 'f1': 0.0}
Actual (36, 98)
Predicted (tensor(30), tensor(95))
---------
{'exact_match': 0.0, 'f1': 0.0}
Actual (0, 0)
Predicted (tensor(0), tensor(0))
---------
{'exact_match': 0.0, 'f1': 0.0}
Actual (1, 11)
Predic

In [183]:
import itertools
from collections import defaultdict

from data_loader import Caption, load_captions_from_chunks, segment_text, get_intersection_range

def caption_times(c):
    return c.start, c.end

def prediction_times(p):
    return tuple(p[0])

def tumbling_time_window(captions, duration, key=caption_times):
    results = [captions[0]]
    for caption in captions:
        if key(results[-1])[1] - key(results[0])[0] <= duration:
            results.append(caption)
        else:
            yield results
            results = [caption]

    yield results
    
def session_time_window(captions, duration, key=caption_times):
    captions_iter = iter(captions)
    results = [next(captions_iter)]
    for caption in captions_iter:
        if key(results[-1])[1] - key(caption)[0] <= duration:
            results.append(caption)
        else:
            yield results
            results = [caption]

    yield results

def batch(iterable, n):
    length = len(iterable)
    for i in range(0, length, n):
        yield iterable[i:min(i + n, length)]
        
def decode_label(outputs):
    content, sponsor = outputs
    
    prediction_dict = {'sponsor': sponsor, 'content': content}
    prediction_dict = {k: v for k, v in sorted(prediction_dict.items(), key=lambda item: item[1], reverse=True)}

    return next(iter(prediction_dict.items()))
        
def predict_in_batches(texts, batch_size: int = 8):    
    batches = list(batch(texts, batch_size))
    for b in batches:
        inputs = defaultdict(list)
        for text in b:
            tokenized = tokenize_function({ 'text': text })
            for k, v in tokenized.items():
                inputs[k].append(v)
            
        inputs = { k: torch.tensor(v).cuda() for k, v in inputs.items() }
        outputs = trained(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1).tolist()
        yield from predictions
        
def predict_sponsor_segments(captions, window_duration=10):
    windows = list(tumbling_time_window(captions, window_duration))
    window_texts = [segment_text(window) for window in windows]
    predictions = predict_in_batches(window_texts, 4)
    
    for window, text, prediction in zip(windows, window_texts, predictions):
        yield [window[0].start, window[-1].end], text, *decode_label(prediction)
        
def merge_prediction_(predictions):
    assert len(set((label for _, _, label, _ in predictions))) == 1
    # All co-occurring predictions have the same label so we merge them
    merged_start, merged_end = predictions[0][0][0], predictions[-1][0][1]
    merged_text = ' '.join((text for _, text, _, _ in predictions))
    # Don't know what the correct way to compute the joint probability here is,
    # just assume they are independent; We don't really use this number anywhere
    prob = np.prod([prob for _, _, _, prob in predictions])
    return [merged_start, merged_end], merged_text, predictions[0][2], prob

def merge_predictions(predictions, within_duration=5):
    for co_occuring in session_time_window(predictions, within_duration, key=prediction_times):
        merged = [co_occuring[0]]
        for times, text, label, prob in co_occuring[1:]:
            _, _, prev_label, _ = merged[0]
            if label == prev_label:
                merged.append((times, text, label, prob))
            else:
                yield merge_prediction_(merged)
                merged = [(times, text, label, prob)]
        
        if len(merged) > 0:
            yield merge_prediction_(merged)
        

In [184]:
import itertools

def range_equals(left: 'Tuple[float, float]', right: 'Tuple[float, float]', eps: float) -> bool:
    left_start, left_end = left
    right_start, right_end = right
    
    return (abs(left_start - right_start) <= eps
        and abs(left_end - right_end) <= eps)

def count_range_equals(pairs, eps: float) -> int:
    cnt = 0
    for left, right in pairs:
        if range_equals(left, right, eps):
            cnt += 1
    return cnt

assert range_equals([0, 5], [0, 5], eps=0)
assert range_equals([1, 6], [0, 5], eps=1)
assert range_equals([-1, 4], [0, 5], eps=1)
assert not range_equals([-2, 4], [0, 5], eps=1)
assert not range_equals([1, 7], [0, 5], eps=1)

def range_negation(base: 'Tuple[float, float]', ranges: 'List[Tuple[float, float]]') -> 'List[Tuple[float, float]]':
    """
    base:    |-------------|
    ranges:  | ***   **    |
    Return:  |#   ###  ####|
    """
    results = []
    last_end = base[0]
    for r in ranges:
        if last_end != r[0]:
            results.append((last_end, r[0]))
        last_end = r[1]
    if last_end != base[1]:
        results.append((last_end, base[1]))
        
    return results
    
assert range_negation((2, 10), [(3,4), (5, 6)]) == [(2, 3), (4, 5), (6, 10)]
assert range_negation((2, 6), [(3,4), (5, 6)]) == [(2, 3), (4, 5)]
assert range_negation((3, 6), [(3,4), (5, 6)]) == [(4, 5)]

In [24]:
from termcolor import colored

def create_labels_from_range(captions, sponsor_ranges):
    caption_labels = np.zeros(len(captions), dtype=bool)
    for start_idx, end_idx in sponsor_ranges:
        for i in range(start_idx, end_idx + 1):
            caption_labels[i] = True

    token_labels = []
    for i, caption in enumerate(captions):
        num_tokens = len(caption.text.split())
        token_labels.extend([caption_labels[i]] * num_tokens)
    return token_labels

def create_labels_from_times(captions, sponsor_times):
    ranges = [get_intersection_range(captions, *pair[1]) for pair in sponsor_times]
    return create_labels_from_range(captions, ranges)

def evaluate(videos, eps=5):
    from tqdm.auto import tqdm
    
    predicted_labels = np.empty(0)
    actual_labels = np.empty(0)
    # Values for our close match metric (exact match with threshold)
    # Number of maches
    close_matches = 0
    # Number of predicted ranges
    total_predicted_ranges = 0
    
    for video_id, captions, sponsor_ranges in tqdm(videos):
        print(colored(f'{video_id} {sponsor_ranges}', None, 'on_magenta'))
        sponsor_times = [(captions[start].start, captions[end].end) for start, end in sponsor_ranges]
        predicted_sponsor_times = []

        for times, text, label, prob in merge_predictions(predict_sponsor_segments(captions, window_duration=10), within_duration=10):
            if label == 'sponsor':
                predicted_sponsor_times.append((f'{int(prob * 100)}%', times))

            color = { 'sponsor': 'yellow', 'content': None }[label]
            # print(colored(f'{int(prob * 100)}% {times[0]} <--> {times[1]} {text}', color=color))
            
            if any((range_equals(times, actual_times, eps) for actual_times in sponsor_times)):
                close_matches += 1
            total_predicted_ranges += 1

        predicted_sponsor_ranges = [get_intersection_range(captions, *pair[1]) for pair in predicted_sponsor_times]
        predicted_labels = np.append(predicted_labels, create_labels_from_range(captions, predicted_sponsor_ranges))
        actual_labels = np.append(actual_labels, create_labels_from_range(captions, sponsor_ranges))
        
        print(f'\tPredicted={predicted_sponsor_ranges},\n\tExpected={sponsor_ranges}')
    
    from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
    
    close_match_score = close_matches / total_predicted_ranges
    print(f'Exact match (with {eps}s threshold)', close_match_score)
    print('Confusion matrix', confusion_matrix(actual_labels, predicted_labels))
    print('Accuracy', accuracy_score(actual_labels, predicted_labels))
    print('Precision', precision_score(actual_labels, predicted_labels))
    print('Recall', recall_score(actual_labels, predicted_labels))
        
evaluate(list(itertools.islice(load_captions_from_chunks('data', './', [1]), 0, 10)))

[34mFound ./data.1.json.gz.[0m
[34mOpening ./data.1.json.gz for reading...[0m


Dropping --6T95cQa50 because sponsor times do not match the captions
Dropping --BXjAWlPDQ because sponsor times do not match the captions


  0%|          | 0/10 [00:00<?, ?it/s]

[45m---jcia5ufM [[28, 45]][0m
	Predicted=[(40, 45), (204, 208)],
	Expected=[[28, 45]]
[45m--4bRr1Pwlg [[28, 56]][0m
	Predicted=[(33, 42), (49, 52), (323, 326)],
	Expected=[[28, 56]]
[45m--4EqGOaEgU [[40, 54]][0m
	Predicted=[],
	Expected=[[40, 54]]
[45m--540zBQ6GI [[0, 5]][0m
	Predicted=[(0, 4)],
	Expected=[[0, 5]]
[45m--6CCgW32LE [[0, 29]][0m
	Predicted=[(0, 5), (19, 32), (81, 86)],
	Expected=[[0, 29]]
[45m--B_ZkOUCDc [[0, 2]][0m
	Predicted=[(0, 2), (116, 126)],
	Expected=[[0, 2]]
[45m--CWTjd8rkY [[746, 826]][0m
	Predicted=[(543, 548), (747, 764), (771, 776), (784, 819)],
	Expected=[[746, 826]]
[45m--cYDnBVfvE [[0, 0]][0m
	Predicted=[],
	Expected=[[0, 0]]
[45m--JOtw1cCso [[25, 46]][0m
	Predicted=[(24, 47), (445, 452)],
	Expected=[[25, 46]]
[45m--JVap-3nJU [[0, 4]][0m
	Predicted=[(0, 2), (127, 129), (153, 153), (614, 624)],
	Expected=[[0, 4]]
Exact match (with 5s threshold) 0.08333333333333333
Confusion matrix [[27189   360]
 [  510   821]]
Accuracy 0.969875346260387