In [1]:
#!pip install -U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio] rich

# install newest transformers build to be able to pass `inputs_embeds` through generate()
!pip install --upgrade git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-1v_t3m8e
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-1v_t3m8e
  Resolved https://github.com/huggingface/transformers.git to commit fb366b9a2a94b38171896f6ba9fb9ae8bffd77af
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0m

In [2]:
!pip install torch==1.13

[0m

In [3]:
!pip install evaluate

[0m

In [4]:
from musiccaps import load_musiccaps

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split, Subset

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

from rich import print as printr
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt

import itertools
import math
import json
import random
from collections import defaultdict
from pathlib import Path
import evaluate
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
f = open('logs.txt','a')

# Load musiccaps

In [5]:
def filter_muscaps_with_embeddings(ds, embeddings):
    """Some clips weren't downloaded so we couldn't embed them, get rid of that"""
    exclude_ids = set()
    for i in range(len(ds)):
        if ds[i]["ytid"] not in embeddings.keys():
            exclude_ids.add(i)
    ds = ds.select((i for i in range(len(ds)) if i not in set(exclude_ids)))
    assert len(ds) == len(embeddings)
    return ds


In [6]:
ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)

embeddings = np.load("embeddings.npy", allow_pickle=True).item()

with open('summarized_captions.json') as f:
    summarized_captions = json.load(f)
    
ds = ds.add_column("summarized_captions", list(summarized_captions.values()))

Using custom data configuration google--MusicCaps-bedc2a0fd7888f2f
Reusing dataset csv (/root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


In [7]:
# get a list of music-related words to use for evaluation
aspects = []
for x in ds:
    aspect_str = x["aspect_list"]
    for t in "[]\"'":
        aspect_str = aspect_str.replace(t, "")
    aspects.extend(aspect_str.split(", "))

from collections import Counter

# only pick aspects that show up somewhat frequently
aspects = {s for s, count in Counter(aspects).most_common() if count >= 25}
len(aspects)

378

In [8]:
class SummarizedCaptionEmbedding(Dataset):
    """Returns a torch Dataset of paired captions and embeddings"""

    def __init__(self, muscaps_ds, embeddings, output_all_refs=False):
        ds = filter_muscaps_with_embeddings(muscaps_ds, embeddings)
        self.caption_list = ds.sort("ytid")["summarized_captions"]
        sorted_embs = [value for _, value in sorted(embeddings.items())]
        self.embeddings = torch.from_numpy(np.stack(sorted_embs)).to(device)
        self.output_all_refs = output_all_refs

    def __len__(self):
        return len(self.caption_list)

    def __getitem__(self, idx):
        emb = self.embeddings[idx]
        assert len(emb) == 512
        emb = (emb[:256] + emb[256:]) / 2
        
        if self.output_all_refs:
            return list(self.caption_list[idx]), emb
        
        caption = np.random.choice(self.caption_list[idx])
        return caption, emb

In [9]:
with open('musiccaps_split.json', 'r') as fp:
    musiccaps_split = json.load(fp)

train_ytids, valid_ytids, test_ytids = musiccaps_split['train'], musiccaps_split['valid'], musiccaps_split['test']

train_ds = ds.filter(lambda x: x['ytid'] in train_ytids)
valid_ds = ds.filter(lambda x: x['ytid'] in valid_ytids)
test_ds = ds.filter(lambda x: x['ytid'] in test_ytids)

train_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in train_ytids}
valid_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in valid_ytids}
test_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in test_ytids}

training_data = SummarizedCaptionEmbedding(muscaps_ds=train_ds, embeddings=train_embeddings)
valid_data = SummarizedCaptionEmbedding(muscaps_ds=valid_ds, embeddings=valid_embeddings)
test_data = SummarizedCaptionEmbedding(muscaps_ds=test_ds, embeddings=test_embeddings) 
    

Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-9a9e3bf18ce5ccf8.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-762f4c6e9e04b27d.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-533892882886e54e.arrow


N = len(ds['ytid'])
shuffled_indices = torch.randperm(N)
train_N, valid_N = round(N*0.8), round(N*0.1)
train_indices, valid_indices, test_indices = shuffled_indices[:train_N], shuffled_indices[train_N:train_N+valid_N], shuffled_indices[train_N+valid_N:]
train_ids, valid_ids, test_ids = ds[train_indices]['ytid'], ds[valid_indices]['ytid'], ds[test_indices]['ytid']
split_json = {'train':train_ids, 'valid':valid_ids, 'test':test_ids}
with open('musiccaps_split.json', 'w') as fp:
    json.dump(split_json, fp)

# Training

### Tokenization

target should be:

`"<bos> <pad> caption <eos> <mask...>"` (first element is dropped in transformer.forward)

input should be:

`"<bos> <music-emb> caption <eos> <pad...>"` (last element is dropped in transformer.forward)

where

- `<bos>` = `<eos>` (for gpt2, see https://github.com/huggingface/transformers/issues/2026)
- `<mask>` is -100 (masked in cross-entropy, see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
- `<pad>` is arbitrary (but needs to be valid embedding index)
- `<music-emb>` is the encoded music

to use a `<bos>` token, prepend it in `tokenize()`, set `music_emb_ind = 1`, and update the caption slicing in `eval`

In [10]:
class ResidualLinear(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        
    def forward(self, x):
        return x + self.fc(torch.nn.functional.relu(x))
    
def B2T():
    return nn.Sequential(
        nn.Linear(256, 768),
        ResidualLinear(768),
        nn.Dropout(0.6),
        ResidualLinear(768),
        nn.Dropout(0.4),
        ResidualLinear(768),
    )

In [11]:
def tokenize(captions_batch):
    input_ids = tokenizer(captions_batch)["input_ids"]

    # wrap in eos and add placeholder for music embedding/mask
    input_ids = [torch.tensor([placeholder_id] + x + [eos_id]) for x in input_ids]
    # pad with -100, this index is masked in the cross-entropy loss
    input_ids_target = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=mask_id
    ).to(device)

    # index -100 isn't valid as model input however, since the token embedding lookup fails,
    # so we need a second version as model input, with -100 replaced with another token (shouldn't matter which)
    input_ids = input_ids_target.clone()
    input_ids[input_ids == mask_id] = eos_id

    return input_ids, input_ids_target


def transform_input_ids(music_embedding, input_ids, input_ids_target):
    music_emb_ind = 0  # 1 if using <bos>, otherwise 0
    assert (input_ids[:, music_emb_ind] == placeholder_id).all()
    assert (input_ids_target[:, music_emb_ind] == placeholder_id).all()

    input_ids_target[:, music_emb_ind] = eos_id  # mask_id
    input_ids[:, music_emb_ind] = eos_id  # temp placeholder to make the embedding lookup work
    inputs_embeds = model.transformer.wte(input_ids)

    inputs_embeds[:, music_emb_ind] = b2t(music_embedding)  # insert music embedding

    return inputs_embeds, input_ids_target


def strip_eos(pred):
    """ 
    remove eos tokens from predicted captions 
    discards everything after the first <eos> that isn't the very first token
    (the hf can only skip eos but not stop at eos) 
    """
    pred = [p.removeprefix("<|endoftext|>") for p in pred]
    pred = [p[: p.find("<|endoftext|>")] if "<|endoftext|>" in p else p for p in pred]
    return pred


def eval(caption_batch, embedding_batch, rm_eos=False, **kwargs):
    model.eval()
    input_ids, input_ids_target = tokenize(caption_batch)
    inputs_embeds, _ = transform_input_ids(embedding_batch, input_ids, input_ids_target)

    # only include <bos> (optional) and music_embedding, don't include true caption
    inputs_embeds = inputs_embeds[:, :1]
    output_ids = model.generate(inputs_embeds=inputs_embeds, **kwargs)
    pred = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    pred = [p.replace("\n", "").strip() for p in pred]
    return strip_eos(pred) if rm_eos else pred

In [12]:
def update_step(inputs_embeds, input_ids_target, apply_grad):
    model.train()
    loss = model.forward(inputs_embeds=inputs_embeds, labels=input_ids_target).loss
    loss.backward()
    
    if apply_grad:
        opt.step()
        opt.zero_grad()
    
    return loss.item()

def eval_step(string_info=""):
    
    caption_batch, embedding_batch = next(iter(eval_train_dataloader))
    pred = eval(list(caption_batch), embedding_batch, **generation_params)
    printr('[green bold]TRAIN TRUE: ' + caption_batch[0])
    printr('[yellow]TRAIN PRED: ' + pred[0])
    wlog('TRAIN TRUE: ' + caption_batch[0])
    wlog('TRAIN PRED: ' + pred[0])
    caption_batch, embedding_batch = next(iter(eval_valid_dataloader))
    pred = eval(list(caption_batch), embedding_batch, **generation_params)
    printr('[green bold]TEST TRUE: ' + caption_batch[0])
    printr('[yellow]TEST PRED: ' + pred[0])
    wlog('TEST TRUE: ' + caption_batch[0])
    wlog('TEST PRED: ' + pred[0])
    print()
    
def metrics_step(n=100, shuffles=10):
    
    lq = 0
    
    print(f"Computing metrics for n={n}")
    wlog(f"Computing metrics for n={n}")
    
    shuffled_meteor_valid, shuffled_bleu_valid = 0., 0.
    train_captions, train_preds = [], []
    valid_captions, valid_preds = [], []

    for i in tqdm(range(n)):

        caption_batch, embedding_batch = next(iter(eval_train_dataloader))
        pred = eval(caption_batch, embedding_batch, **generation_params)
        train_captions.append(caption_batch[0])
        train_preds.append(pred[0])

        caption_batch, embedding_batch = next(iter(eval_valid_dataloader))
        pred = eval(caption_batch, embedding_batch, **generation_params)
        valid_captions.append(caption_batch[0])
        valid_preds.append(pred[0])
        if "The low quality recording" in pred[0]:
            lq += 1

    meteor_train = meteor.compute(predictions=train_preds, references=train_captions)['meteor']
    bleu_train = google_bleu.compute(predictions=train_preds, references=train_captions)['google_bleu']

    meteor_valid = meteor.compute(predictions=valid_preds, references=valid_captions)['meteor']
    bleu_valid = google_bleu.compute(predictions=valid_preds, references=valid_captions)['google_bleu']
    
    for j in range(shuffles):
        shuffled_captions = sorted(valid_captions, key=lambda k: random.random())
        shuffled_meteor_valid += (1./(shuffles))*meteor.compute(predictions=valid_preds, references=shuffled_captions)['meteor']
        shuffled_bleu_valid += (1./(shuffles))*google_bleu.compute(predictions=valid_preds, references=shuffled_captions)['google_bleu']

    spec_meteor = meteor_valid - shuffled_meteor_valid
    spec_bleu = bleu_valid - shuffled_bleu_valid
        
    print(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    wlog(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    print(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    wlog(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    print(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    wlog(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    print(f"Low quality recording count: {lq}")
    wlog(f"Low quality recording count: {lq}")
    
    return {'meteor_train': meteor_train,
            'meteor_valid': meteor_valid,
            'bleu_train': bleu_train,
            'bleu_valid': bleu_valid,
            'spec_meteor': spec_meteor,
            'spec_bleu': spec_bleu} 

m_styles = {"meteor_train": ("tab:orange", "solid"),
    "meteor_valid": ("tab:orange", "dashed"),
    "bleu_train": ("tab:blue", "solid"),
    "bleu_valid": ("tab:blue", "dashed"),
    "spec_bleu": ("tab:blue", "dashdot"),
    "spec_meteor": ("tab:orange", "dashdot")
}
    

In [13]:
def metrics_step(n=100, shuffles=10):
    
    lq = 0
    
    print(f"Computing metrics for n={n}")
    wlog(f"Computing metrics for n={n}")
    
    shuffled_meteor_valid, shuffled_bleu_valid = 0., 0.
    train_captions, train_preds = [], []
    valid_captions, valid_preds = [], []

    for i in tqdm(range(n)):

        caption_batch, embedding_batch = next(iter(eval_train_dataloader))
        pred = eval(list(caption_batch), embedding_batch, **generation_params)
        train_captions.append(caption_batch[0])
        train_preds.append(pred[0])

        caption_batch, embedding_batch = next(iter(eval_valid_dataloader))
        pred = eval(list(caption_batch), embedding_batch, **generation_params)
        valid_captions.append(caption_batch[0])
        valid_preds.append(pred[0])
        if "The low quality recording" in pred[0]:
            lq += 1

    meteor_train = meteor.compute(predictions=train_preds, references=train_captions)['meteor']
    bleu_train = google_bleu.compute(predictions=train_preds, references=train_captions)['google_bleu']

    meteor_valid = meteor.compute(predictions=valid_preds, references=valid_captions)['meteor']
    bleu_valid = google_bleu.compute(predictions=valid_preds, references=valid_captions)['google_bleu']
    
    for j in range(shuffles):
        shuffled_captions = sorted(valid_captions, key=lambda k: random.random())
        shuffled_meteor_valid += (1./(shuffles))*meteor.compute(predictions=valid_preds, references=shuffled_captions)['meteor']
        shuffled_bleu_valid += (1./(shuffles))*google_bleu.compute(predictions=valid_preds, references=shuffled_captions)['google_bleu']

    spec_meteor = meteor_valid - shuffled_meteor_valid
    spec_bleu = bleu_valid - shuffled_bleu_valid
        
    print(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    wlog(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    print(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    wlog(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    print(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    wlog(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    print(f"Low quality recording count: {lq}")
    wlog(f"Low quality recording count: {lq}")
    
    return {'meteor_train': meteor_train,
            'meteor_valid': meteor_valid,
            'bleu_train': bleu_train,
            'bleu_valid': bleu_valid,
            'spec_meteor': spec_meteor,
            'spec_bleu': spec_bleu} 

m_styles = {"meteor_train": ("tab:orange", "solid"),
    "meteor_valid": ("tab:orange", "dashed"),
    "bleu_train": ("tab:blue", "solid"),
    "bleu_valid": ("tab:blue", "dashed"),
    "spec_bleu": ("tab:blue", "dashdot"),
    "spec_meteor": ("tab:orange", "dashdot")
}
    

In [14]:
model_name = 'gpt2' # gpt2, gpt2-medium, gpt2-large, gpt2-xl
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.config.is_decoder = True # not sure if necessary

mask_id = -100 # don't change, this is fixed in torch cross-entropy loss!
eos_id = tokenizer.eos_token_id
placeholder_id = -200
b2t = B2T().cuda()
b2t_lr = 1e-4
gpt2_finetune_lr = 1e-4

opt = torch.optim.AdamW([
    # b2t needs to be the first parameter group
    {'params': b2t.parameters(), 'lr': b2t_lr},
    
    # disable AdamW weight decay for gpt2 layer finetuning
    {'params': model.transformer.h[1].parameters(), 'lr': gpt2_finetune_lr, 'weight_decay': 0},
    {'params': model.transformer.h[2].parameters(), 'lr': gpt2_finetune_lr, 'weight_decay': 0},
])

batch_size = 32
num_epochs = 1000
epoch = 391
gradient_acc_fact = 2
gpt2_finetune_start_epoch = 0

generation_params = dict(
    max_new_tokens=64,
    num_beams=4,
    do_sample=True,
    temperature=1.0,
    bos_token_id=eos_id,
    eos_token_id=eos_id,
    pad_token_id=mask_id,
    early_stopping=True,
    rm_eos=True,
)

losses = []

train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
eval_train_dataloader = DataLoader(training_data, 1, shuffle=True)
eval_valid_dataloader = DataLoader(valid_data, 1, shuffle=True)

meteor = evaluate.load('meteor')
google_bleu = evaluate.load("google_bleu")
metrics = {'step': [], 'bleu_train': [], 'bleu_valid': [], 'meteor_train': [], 'meteor_valid': [],
          'spec_bleu': [], 'spec_meteor':[]}

checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

load_pretrain = True
keep_training = True

pretrain_model = "checkpoints/chkp_summarized_0.0001_0.0001.pt"
if load_pretrain:
    data_dict = torch.load(pretrain_model)
    model.load_state_dict(data_dict['model'])
    b2t.load_state_dict(data_dict['b2t'])
    
model_name_info = "small" if model_name == 'gpt2' else model_name
pretraining_info = "yes" if load_pretrain and not keep_training else "no"
string_info = f"summarized_{gpt2_finetune_lr}_{b2t_lr}"

def wlog(s):
    f = open(f"outputs/logs_{string_info}.txt",'a')
    f.write(s+"\n")
    f.close()
    
if keep_training:
    with open(f'outputs/train_metrics_{string_info}.npy', 'rb') as f:
        train_metrics = np.load(f, allow_pickle=True).item()
    
    losses = list(train_metrics['loss'])
    metrics = {k: list(v) for k,v in train_metrics.items()}

string_info

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


'summarized_0.0001_0.0001'

In [None]:
for epoch in tqdm(range(epoch, num_epochs)):
    wlog(f"\nEpoch {epoch}")

    torch.save({
        "model": model.state_dict(), 
        "b2t": b2t.state_dict(), 
        "opt": opt
    }, checkpoint_dir / f"chkp_{string_info}.pt")
    #}, checkpoint_dir / f"chkp_{epoch}.pt")
    
    if epoch>0 and epoch % 100 == 0:
        print("Checkpoint saved")
        wlog("Checkpoint saved")
        torch.save({
            "model": model.state_dict(), 
            "b2t": b2t.state_dict(), 
            "opt": opt
        }, checkpoint_dir / f"chkp_{string_info}_e{epoch}.pt")

    if epoch == gpt2_finetune_start_epoch:
        print("Started finetuning b2t")
        wlog("Started finetuning b2t")
        torch.save({
        "model": model.state_dict(), 
        "b2t": b2t.state_dict(), 
        "opt": opt
        }, checkpoint_dir / f"chkp_{string_info}.pt")
        opt.param_groups[0]["lr"] = b2t_lr
        for pg in opt.param_groups[1:]:
            pg["lr"] = gpt2_finetune_lr

    for step, (caption_batch, embedding_batch) in enumerate(tqdm(train_dataloader)):
        caption_batch = list(caption_batch)
        # tokenize and prepare inputs for forward
        input_ids, input_ids_target = tokenize(caption_batch)
        inputs_embeds, input_ids_target = transform_input_ids(
            embedding_batch, input_ids, input_ids_target
        )

        apply_grad_cond = step % gradient_acc_fact == 0
        losses.append(
            update_step(inputs_embeds, input_ids_target, apply_grad=apply_grad_cond)
        )

        if epoch % 5 == 0 and step % 500 == 0:
            wlog(f"Loss {np.mean(losses[-500:])}\n")
            eval_step(string_info=string_info)
                
            plt.plot(losses, label='train_loss')
            plt.savefig(f"plots/plot_loss_{string_info}.png")
            plt.show()
            plt.clf()
            
            if epoch % 10 == 0 and epoch>0:
                
                metrics_results = metrics_step(n=100)
                metrics['step'].append(len(losses))
                for m in ['meteor_train', 'meteor_valid', 'bleu_train', 'bleu_valid', 'spec_bleu', 'spec_meteor']:
                    metrics[m].append(metrics_results[m])
                    plt.plot(metrics['step'], metrics[m], label=m, linestyle=m_styles[m][1], color=m_styles[m][0])
                plt.legend()
                plt.savefig(f"plots/plot_metrics_{string_info}.png")
                plt.show()

                with open(f'outputs/train_metrics_{string_info}.npy', 'wb') as f:
                    metrics_to_save = {k: np.array(a) for k, a in metrics.items()}
                    metrics_to_save['loss'] = np.array(losses)
                    np.save(f, metrics_to_save)

# Generate eval captions

In [15]:
import evaluate

meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [34]:
test_results_data = SummarizedCaptionEmbedding(muscaps_ds=test_ds, embeddings=test_embeddings, 
                                      output_all_refs=True) 
eval_test_dataloader = DataLoader(test_results_data, 1, shuffle=False)


eval_true_captions = []
eval_pred_captions = []
# load epoch checkpoint
model_path = f"checkpoints/chkp_summarized_0.0001_0.0001_e200.pt"
data_dict = torch.load(model_path)
model.load_state_dict(data_dict['model'])
b2t.load_state_dict(data_dict['b2t'])

gleu_score = 0
meteor_score = 0

N = len(eval_test_dataloader)

for i, (caption_batch, embedding_batch) in enumerate(tqdm(eval_test_dataloader)):
    pred = eval(caption_batch[0], embedding_batch, **generation_params)
    eval_true_captions.append([c[0] for c in caption_batch])
    eval_pred_captions.append(pred[0])
    if i >= 1000:
        break

  0%|          | 0/550 [00:00<?, ?it/s]

In [35]:
gleu_score = google_bleu.compute(predictions=eval_pred_captions, references=eval_true_captions)['google_bleu']
meteor_score = meteor.compute(predictions=eval_pred_captions, references=eval_true_captions)['meteor']
gleu_score, meteor_score

(0.11332526408450705, 0.2751089686071667)

In [36]:
json.dump(dict(
    eval_true_captions=eval_true_captions,
    eval_pred_captions=eval_pred_captions,
    tracks_ids = ["-"+str(x) for x in sorted(test_ytids)]
), open('outputs/preds_gpt2_summarized_notag_noaug.json', 'w'))

In [40]:
n_shuffles = 20
shuffled_gleu_score, shuffled_meteor_score = 0, 0
for _ in tqdm(range(n_shuffles)):
    data_true_shuffled = sorted(eval_true_captions, key=lambda k: random.random())
    meteor_sh = meteor.compute(predictions=eval_pred_captions,references=data_true_shuffled)['meteor']
    gleu_sh = google_bleu.compute(predictions=eval_pred_captions,references=data_true_shuffled)['google_bleu']
    print(meteor_score-meteor_sh, gleu_score-gleu_sh)
    shuffled_gleu_score += 1./n_shuffles * gleu_sh
    shuffled_meteor_score += 1./n_shuffles * meteor_sh
spec_meteor = meteor_score-shuffled_meteor_score
spec_gleu = gleu_score-shuffled_gleu_score

  0%|          | 0/20 [00:00<?, ?it/s]

0.03094851358580475 0.014059924546416733
0.03815543644008196 0.015874820683359422
0.029926634611551928 0.012379370909690365
0.036834558344496976 0.016551768128449787
0.03969293172003793 0.016401488737540035
0.03772487906066474 0.017732611093347492
0.039914643784096565 0.017677424993085666
0.03946004593897651 0.017661998778384605
0.030874043635128773 0.01593570585157532
0.03485621330683317 0.013482771959900816
0.03824850499696958 0.016788354827899837
0.03363027320316858 0.013672486306729265
0.038235269256087334 0.01778030554173951
0.034433729681416614 0.015187138949196377
0.037834529567885994 0.01640485636383908
0.036499843802813514 0.0161078539502498
0.03537662512721931 0.016438406547425302
0.03497738114899776 0.01555503501038459
0.029030000681703555 0.013785099328372283
0.03738044952635636 0.015449909976858328


In [41]:
spec_gleu, spec_meteor

(0.01574636662422224, 0.03570172537101457)