In [1]:
#!pip install-U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio] rich

# install newest transformers build to be able to pass `inputs_embeds` through generate()
!pip install --upgrade git+https://github.com/huggingface/transformers.git
!pip install librosa

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-jzepr01p
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-jzepr01p
  Resolved https://github.com/huggingface/transformers.git to commit 48327c57182fdade7f7797d1eaad2d166de5c55b
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0m

In [2]:
!pip install -U torch torchaudio --no-cache-dir

[0m

In [3]:
!pip install evaluate

[0m

In [4]:
from musiccaps import load_musiccaps

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split, Subset

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

from rich import print as printr
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt

import itertools
import math
import json
import random
from collections import defaultdict
from pathlib import Path
import evaluate
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import sys
sys.path.insert(1, './sota-music-tagging-models/training')
import model as sota_model

# Load musiccaps

In [5]:
ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)

music_files = {f.stem: f for f in Path('./music_npys/').iterdir()}

with open('summarized_captions.json') as f:
    summarized_captions = json.load(f)
    
ds = ds.add_column("summarized_captions", list(summarized_captions.values()))

Using custom data configuration google--MusicCaps-bedc2a0fd7888f2f
Reusing dataset csv (/root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


In [6]:
class CaptionMusicDataset(Dataset):
    """Returns a torch Dataset of paired captions and music files"""

    def __init__(self, muscaps_ds, music_files, preds_mode=False):
        include_ytids = set(muscaps_ds['ytid']) & set(music_files.keys())
        include_inds = [i for i, ytid in enumerate(muscaps_ds['ytid']) if ytid in include_ytids]
        ds = muscaps_ds.select(include_inds)
        assert len(ds) == len(music_files)

        self.ytids_sorted = ds.sort("ytid")["ytid"]
        self.caption_list = ds.sort("ytid")["summarized_captions"]
        self.sorted_music_files = [music_files[ytid] for ytid in self.ytids_sorted]
        self.preds_mode = preds_mode

    def __len__(self):
        return len(self.sorted_music_files)

    def __getitem__(self, idx):
        music = np.load(self.sorted_music_files[idx], allow_pickle=True)
        music = np.stack([music[:80000], music[-80000:]])
        
        if self.preds_mode:
            return list(self.caption_list[idx]), music, self.ytids_sorted[idx]
        
        caption = np.random.choice(self.caption_list[idx])

        return caption, music

In [7]:
use_chat_aug = True

with open('chataug.json', 'r') as fp:
    chataug_captions = json.load(fp)
    
with open('musiccaps_split.json', 'r') as fp:
    musiccaps_split = json.load(fp)

train_ytids, valid_ytids, test_ytids = musiccaps_split['train'], musiccaps_split['valid'], musiccaps_split['test']

train_ds = ds.filter(lambda x: x['ytid'] in train_ytids)
valid_ds = ds.filter(lambda x: x['ytid'] in valid_ytids)
test_ds = ds.filter(lambda x: x['ytid'] in test_ytids)

train_music_files = {ytid: e for ytid, e in music_files.items() if ytid in train_ytids}
valid_music_files = {ytid: e for ytid, e in music_files.items() if ytid in valid_ytids}
test_music_files = {ytid: e for ytid, e in music_files.items() if ytid in test_ytids}

training_data = CaptionMusicDataset(muscaps_ds=train_ds, music_files=train_music_files)
valid_data = CaptionMusicDataset(muscaps_ds=valid_ds, music_files=valid_music_files)
test_data = CaptionMusicDataset(muscaps_ds=test_ds, music_files=test_music_files) 

Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-9a9e3bf18ce5ccf8.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-762f4c6e9e04b27d.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-533892882886e54e.arrow


In [8]:
len(training_data), len(valid_data), len(test_data)

(4384, 549, 549)

# Training

In [9]:
class ResidualLinear(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        
    def forward(self, x):
        return x + self.fc(torch.nn.functional.relu(x))

class B2T(nn.Module):
    def __init__(self):
        super(B2T, self).__init__()
        self.hcnn = sota_model.HarmonicCNNCropped().to(device)
        state_dict = torch.load(f'sota-music-tagging-models/models/jamendo/hcnn/best_model.pth',
                        map_location=device)
        self.hcnn.load_state_dict(state_dict, strict=False)
        self.fc_net = nn.Sequential(
            nn.Linear(256, 768),
            nn.Dropout(0.7),
            ResidualLinear(768),
            nn.Dropout(0.5),
            ResidualLinear(768),
        )
        
    def forward(self, audio_array):
        audio_features = (self.hcnn(audio_array[:, 0, :])+self.hcnn(audio_array[:, 1, :]))/2
        audio_embedding = self.fc_net(audio_features)
    
        return audio_embedding

In [10]:
def tokenize(captions_batch):
    input_ids = tokenizer(captions_batch)["input_ids"]

    # wrap in eos and add placeholder for music embedding/mask
    input_ids = [torch.tensor([placeholder_id] + x + [eos_id]) for x in input_ids]
    # pad with -100, this index is masked in the cross-entropy loss
    input_ids_target = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=mask_id
    ).to(device)

    # index -100 isn't valid as model input however, since the token embedding lookup fails,
    # so we need a second version as model input, with -100 replaced with another token (shouldn't matter which)
    input_ids = input_ids_target.clone()
    input_ids[input_ids == mask_id] = eos_id

    return input_ids, input_ids_target


def transform_input_ids(music_array, input_ids, input_ids_target):
    music_emb_ind = 0  # 1 if using <bos>, otherwise 0
    assert (input_ids[:, music_emb_ind] == placeholder_id).all()
    assert (input_ids_target[:, music_emb_ind] == placeholder_id).all()

    input_ids_target[:, music_emb_ind] = eos_id  # mask_id
    input_ids[:, music_emb_ind] = eos_id  # temp placeholder to make the embedding lookup work
    inputs_embeds = model.transformer.wte(input_ids)
    inputs_embeds[:, music_emb_ind] = b2t(music_array.cuda())  # insert music embedding

    return inputs_embeds, input_ids_target


def strip_eos(pred):
    """ 
    remove eos tokens from predicted captions 
    discards everything after the first <eos> that isn't the very first token
    (the hf can only skip eos but not stop at eos) 
    """
    pred = [p.removeprefix("<|endoftext|>") for p in pred]
    pred = [p[: p.find("<|endoftext|>")] if "<|endoftext|>" in p else p for p in pred]
    return pred


def eval(caption_batch, audio_batch, rm_eos=False, **kwargs):
    model.eval()
    b2t.eval()
    input_ids, input_ids_target = tokenize(caption_batch)
    inputs_embeds, _ = transform_input_ids(audio_batch, input_ids, input_ids_target)

    # only include <bos> (optional) and music_embedding, don't include true caption
    inputs_embeds = inputs_embeds[:, :1]
    output_ids = model.generate(inputs_embeds=inputs_embeds, **kwargs)
    pred = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    pred = [p.replace("\n", "").strip() for p in pred]
    return strip_eos(pred) if rm_eos else pred

In [11]:
def update_step(inputs_embeds, input_ids_target, apply_grad):
    model.train()
    b2t.train()
    loss = model.forward(inputs_embeds=inputs_embeds, labels=input_ids_target).loss
    loss.backward()
    
    if apply_grad:
        opt.step()
        opt.zero_grad()
    
    return loss.item()

def eval_step(string_info=""):
    
    caption_batch, embedding_batch = next(iter(eval_train_dataloader))
    pred = eval(caption_batch, embedding_batch, **generation_params)
    printr('[green bold]TRAIN TRUE: ' + caption_batch[0])
    printr('[yellow]TRAIN PRED: ' + pred[0])
    wlog('TRAIN TRUE: ' + caption_batch[0])
    wlog('TRAIN PRED: ' + pred[0])
    caption_batch, embedding_batch = next(iter(eval_valid_dataloader))
    pred = eval(caption_batch, embedding_batch, **generation_params)
    printr('[green bold]TEST TRUE: ' + caption_batch[0])
    printr('[yellow]TEST PRED: ' + pred[0])
    wlog('TEST TRUE: ' + caption_batch[0])
    wlog('TEST PRED: ' + pred[0])
    print()
    
def metrics_step(n=100, shuffles=10):
    
    lq = 0
    
    print(f"Computing metrics for n={n}")
    wlog(f"Computing metrics for n={n}")
    
    shuffled_meteor_valid, shuffled_bleu_valid = 0., 0.
    train_captions, train_preds = [], []
    valid_captions, valid_preds = [], []

    for i in tqdm(range(n)):

        caption_batch, embedding_batch = next(iter(eval_train_dataloader))
        pred = eval(caption_batch, embedding_batch, **generation_params)
        train_captions.append(caption_batch[0])
        train_preds.append(pred[0])

        caption_batch, embedding_batch = next(iter(eval_valid_dataloader))
        pred = eval(caption_batch, embedding_batch, **generation_params)
        valid_captions.append(caption_batch[0])
        valid_preds.append(pred[0])
        if "The low quality recording" in pred[0]:
            lq += 1

    meteor_train = meteor.compute(predictions=train_preds, references=train_captions)['meteor']
    bleu_train = google_bleu.compute(predictions=train_preds, references=train_captions)['google_bleu']

    meteor_valid = meteor.compute(predictions=valid_preds, references=valid_captions)['meteor']
    bleu_valid = google_bleu.compute(predictions=valid_preds, references=valid_captions)['google_bleu']
    
    for j in range(shuffles):
        shuffled_captions = sorted(valid_captions, key=lambda k: random.random())
        shuffled_meteor_valid += (1./(shuffles))*meteor.compute(predictions=valid_preds, references=shuffled_captions)['meteor']
        shuffled_bleu_valid += (1./(shuffles))*google_bleu.compute(predictions=valid_preds, references=shuffled_captions)['google_bleu']

    spec_meteor = meteor_valid - shuffled_meteor_valid
    spec_bleu = bleu_valid - shuffled_bleu_valid
        
    print(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    wlog(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    print(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    wlog(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    print(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    wlog(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    print(f"Low quality recording count: {lq}")
    wlog(f"Low quality recording count: {lq}")
    
    return {'meteor_train': meteor_train,
            'meteor_valid': meteor_valid,
            'bleu_train': bleu_train,
            'bleu_valid': bleu_valid,
            'spec_meteor': spec_meteor,
            'spec_bleu': spec_bleu} 

m_styles = {"meteor_train": ("tab:orange", "solid"),
    "meteor_valid": ("tab:orange", "dashed"),
    "bleu_train": ("tab:blue", "solid"),
    "bleu_valid": ("tab:blue", "dashed"),
    "spec_bleu": ("tab:blue", "dashdot"),
    "spec_meteor": ("tab:orange", "dashdot")
}
    

In [12]:
import re
def metrics_step(n=100, shuffles=10):
    
    print(f"Computing metrics for n={n}")
    wlog(f"Computing metrics for n={n}")
    
    shuffled_meteor_valid, shuffled_bleu_valid = 0., 0.
    train_captions, train_preds = [], []
    valid_captions, valid_preds = [], []

    for i in tqdm(range(n)):

        caption_batch, embedding_batch = next(iter(eval_train_dataloader))
        pred = eval(caption_batch, embedding_batch, **generation_params)
        train_captions.append(re.sub(r'[^\w\s]','',caption_batch[0]))
        train_preds.append(re.sub(r'[^\w\s]','',pred[0]))

        caption_batch, embedding_batch = next(iter(eval_valid_dataloader))
        pred = eval(caption_batch, embedding_batch, **generation_params)
        valid_captions.append(re.sub(r'[^\w\s]','',caption_batch[0]))
        valid_preds.append(re.sub(r'[^\w\s]','',pred[0]))

    meteor_train = meteor.compute(predictions=train_preds, references=train_captions)['meteor']
    bleu_train = google_bleu.compute(predictions=train_preds, references=train_captions)['google_bleu']

    meteor_valid = meteor.compute(predictions=valid_preds, references=valid_captions)['meteor']
    bleu_valid = google_bleu.compute(predictions=valid_preds, references=valid_captions)['google_bleu']
    
    for j in range(shuffles):
        shuffled_captions = sorted(valid_captions, key=lambda k: random.random())
        shuffled_meteor_valid += (1./(shuffles))*meteor.compute(predictions=valid_preds, references=shuffled_captions)['meteor']
        shuffled_bleu_valid += (1./(shuffles))*google_bleu.compute(predictions=valid_preds, references=shuffled_captions)['google_bleu']

    spec_meteor = meteor_valid - shuffled_meteor_valid
    spec_bleu = bleu_valid - shuffled_bleu_valid
        
    print(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    wlog(f"Train meteor: {meteor_train:.4f}, Train bleu: {bleu_train:.4f}")
    print(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    wlog(f"Valid meteor: {meteor_valid:.4f}, Valid bleu: {bleu_valid:.4f}")
    print(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    wlog(f"Valid spec-meteor: {spec_meteor:.4f}, Valid spec-bleu: {spec_bleu:.4f}")
    
    return {'meteor_train': meteor_train,
            'meteor_valid': meteor_valid,
            'bleu_train': bleu_train,
            'bleu_valid': bleu_valid,
            'spec_meteor': spec_meteor,
            'spec_bleu': spec_bleu} 

m_styles = {"meteor_train": ("tab:orange", "solid"),
    "meteor_valid": ("tab:orange", "dashed"),
    "bleu_train": ("tab:blue", "solid"),
    "bleu_valid": ("tab:blue", "dashed"),
    "spec_bleu": ("tab:blue", "dashdot"),
    "spec_meteor": ("tab:orange", "dashdot")
}
    

In [13]:
model_name = 'gpt2' # gpt2, gpt2-medium, gpt2-large, gpt2-xl
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.config.is_decoder = True # not sure if necessary

mask_id = -100 # don't change, this is fixed in torch cross-entropy loss!
eos_id = tokenizer.eos_token_id
placeholder_id = -200
b2t = B2T().cuda()
hcnn_lr = 5e-5
b2t_lr = 1e-4
gpt2_finetune_lr = 5e-5

opt = torch.optim.AdamW([
    # b2t needs to be the first parameter group
    {'params': b2t.hcnn.layer4.parameters(), 'lr': hcnn_lr/2},
    {'params': b2t.hcnn.layer5.parameters(), 'lr': hcnn_lr/2},
    {'params': b2t.hcnn.layer6.parameters(), 'lr': hcnn_lr},
    {'params': b2t.hcnn.layer7.parameters(), 'lr': hcnn_lr},
    {'params': b2t.fc_net.parameters(), 'lr': b2t_lr},
    
    # disable AdamW weight decay for gpt2 layer finetuning
    {'params': model.transformer.h[1].parameters(), 'lr': gpt2_finetune_lr, 'weight_decay': 0},
    {'params': model.transformer.h[2].parameters(), 'lr': gpt2_finetune_lr, 'weight_decay': 0}
])

batch_size = 24
num_epochs = 500
epoch = 104
gradient_acc_fact = 1
gpt2_finetune_start_epoch = 0

generation_params = dict(
    max_new_tokens=64,
    num_beams=4,
    do_sample=True,
    temperature=1.0,
    bos_token_id=eos_id,
    eos_token_id=eos_id,
    pad_token_id=mask_id,
    early_stopping=True,
    rm_eos=True,
)

losses = []

train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
eval_train_dataloader = DataLoader(training_data, 1, shuffle=True)
eval_valid_dataloader = DataLoader(valid_data, 1, shuffle=True)

meteor = evaluate.load('meteor')
google_bleu = evaluate.load("google_bleu")
metrics = {'step': [], 'bleu_train': [], 'bleu_valid': [], 'meteor_train': [], 'meteor_valid': [],
          'spec_bleu': [], 'spec_meteor':[]}

checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

load_pretrain = True
keep_training = True

pretrain_model =  "checkpoints/chkp_enc_summarized_5e-05_5e-05_0.0001_chataug.pt"
if load_pretrain:
    data_dict = torch.load(pretrain_model)
    model.load_state_dict(data_dict['model'])
    b2t.load_state_dict(data_dict['b2t'])
    
model_name_info = "small" if model_name == 'gpt2' else model_name
pretraining_info = "yes" if load_pretrain and not keep_training else "no"
chat_aug_info = "_chataug" if use_chat_aug else ""
string_info = f"enc_summarized_{hcnn_lr}_{gpt2_finetune_lr}_{b2t_lr}{chat_aug_info}"

def wlog(s):
    f = open(f"outputs/logs_{string_info}.txt",'a')
    f.write(s+"\n")
    f.close()
    
if keep_training:
    with open(f'outputs/train_metrics_{string_info}.npy', 'rb') as f:
        train_metrics = np.load(f, allow_pickle=True).item()
    
    losses = list(train_metrics['loss'])
    metrics = {k: list(v) for k,v in train_metrics.items()}

string_info

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


'enc_summarized_5e-05_5e-05_0.0001_chataug'

In [None]:
for epoch in tqdm(range(epoch, num_epochs)):
    wlog(f"\nEpoch {epoch}")

    torch.save({
        "model": model.state_dict(), 
        "b2t": b2t.state_dict(), 
        "opt": opt
    }, checkpoint_dir / f"chkp_{string_info}.pt")
    #}, checkpoint_dir / f"chkp_{epoch}.pt")
    
    if epoch>0 and epoch % 20 == 0:
        print("Checkpoint saved")
        wlog("Checkpoint saved")
        torch.save({
            "model": model.state_dict(), 
            "b2t": b2t.state_dict(), 
            "opt": opt
        }, checkpoint_dir / f"chkp_{string_info}_e{epoch}.pt")

    for step, (caption_batch, embedding_batch) in enumerate(tqdm(train_dataloader)):
        # tokenize and prepare inputs for forward
        input_ids, input_ids_target = tokenize(list(caption_batch))
        inputs_embeds, input_ids_target = transform_input_ids(
            embedding_batch, input_ids, input_ids_target
        )

        apply_grad_cond = step % gradient_acc_fact == 0
        losses.append(
            update_step(inputs_embeds, input_ids_target, apply_grad=apply_grad_cond)
        )

        if epoch % 5 == 0 and step % 500 == 0:
            wlog(f"Loss {np.mean(losses[-500:])}\n")
            eval_step(string_info=string_info)
                
            plt.plot(losses, label='train_loss')
            plt.savefig(f"plots/plot_loss_{string_info}.png")
            plt.show()
            plt.clf()
            
            if epoch % 10 == 0 and epoch>0:
                
                metrics_results = metrics_step(n=100)
                metrics['step'].append(len(losses))
                for m in ['meteor_train', 'meteor_valid', 'bleu_train', 'bleu_valid', 'spec_bleu', 'spec_meteor']:
                    metrics[m].append(metrics_results[m])
                    plt.plot(metrics['step'], metrics[m], label=m, linestyle=m_styles[m][1], color=m_styles[m][0])
                plt.legend()
                plt.savefig(f"plots/plot_metrics_{string_info}.png")
                plt.show()

                with open(f'outputs/train_metrics_{string_info}.npy', 'wb') as f:
                    metrics_to_save = {k: np.array(a) for k, a in metrics.items()}
                    metrics_to_save['loss'] = np.array(losses)
                    np.save(f, metrics_to_save)

# Generate eval captions

In [14]:
import evaluate

meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [19]:
test_results_data = CaptionMusicDataset(muscaps_ds=test_ds, music_files=test_music_files,
                                  preds_mode=True) 
eval_test_dataloader = DataLoader(test_results_data, 1, shuffle=False)

eval_true_captions = []
eval_pred_captions = []
eval_ytids = []

model_path = "saved-models/best-enc-summarized.pt"
data_dict = torch.load(model_path)
model.load_state_dict(data_dict['model'])
b2t.load_state_dict(data_dict['b2t'])

generation_params = dict(
    max_new_tokens=200,
    num_beams=8,
    do_sample=True,
    temperature=1.0,
    bos_token_id=eos_id,
    eos_token_id=eos_id,
    pad_token_id=mask_id,
    early_stopping=True,
    rm_eos=True,
)

for i, (caption_batch, embedding_batch, ytid_batch) in tqdm(enumerate(eval_test_dataloader)):
    pred = eval(list(caption_batch[0]), embedding_batch, **generation_params)
    eval_true_captions.append([c[0] for c in caption_batch])
    eval_pred_captions.append(pred[0])
    eval_ytids.append(ytid_batch[0])

0it [00:00, ?it/s]

In [20]:
json.dump(dict(
    eval_true_captions=eval_true_captions,
    eval_pred_captions=eval_pred_captions,
    tracks_ids = eval_ytids,
), open('outputs/preds_gpt2_enc_summarized.json', 'w'))

In [21]:
lower_eval_true_captions = [[re.sub(r'[^\w\s]','',x[0]).lower() for x in caption_list] for caption_list in eval_true_captions]
lower_eval_pred_captions = [re.sub(r'[^\w\s]','',x).lower() for x in eval_pred_captions]

In [22]:
gleu_score = google_bleu.compute(predictions=lower_eval_pred_captions, 
                                 references=lower_eval_true_captions)['google_bleu']
meteor_score = meteor.compute(predictions=lower_eval_pred_captions, 
                              references=lower_eval_true_captions)['meteor']
print(gleu_score, meteor_score)

0.002148431790944433 0.03212579655756952
