In [1]:
import numpy as np
import string
import os
import torch
import torchvision
import pandas as pd
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset, random_split
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
import json
import tqdm
import random
from collections import defaultdict

import warnings
# UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() 
# or sourceTensor.clone().detach().requires_grad_(True) rather than torch.tensor(sourceTensor)
warnings.simplefilter("ignore", category=UserWarning)

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Data loading and cleanup

In [3]:
JAMENDO_TAGS = np.array(['genre---alternative','genre---ambient','genre---atmospheric','genre---chillout','genre---classical','genre---dance','genre---downtempo','genre---easylistening','genre---electronic','genre---experimental','genre---folk','genre---funk','genre---hiphop','genre---house','genre---indie','genre---instrumentalpop','genre---jazz','genre---lounge','genre---metal','genre---newage','genre---orchestral','genre---pop','genre---popfolk','genre---poprock','genre---reggae','genre---rock','genre---soundtrack','genre---techno','genre---trance','genre---triphop','genre---world','instrument---acousticguitar','instrument---bass','instrument---computer','instrument---drummachine','instrument---drums','instrument---electricguitar','instrument---electricpiano','instrument---guitar','instrument---keyboard','instrument---piano','instrument---strings','instrument---synthesizer','instrument---violin','instrument---voice','mood/theme---emotional','mood/theme---energetic','mood/theme---film','mood/theme---happy','mood/theme---relaxing'])

def get_top_tags(scores, k=3, threshold=.4):
    assert scores.shape == (2, 50)
    scores = (scores[0]+scores[1])/2
    indices = np.where(scores>threshold)[0]
    sorted_indices = indices[np.argsort(-scores[indices])[:k]]
    return JAMENDO_TAGS[sorted_indices]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
jam_tags = {}
jam_pred_tags = {}
jam_embeddings = {}
jam_scores = {}

jam_embeddings_dir = Path('/content/drive/MyDrive/jam_embeddings')

for i in tdqm.tqdm(range(100)):
    try:
        with open(jam_embeddings_dir / f'tags_{i:02d}.json') as f:
            jam_tags.update(json.load(f))
    except Exception as e:
        print(e)
        continue
    data_dict = np.load(jam_embeddings_dir / f'embeddings_{i:02d}.npy', allow_pickle=True)
    jam_embeddings.update(data_dict.item())
    data_dict = np.load(jam_embeddings_dir / f'tag_scores_{i:02d}.npy', allow_pickle=True)
    jam_scores.update(data_dict.item())
    
for k, v in jam_scores.items():
    jam_pred_tags[k] = get_top_tags(v, k=3, threshold=0.4)
    
pred_tag_counts = np.array([len(v) for v in jam_pred_tags.values()])
print(f'avg number of pred tags = {(pred_tag_counts).mean()}, fraction of samples with 0 pred tags = {(pred_tag_counts==0).mean()}')

[Errno 2] No such file or directory: '/content/drive/MyDrive/jam_embeddings/tags_35.json'
avg number of pred tags = 1.1215934627170583, fraction of samples with 0 pred tags = 0.21151453245426688


In [15]:
def cleaning_text(caption):
    table = str.maketrans('','',string.punctuation)
    caption = caption.replace("---"," ")
    #caption = caption.replace("---",": ")
    # split the sentences into words
    desc = caption.split()
    #converts to lower case
    desc = [word.lower() for word in desc]
    #remove punctuation from each token
    desc = [word.translate(table) for word in desc]
    #remove hanging 's and a 
    desc = [word for word in desc if(len(word)>1)]
    #remove tokens with numbers in them
    desc = [word for word in desc if(word.isalpha())]
    #convert back to string
    caption = ' '.join(desc)

    return caption

def preprocess_captions(captions):
    preprocessed_captions = {}
    for audio_file in tqdm.tqdm(captions):
      text_caption = ", ".join(captions[audio_file])
      preprocessed_captions[audio_file] = cleaning_text(text_caption)
            
    return preprocessed_captions

#captions = dict(zip(df.ytid,df.caption))
jam_tags_processed = preprocess_captions(jam_tags)
jam_pred_tags_processed = preprocess_captions(jam_pred_tags)

100%|██████████| 269225/269225 [00:04<00:00, 56426.28it/s]
100%|██████████| 269225/269225 [00:02<00:00, 107242.46it/s]


In [17]:
# define a vocabulary
def text_vocabulary(descriptions):
    captions = list(descriptions.values())
    vocab = set(['<start>', '<end>', ':', ',', ';'])
    for caption in captions:
        for token in caption.strip().split():
            vocab.add(token)
    return vocab

# force <pad> to have idx 0 (convention)
vocab = ['<pad>'] + list(text_vocabulary(jam_pred_tags_processed)) + list(text_vocabulary(jam_tags_processed))

word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

def tokenize(caption):
    caption = cleaning_text(caption)
    token_list = []
    # Add <start> to the beginning and <end> to the end of each caption
    caption_list = ["<start>"] + caption.split() + ["<end>"]
    token_list = [word_to_idx[word] for word in caption_list]
    return token_list 


In [152]:
# Define the audio captioning dataset
class JamendoTagDataset(Dataset):
    def __init__(self, jam_tags, jam_pred_tags, jam_embeddings):
        
        self.keys = sorted(jam_tags.keys())
        self.jam_tags = jam_tags
        self.jam_pred_tags = jam_pred_tags
        self.jam_embeddings = jam_embeddings

    def __len__(self):
        return len(self.jam_tags)

    def __getitem__(self, idx):
        id = self.keys[idx]
        
        tags = self.jam_tags[id]
        categories = defaultdict(set)
        for t in tags:
            assert '---' in t
            categories[t[:t.find('---')]].add(t[t.find('---')+3:])
            
        result = []
        for k in sorted(categories.keys()):
            cat_tags = list(categories[k])
            result.append(k + ': ' + ', '.join(random.sample(cat_tags, len(cat_tags))))
        tags_cap = '; '.join(result)
            
        #tags = [t.replace('---', ': ') for t in tags]
        #tags = [t[t.find('---')+3:] if '---' in t else t for t in tags]
        #random.shuffle(tags)
        #tags_cap = ', '.join(tags)
        
        emb = self.jam_embeddings[id]
        assert emb.shape == (2, 256)
        emb = np.concatenate([emb[0],emb[1]])
        assert emb.shape == (512,)
        
        return {"song_id": id.split("_")[0], "start_timestep": id.split("_")[1], "caption": tags_cap, "tokenized_caption":tokenize(tags_cap), "embedding":torch.from_numpy(emb).to(device)}
      

In [166]:
train_frac = 0.8
song_ids = list(set([x.split("_")[0] for x in jam_tags]))
# split efficiently dicts!
train_mask = random.choices([True, False], weights=[0.8, 0.2], k=len(song_ids))
song_is_train = {}
for i, song in enumerate(song_ids):
  song_is_train[song] = train_mask[i]

train_jam_tags = {x: jam_tags[x] for x in jam_tags if song_is_train[x.split("_")[0]]}
test_jam_tags = {x: jam_tags[x] for x in jam_tags if not song_is_train[x.split("_")[0]]}
train_jam_pred_tags = {x: jam_pred_tags[x] for x in jam_pred_tags if song_is_train[x.split("_")[0]]}
test_jam_pred_tags = {x: jam_pred_tags[x] for x in jam_pred_tags if not song_is_train[x.split("_")[0]]}
train_jam_embeddings = {x: jam_embeddings[x] for x in jam_embeddings if song_is_train[x.split("_")[0]]}
test_jam_embeddings = {x: jam_embeddings[x] for x in jam_embeddings if not song_is_train[x.split("_")[0]]}

training_data = JamendoTagDataset(train_jam_tags, train_jam_pred_tags, train_jam_embeddings)
test_data = JamendoTagDataset(test_jam_tags, test_jam_pred_tags, test_jam_embeddings)

print(len(training_data), len(test_data))

215270 53955


In [167]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0.4):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(p=dropout)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)

    def forward(self, input, features):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        # concatenate audio features and input embedding
        inputs = torch.cat((features.unsqueeze(1), embedded), dim=1)
        output, hidden = self.lstm(inputs)
        output = self.fc(output)[:,1:,:]

        return output, hidden


## Model training

In [168]:
# Instantiate the model
vocab_size = len(vocab)
input_size = vocab_size
hidden_size = 512
output_size = vocab_size
num_layers = 2
model = LSTMModel(input_size, hidden_size, output_size, num_layers)


# Train the model
lr = 5e-4
batch_size = 512
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
loss_fn = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])

# Define the collate function for the audio captioning dataset
def collate_fn_try(batch):
    embeddings = []
    captions = []
    for b in batch:
        embeddings.append(b['embedding'])
        captions.append(torch.tensor(b['tokenized_caption']))
    padded_embeddings = nn.utils.rnn.pad_sequence(embeddings, batch_first=True)
    padded_captions = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=word_to_idx["<pad>"])  # Use the <pad> index for padding
    return padded_embeddings.to(device, dtype=torch.float), padded_captions.to(device, dtype=torch.long)

train_dataloader = DataLoader(training_data, batch_size, shuffle=True, collate_fn=collate_fn_try)
eval_train_dataloader = DataLoader(training_data, batch_size, shuffle=True, collate_fn=collate_fn_try)
eval_test_dataloader = DataLoader(test_data, batch_size, shuffle=True, collate_fn=collate_fn_try)

In [169]:
# load a batch
batch = next(iter(train_dataloader))
print(f"Embeddings batch, shape={batch[0].shape}")
print(f"Captions batch, shape={batch[1].shape}")

Embeddings batch, shape=torch.Size([512, 512])
Captions batch, shape=torch.Size([512, 36])


In [170]:
def generate_caption(idx, train, max_caption_length=64, show_true_caption=True, show_ytid=True):

  if train:
    true_tags = training_data[idx]["caption"]
    embedding = training_data[idx]["embedding"]
  else:
    true_tags = test_data[idx]["caption"]
    embedding = test_data[idx]["embedding"]

  x = embedding.unsqueeze(0).to(device, dtype=torch.float)
  model.eval()
  # breaks if starting sequence is only one token (?)
  caption = torch.tensor([word_to_idx[word] for word in ['<pad>', '<start>']]).unsqueeze(0).to(device)

  # Generate the caption word by word
  with torch.no_grad():
      while caption[0][-1] != word_to_idx['<end>'] and len(caption[0]) < max_caption_length:
          logits, hidden = model(caption[:, :-1], x)
          predicted_word_index = logits.argmax(-1)[:, -1].item()
          predicted_word = idx_to_word[predicted_word_index]
          caption = torch.cat([caption, torch.tensor([[predicted_word_index]], dtype=torch.long).to(device)], dim=1)

  predicted_caption = ' '.join([idx_to_word[word_idx] for word_idx in caption[0].tolist()][2:-1])

  if show_true_caption: 
    print(f"True tags: {true_tags} ")
  print(f"Predicted tags: {predicted_caption}\n")
  

In [171]:
len(train_dataloader)

421

In [None]:
# Train the model
pbar = tqdm.tqdm(range(num_epochs))
pbar.set_description(f"Epoch 1")
for epoch in pbar:
    model.train()  # set model to train mode
    for i, (x, captions) in enumerate(train_dataloader):
        optimizer.zero_grad()
        logits, hidden = model(captions[:, :-1], x)
        loss = loss_fn(logits.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        pbar.set_description(f"Epoch {epoch+1} | Train batch {i}/{len(train_dataloader)}")
    scheduler.step()

    model.eval()  # set model to eval mode
    eval_loss = 0
    with torch.no_grad():
        for i, (x, captions) in enumerate(eval_test_dataloader):
            x = x.to(device, dtype=torch.float)
            captions = captions.to(device, dtype=torch.long)
            logits, hidden = model(captions[:, :-1], x)
            loss = loss_fn(logits.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
            eval_loss += loss.item() * x.size(0)  # accumulate loss for entire eval dataset
            pbar.set_description(f"Epoch {epoch+1} | Eval batch {i}/{len(eval_test_dataloader)}")
        eval_loss /= len(eval_test_dataloader.dataset)  # compute average eval loss
        
    
    pbar.set_description(f"Epoch {epoch+1}, train loss {loss.item():.4f}, eval loss {eval_loss:.4f}")
    train_ids, test_ids = np.random.randint(len(training_data), size=2), np.random.randint(len(test_data), size=2)
    print("\n(train)")
    generate_caption(train_ids[0], train=True, show_true_caption=True) 
    generate_caption(train_ids[1], train=True, show_true_caption=True) 
    print("(eval)")
    generate_caption(test_ids[0], train=False, show_true_caption=True)
    generate_caption(test_ids[1], train=False, show_true_caption=True)
    print("\n")

Epoch 1, train loss 2.3914, eval loss 2.4226:   5%|▌         | 1/20 [01:26<27:26, 86.65s/it]


(train)
True tags: genre: classical, soundtrack; instrument: piano; mood/theme: sentimental, wedding, romantic 
Predicted tags: genre classical classical instrument instrument piano piano moodtheme moodtheme relaxing christmas

True tags: genre: rnb, hiphop, rap 
Predicted tags: genre electronic electronic

(eval)
True tags: genre: easylistening, trance, atmospheric; instrument: guitar, drums, bass, harp 
Predicted tags: genre ambient ambient instrument instrument synthesizer synthesizer synthesizer moodtheme moodtheme meditative meditative

True tags: genre: 80s, acidrock, world; instrument: bass 
Predicted tags: genre rock rock





Epoch 2, train loss 2.2446, eval loss 2.2896:  10%|█         | 2/20 [02:58<26:57, 89.88s/it]


(train)
True tags: genre: electronic, dance, electropop, electronica 
Predicted tags: genre dance pop pop instrument instrument drums synthesizer synthesizer bass moodtheme moodtheme energetic happy

True tags: genre: grunge, bluesrock, rock, indie; instrument: electricguitar, drums, bass, slideguitar 
Predicted tags: genre rock funk

(eval)
True tags: genre: chillout, atmospheric, ambient; mood/theme: love 
Predicted tags: genre ambient ambient instrument instrument synthesizer synthesizer moodtheme moodtheme relaxing relaxing

True tags: genre: electronic, downtempo, chillout; mood/theme: peaceful 
Predicted tags: genre chillout lounge lounge instrument instrument piano synthesizer





Epoch 3, train loss 2.1987, eval loss 2.2073:  15%|█▌        | 3/20 [04:32<25:54, 91.46s/it]


(train)
True tags: genre: techno, electronic, soundtrack; instrument: synthesizer 
Predicted tags: genre electronic soundtrack

True tags: genre: electronic, club, dance, trance, techno 
Predicted tags: genre electronic electronic

(eval)
True tags: genre: rockfrancais, rock, ska; instrument: electricguitar, bass, voice, keyboard, drums; mood/theme: energetic 
Predicted tags: genre rock rock

True tags: genre: punkrock, rock, rocknroll, poprock, indie 
Predicted tags: genre rock rock





Epoch 4, train loss 2.1421, eval loss 2.1502:  20%|██        | 4/20 [06:05<24:36, 92.27s/it]


(train)
True tags: genre: electronic, edm, soundtrack, ambient, dance; instrument: drums, synthesizer, electricguitar, drummachine, violin, pipeorgan, acousticguitar, sampler; mood/theme: energetic, driving, advertising, dark, urban, angry 
Predicted tags: genre electronic electronic

True tags: genre: dance, electropop, pop; instrument: synthesizer, drums, bass; mood/theme: crazy, fun, energetic, party 
Predicted tags: genre hiphop hiphop instrument instrument beat beat

(eval)
True tags: genre: ambient, chillout, lounge; instrument: drummachine 
Predicted tags: genre chillout easylistening lounge instrument instrument bass synthesizer synthesizer moodtheme moodtheme relaxing relaxing

True tags: genre: contemporary, classical 
Predicted tags: genre pop pop





Epoch 5, train loss 2.0083, eval loss 2.1088:  25%|██▌       | 5/20 [07:39<23:13, 92.87s/it]


(train)
True tags: genre: triphop, hiphop, downtempo; instrument: guitar, drums, bass; mood/theme: dark, upbeat 
Predicted tags: genre hiphop hiphop rap instrument instrument beat synthesizer moodtheme moodtheme conscient energetic

True tags: genre: soundtrack 
Predicted tags: genre ambient ambient instrument instrument synthesizer synthesizer

(eval)
True tags: genre: lounge, easylistening, chillout 
Predicted tags: genre chillout easylistening lounge instrument instrument piano synthesizer moodtheme moodtheme relaxing inspiring

True tags: genre: classical, contemporary 
Predicted tags: genre classical classical instrument instrument piano piano moodtheme moodtheme emotional emotional





Epoch 6, train loss 2.1110, eval loss 2.0751:  30%|███       | 6/20 [09:13<21:44, 93.21s/it]


(train)
True tags: genre: hiphop, electronic, 90s 
Predicted tags: genre lounge hiphop easylistening instrument instrument bass synthesizer synthesizer piano piano moodtheme moodtheme background deep

True tags: genre: pop, ambient, soundtrack, latin, experimental; instrument: synthesizer, percussion; mood/theme: documentary, children, society, advertising, comedy 
Predicted tags: genre reggae dub dub instrument instrument bass electricguitar electricguitar electricpiano electricpiano

(eval)
True tags: genre: dance, trance, trancemelodique 
Predicted tags: genre trance trance

True tags: genre: hiphop, downtempo, triphop; instrument: synthesizer, bass, drums, guitar 
Predicted tags: genre hiphop hiphop rap instrument moodtheme beat youtube





Epoch 7, train loss 2.1490, eval loss 2.0510:  35%|███▌      | 7/20 [10:48<20:18, 93.70s/it]


(train)
True tags: genre: club, chillout, electronic; instrument: bass 
Predicted tags: genre electronic electronic instrument instrument synthesizer synthesizer

True tags: genre: soul, latin, bossanova 
Predicted tags: genre electronic pop

(eval)
True tags: genre: progressive, electronic; mood/theme: aggressive 
Predicted tags: genre electronic electronic

True tags: genre: pop, waltz, cabaret 
Predicted tags: genre reggae pop instrument instrument guitar guitar moodtheme





Epoch 8, train loss 2.0950, eval loss 2.0325:  40%|████      | 8/20 [12:22<18:46, 93.85s/it]


(train)
True tags: genre: rockfrancais, rock, bluesrock; instrument: electricguitar, voice, bass, drums, harmonica 
Predicted tags: genre funk funk rnb instrument instrument bass bass electricguitar electricguitar drums drums

True tags: genre: minimal; instrument: guitar, classicalguitar 
Predicted tags: genre classical pop instrument instrument guitar piano

(eval)
True tags: genre: instrumentalpop, easylistening 
Predicted tags: genre rock pop instrument instrument guitar guitar bass bass

True tags: genre: orchestral, easylistening, ambient; instrument: piano; mood/theme: sweet, motivational, soft, inspiring 
Predicted tags: genre classical classical instrument instrument piano piano





Epoch 10 | Train batch 0/421:  45%|████▌     | 9/20 [13:57<17:15, 94.11s/it]                


(train)
True tags: genre: rap, hiphop, pop, rock 
Predicted tags: genre electronic electronic

True tags: genre: industrial, ebm, electronic; mood/theme: dark 
Predicted tags: genre electronic electronic

(eval)
True tags: genre: soundtrack, darkambient, ambient; mood/theme: thriller, dark, horror 
Predicted tags: genre soundtrack soundtrack classical moodtheme moodtheme epic epic

True tags: genre: electronic, popfolk, folk; instrument: synthesizer, guitar; mood/theme: fast 
Predicted tags: genre electronic pop pop instrument instrument synthesizer synthesizer





Epoch 10, train loss 2.1471, eval loss 2.0113:  50%|█████     | 10/20 [15:32<15:45, 94.56s/it]


(train)
True tags: genre: dance, electronic; instrument: drummachine, sampler, synthesizer; mood/theme: bright, energetic, retro 
Predicted tags: genre electronic dance dance instrument instrument synthesizer synthesizer computer computer moodtheme

True tags: genre: electronic 
Predicted tags: genre electronic electronic

(eval)
True tags: genre: rock, progressiverock, hardrock; instrument: bassguitar, drums, guitar 
Predicted tags: genre rock rock instrument instrument guitar guitar bass bass drums drums

True tags: genre: orchestral, experimental, ambient, electronic; instrument: electricguitar, computer 
Predicted tags: genre electronic electronic ambient instrument instrument synthesizer synthesizer piano piano moodtheme moodtheme emotional emotional inspiring





Epoch 11, train loss 1.8643, eval loss 1.9941:  55%|█████▌    | 11/20 [17:09<14:16, 95.13s/it]


(train)
True tags: genre: soundtrack, orchestral; instrument: choir, strings; mood/theme: epic, emotional, drama 
Predicted tags: genre orchestral soundtrack soundtrack moodtheme instrument film film

True tags: genre: poprock, powerpop, rock 
Predicted tags: genre rock rock pop instrument instrument guitar voice voice bass bass

(eval)
True tags: genre: ambient, soundtrack; instrument: piano, synthesizer; mood/theme: thoughtful, peaceful 
Predicted tags: genre ambient ambient soundtrack instrument instrument piano synthesizer synthesizer moodtheme

True tags: genre: experimental, idm, ambient, electronic, downtempo; mood/theme: dream 
Predicted tags: genre electronic electronic ambient instrument instrument synthesizer synthesizer computer moodtheme moodtheme emotional adventure





Epoch 12, train loss 2.1304, eval loss 1.9922:  60%|██████    | 12/20 [18:44<12:42, 95.26s/it]


(train)
True tags: genre: 70s, blues, ambient 
Predicted tags: genre rock pop pop instrument instrument guitar guitar

True tags: genre: classical, soundtrack; instrument: electricpiano, piano; mood/theme: melodic, reflective 
Predicted tags: genre classical classical instrument instrument piano piano

(eval)
True tags: genre: ambient 
Predicted tags: genre ambient soundtrack soundtrack instrument instrument strings synthesizer piano piano

True tags: genre: hardrock, rock, indie, instrumentalrock; instrument: bass, drums, synthesizer, electricguitar 
Predicted tags: genre rock metal hardrock punkrock instrument instrument electricguitar electricguitar bass bass drums drums





Epoch 13, train loss 1.9733, eval loss 1.9906:  65%|██████▌   | 13/20 [20:18<11:03, 94.75s/it]


(train)
True tags: genre: electronic 
Predicted tags: genre electronic trance trance moodtheme instrument uplifting space

True tags: genre: techno, trance; instrument: accordion, computer, keyboard 
Predicted tags: genre electronic electronic

(eval)
True tags: genre: alternative, pop, indie 
Predicted tags: genre pop pop rock instrument instrument drums synthesizer piano piano bass

True tags: genre: world, soundtrack, jazz; instrument: saxophone, hang; mood/theme: society 
Predicted tags: genre ambient ambient atmospheric instrument instrument synthesizer synthesizer





Epoch 14, train loss 2.2170, eval loss 1.9941:  70%|███████   | 14/20 [21:52<09:27, 94.61s/it]


(train)
True tags: genre: orchestral, soundtrack, classical; mood/theme: dark, epic, film 
Predicted tags: genre soundtrack classical classical moodtheme instrument adventure adventure film action action documentary

True tags: genre: poprock, rock 
Predicted tags: genre pop pop instrument instrument keyboard synthesizer synthesizer piano piano acousticguitar

(eval)
True tags: genre: electronic; instrument: beat, synthesizer; mood/theme: drive, energetic 
Predicted tags: genre electronic electronic soundtrack instrument instrument synthesizer synthesizer

True tags: genre: classical; instrument: piano, violin, oboe, drums; mood/theme: christmas 
Predicted tags: genre pop pop rock instrument instrument drums synthesizer piano piano guitar





Epoch 15 | Train batch 58/421:  70%|███████   | 14/20 [22:04<09:27, 94.61s/it]

### Model assessment
Predict with the trained LSTM model

In [None]:
train_idx, test_idx = np.random.randint(len(training_data)), np.random.randint(len(test_data))
print("\n(train)")
generate_caption(train_idx, train=True, show_true_caption=True) 
print("(eval)")
generate_caption(test_idx, train=False, show_true_caption=True)
print("\n")

In [None]:
train_idx, test_idx = np.random.randint(len(training_data)), np.random.randint(len(test_data))
print("\n(train)")
generate_caption(train_idx, train=True, show_true_caption=True) 
print("(eval)")
generate_caption(test_idx, train=False, show_true_caption=True)
print("\n")