# 0. Install Dependencies and Dataset

## 0.1 Install Python packages

In [1]:
%%bash
rm -rf sample_data

pip -q install transformers
pip -q install datasets
pip -q install tqdm
pip -q install sentencepiece 
pip -q install kaggle
pip -q install sacrebleu

## 0.2 Download the dataset using the Kaggle API

In [2]:
%%bash
mkdir ~/.kaggle/
cd ~/.kaggle/
touch kaggle.json
chmod 600 /root/.kaggle/kaggle.json
echo "{\"username\":\"cwcrystal8\",\"key\":\"75ba9516cfea9c5de8e657080f7428bd\"}" > kaggle.json

cd /content/

kaggle datasets download -d mateibejan/multilingual-lyrics-for-genre-classification
unzip multilingual-lyrics-for-genre-classification.zip
rm test.csv
rm multilingual-lyrics-for-genre-classification.zip
mv train.csv data.csv

Downloading multilingual-lyrics-for-genre-classification.zip to /content

Archive:  multilingual-lyrics-for-genre-classification.zip
  inflating: test.csv                
  inflating: train.csv               


mkdir: cannot create directory ‘/root/.kaggle/’: File exists
  0%|          | 0.00/103M [00:00<?, ?B/s]  9%|8         | 9.00M/103M [00:00<00:01, 60.1MB/s] 24%|##4       | 25.0M/103M [00:00<00:01, 68.5MB/s] 40%|###9      | 41.0M/103M [00:00<00:00, 82.9MB/s] 48%|####7     | 49.0M/103M [00:00<00:00, 80.5MB/s] 63%|######3   | 65.0M/103M [00:00<00:00, 94.8MB/s] 81%|########  | 83.0M/103M [00:00<00:00, 111MB/s] 100%|##########| 103M/103M [00:00<00:00, 124MB/s] 


## 0.3 Clone the Github repo

In [3]:
%%bash
cd /content/
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
rm -rf 6.864-lyric-analysis
git clone "https://github.com/sohinik/6.864-lyric-analysis.git"

Cloning into '6.864-lyric-analysis'...


# 1. Clean and Format the Data

In [4]:
import numpy as np
import os
import random
import torch
from torch import cuda

os.chdir("/content/6.864-lyric-analysis")
print("Current Working Directory:", os.getcwd())

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if cuda.is_available():
  device = 'cuda'
  torch.cuda.manual_seed_all(seed)
else:
  print('WARNING: you are running this assignment on a cpu!')
  device = 'cpu'

Current Working Directory: /content/6.864-lyric-analysis


## 1.1 Get data from dataset

In [5]:
import utils
from data_processing import get_data
from utils import save_model, load_model

## Data Hyperparameters
clean_genre=True
genres=["Metal", "Jazz"]
num_included=None
num_words_per_stanza = None # we don't want to separate stanzas since we need the entire song
training_ratio = 0.9

raw_data, train_dict, test_dict = get_data("../data.csv", 
                                 clean_genre = clean_genre,
                                 genres = genres,
                                 num_included = num_included,
                                 num_words_per_stanza = num_words_per_stanza,
                                 training_ratio = training_ratio)

## 1.2 Get set of lyrics separated by newlines

In [6]:
def get_lyric_set(lyrics):
  return set([line.lower() for song in lyrics for line in song.split("\n") if len(line.strip("(),./\\|{}<>-_`~'\":;?!#@$%^&*")) and 0 < song.strip().find("\n")])

def split_by_genre(lyrics, labels):
  genre_lyrics = {}
  for i in range(len(lyrics)):
    label = labels[i]
    if label in genre_lyrics:
      genre_lyrics[label].append(lyrics[i])
    else:
      genre_lyrics[label] = [lyrics[i]]
  for genre in genre_lyrics:
    genre_lyrics[genre] = get_lyric_set(genre_lyrics[genre])
  return genre_lyrics

genre_lyrics_train = split_by_genre(train_dict['lyrics'], train_dict['labels'])
genre_lyrics_test = split_by_genre(test_dict['lyrics'], test_dict['labels'])

In [7]:
# train_dict['lyrics'] = [line.lower() for song in train_dict['lyrics'] for line in song.split("\n") if len(line.strip("(),./\\|{}<>-_`~'\":;?!#@$%^&*")) and 0 < song.strip().find("\n")]
# test_dict['lyrics'] = [line.lower() for song in test_dict['lyrics'] for line in song.split("\n") if len(line.strip("(),./\\|{}<>-_`~'\":;?!#@$%^&*")) and 0 < song.strip().find("\n")]

genre_lyrics_train['Jazz'] = random.sample(genre_lyrics_train['Jazz'], 12000)
genre_lyrics_train['Metal'] = random.sample(genre_lyrics_train['Metal'], 12000)
genre_lyrics_test['Jazz'] = random.sample(genre_lyrics_test['Jazz'], 3000)
genre_lyrics_test['Metal'] = random.sample(genre_lyrics_test['Metal'], 3000)

# train_dict['lyrics'] = train_dict['lyrics'][:15000]
# test_dict['lyrics'] = test_dict['lyrics'][:10000]

## 1.3 Define Dataset class

In [8]:
from torch.utils import data
from collections import Counter
from torch.nn import functional as F

# These IDs are reserved.
PAD_INDEX = 0
UNK_INDEX = 1
SOS_INDEX = 2
EOS_INDEX = 3

MAX_SENT_LENGTH_PLUS_SOS_EOS = 100 # from max of folk and metal max lengths

def get_vocab(data_list, size=7000):
  word_freq = Counter(word.strip("(),./\\|{}<>-_`~'\":;?!#@$%^&*").lower() for lyrics in data_list for word in lyrics.split())
  del word_freq[""]
  # print(word_freq.most_common(20))
  return list(pair[0] for pair in word_freq.most_common(size))

class SongDataset(data.Dataset):
  def __init__(self, songs, vocab, max_size=None):
    self.songs = list(songs)
    self.max_src_seq_length = MAX_SENT_LENGTH_PLUS_SOS_EOS
    self.v2id = {v : i + 4 for i, v in enumerate(vocab)}
    self.v2id['<pad>'] = PAD_INDEX
    self.v2id['<unk>'] = UNK_INDEX
    self.v2id['<sos>'] = SOS_INDEX
    self.v2id['<eos>'] = EOS_INDEX
    self.vocab = vocab
    self.vocab_size = len(self.v2id)
    self.id2v = {val : key for key, val in self.v2id.items()}

  def __len__(self):
    return len(self.songs)

  def __getitem__(self, index):
    return self.get_items(self.songs[index])

  def get_items(self, s):
    song = s.lower().split()
    song_len = len(song) + 2   # add <s> and </s> to each sentence
    song_ids = []
    for word in song:
      w = word.strip("(),./\\|{}<>-_`~'\":;?!#@$%^&*")
      if not len(w):
        song_len -= 1
        continue
      if w not in self.vocab:
        w = '<unk>'
      song_ids.append(self.v2id[w])
    song_ids = ([SOS_INDEX] + song_ids + [EOS_INDEX] + [PAD_INDEX] *
              (self.max_src_seq_length - song_len))
    attn_mask = ([1] * (song_len) + [0] * (self.max_src_seq_length - song_len))

    song_ids = torch.tensor(song_ids)

    song_vecs = F.one_hot(song_ids, num_classes=self.vocab_size).float()

    return song_vecs[:100], song_ids[:100], torch.tensor(min(song_len, 100)), torch.tensor(attn_mask[:100])

## 1.4 Make datasets for each genre

In [9]:
vocab = get_vocab(genre_lyrics_train['Jazz'] + genre_lyrics_train['Metal'] + genre_lyrics_test['Jazz'] + genre_lyrics_test['Metal'])
datasets_train = {genre: SongDataset(lyrics, vocab) for genre, lyrics in genre_lyrics_train.items()}
datasets_test = {genre: SongDataset(lyrics, vocab) for genre, lyrics in genre_lyrics_test.items()}

# jazz_dataset_train = datasets_train["Jazz"]
# jazz_dataset_test = datasets_test["Jazz"]
# metal_dataset_train = datasets_train["Metal"]
# metal_dataset_test = datasets_test["Metal"]

# 2. Design Model


## 2.1 Encoder 

In [10]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Encoder(nn.Module):
    """Encodes a sequence of word embeddings"""
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.rnn = nn.GRU(input_size, hidden_size, num_layers, 
                          batch_first=True, bidirectional=True, dropout=dropout)
        # self.rnn = nn.GRU(input_size, hidden_size, num_layers, 
        #                   batch_first=True, bidirectional=False, dropout=dropout)
        
    def forward(self, x, mask, lengths):
        """
        Applies a bidirectional GRU to sequence of embeddings x.
        The input mini-batch x needs to be sorted by length.
        x should have dimensions [batch, time, dim].
        """
        packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        output, final = self.rnn(packed)
        output, _ = pad_packed_sequence(output, batch_first=True, total_length=MAX_SENT_LENGTH_PLUS_SOS_EOS)

        # we need to manually concatenate the final states for both directions
        fwd_final = final[0:final.size(0):2]
        bwd_final = final[1:final.size(0):2]
        final = torch.cat([fwd_final, bwd_final], dim=2)  # [num_layers, batch, 2*dim]

        return output, final

## 2.2 Decoder

In [11]:
class Decoder(nn.Module):
    """A conditional RNN decoder with attention."""
    
    def __init__(self, emb_size, hidden_size, attention=None, num_layers=1, dropout=0.5,
                 bridge=True):
        super(Decoder, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.attention = attention
        self.dropout = dropout
                 
        self.rnn = nn.GRU(emb_size + 2*hidden_size, hidden_size, num_layers,
                          batch_first=True, dropout=dropout)

        # self.rnn = nn.GRU(emb_size, hidden_size, num_layers,
        #                   batch_first=True, dropout=dropout)
                 
        # to initialize from the final encoder state
        self.bridge = nn.Linear(2*hidden_size, hidden_size, bias=True) if bridge else None

        self.dropout_layer = nn.Dropout(p=dropout)
        self.pre_output_layer = nn.Linear(hidden_size + 2*hidden_size + emb_size,
                                          hidden_size, bias=False)
        # self.pre_output_layer = nn.Linear(hidden_size + emb_size,
        #                                    hidden_size, bias=False)
        
    def forward_step(self, prev_embed, encoder_hidden, src_mask, hidden, proj_key): #proj_key
        """Perform a single decoder step (1 word)"""


        # compute context vector using attention mechanism
        query = hidden[-1].unsqueeze(1)  # [#layers, B, D] -> [B, 1, D]

        context, attn_probs = self.attention(
            query=query, proj_key=proj_key,
            value=encoder_hidden, mask=src_mask)

        # update rnn hidden state
        rnn_input = torch.cat([prev_embed, context], dim=2)
        output, hidden = self.rnn(rnn_input, hidden)
        # output, hidden = self.rnn(prev_embed, hidden)
        
        pre_output = torch.cat([prev_embed, output, context], dim=2)
        # pre_output = torch.cat([prev_embed, output], dim=2)
        pre_output = self.dropout_layer(pre_output)
        pre_output = self.pre_output_layer(pre_output)

        return output, hidden, pre_output
    
    def forward(self, trg_embed, encoder_hidden, encoder_final, 
                src_mask, trg_mask, hidden=None, max_len=None):
        """Unroll the decoder one step at a time."""
                                         
        # the maximum number of steps to unroll the RNN
        if max_len is None:
            # max_len = trg_mask.size(-1)
            max_len = trg_embed.size(1)

        # initialize decoder hidden state
        if hidden is None:
            hidden = self.init_hidden(encoder_final)
        
        # pre-compute projected encoder hidden states
        # (the "keys" for the attention mechanism)
        # this is only done for efficiency
        proj_key = self.attention.key_layer(encoder_hidden)
        
        # here we store all intermediate hidden states and pre-output vectors
        decoder_states = []
        pre_output_vectors = []
        
        # unroll the decoder RNN for max_len steps
        for i in range(max_len):

            # prev_embed = trg_embed[:, i].unsqueeze(1)
            prev_embed = trg_embed[:,i:i+1,:]
            pre_output, hidden, output = self.forward_step(
              prev_embed, encoder_hidden, src_mask, hidden, proj_key)
            decoder_states.append(output)
            pre_output_vectors.append(pre_output)

        decoder_states = torch.cat(decoder_states, dim=1)
        pre_output_vectors = torch.cat(pre_output_vectors, dim=1)
        return decoder_states, hidden, pre_output_vectors  # [B, N, D]

    def init_hidden(self, encoder_final):
        """Returns the initial decoder state,
        conditioned on the final encoder state."""

        if encoder_final is None:
            return None  # start with zeros

        return torch.tanh(self.bridge(encoder_final))


## 2.3 Attention

In [12]:
class BahdanauAttention(nn.Module):
    """Implements Bahdanau (MLP) attention"""
    
    def __init__(self, hidden_size, key_size=None, query_size=None):
        super(BahdanauAttention, self).__init__()
        
        # We assume a bi-directional encoder so key_size is 2*hidden_size
        key_size = 2 * hidden_size if key_size is None else key_size
        query_size = hidden_size if query_size is None else query_size

        self.key_layer = nn.Linear(key_size, hidden_size, bias=False)
        self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
        self.energy_layer = nn.Linear(hidden_size, 1, bias=False)
        
        # to store attention scores
        self.alphas = None
        
    def forward(self, query=None, proj_key=None, value=None, mask=None):
        assert mask is not None, "mask is required"
        
        # We first project the query (the decoder state).
        # The projected keys (the encoder states) were already pre-computated.
        query = self.query_layer(query)
        
        # Calculate scores.
        scores = self.energy_layer(torch.tanh(query + proj_key))
        scores = scores.squeeze(2).unsqueeze(1)
        
        # Mask out invalid positions.
        # The mask marks valid positions so we invert it using `mask & 0`.
        scores.data.masked_fill_(mask == 0, -float('inf'))
        
        # Turn scores to probabilities.
        alphas = F.softmax(scores, dim=-1)
        self.alphas = alphas        
        
        # The context vector is the weighted sum of the values.
        context = torch.bmm(alphas, value)
        
        # context shape: [B, 1, 2D], alphas shape: [B, 1, M]
        return context, alphas

## 2.4 EncoderDecoderLyricGenerator

In [13]:
class EncoderDecoderLyricGenerator(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """

    def __init__(self, encoders, decoders, generators, src_embeds, trg_embeds):
        '''
        Assumes that genres are jazz and metal.
        '''
        super(EncoderDecoderLyricGenerator, self).__init__()
        self.jazz_encoder = encoders['Jazz']
        self.metal_encoder = encoders['Metal']
        self.jazz_decoder = decoders['Jazz']
        self.metal_decoder = decoders['Metal']
        self.jazz_generator = generators['Jazz']
        self.metal_generator = generators['Metal']
        self.jazz_src_embed = src_embeds['Jazz']
        self.metal_src_embed = src_embeds['Metal']
        self.jazz_trg_embed = trg_embeds['Jazz']
        self.metal_trg_embed = trg_embeds['Metal']
        
    def forward(self, src, trg, src_mask, trg_mask, src_lengths, trg_lengths, src_genre, trg_genre):
        """Take in and process masked src and target sequences."""
        encoder_hidden, encoder_final = self.encode(src, src_mask, src_lengths, src_genre)
        return self.decode(encoder_hidden, encoder_final, src_mask, trg[:,:-1], trg_mask, trg_genre)
    
    def encode(self, src, src_mask, src_lengths, src_genre):
        encoder = self.jazz_encoder if src_genre == "Jazz" else self.metal_encoder
        embed = self.jazz_src_embed if src_genre == 'Jazz' else self.metal_src_embed
        return encoder(embed(src), src_mask, src_lengths)
    
    def decode(self, encoder_hidden, encoder_final, src_mask, trg, trg_mask, trg_genre,
               decoder_hidden=None):
        decoder = self.jazz_decoder if trg_genre == "Jazz" else self.metal_decoder
        embed = self.jazz_trg_embed if trg_genre == 'Jazz' else self.metal_trg_embed
        return decoder(embed(trg), encoder_hidden, encoder_final,
                            src_mask, trg_mask, hidden=decoder_hidden)
        
    def generate(self, outputs, genre):
        generator = self.jazz_generator if genre == "Jazz" else self.metal_generator
        return generator(outputs)

class Generator(nn.Module):
        """Define standard linear + softmax generation step."""
        def __init__(self, hidden_size, vocab_size):
            super(Generator, self).__init__()
            self.proj = nn.Linear(hidden_size, vocab_size, bias=False)

        def forward(self, x):
            return F.log_softmax(self.proj(x), dim=-1)

# 3. Train the Model

## 3.1 Create the model

In [14]:
emb_size=256
hidden_size=256
num_layers=2
dropout=0.2

vocab_size = len(vocab) + 4 # to account for tokens

genres_for_model = frozenset(genres)
encoders = {genre: Encoder(emb_size, hidden_size, dropout=dropout, num_layers=num_layers) for genre in genres_for_model}
decoders = {genre: Decoder(emb_size, hidden_size, attention=BahdanauAttention(hidden_size),dropout=dropout,num_layers=num_layers) for genre in genres_for_model}
generators = {genre: Generator(hidden_size, vocab_size) for genre in genres_for_model}
src_embeds = {genre: nn.Linear(vocab_size, emb_size) for genre in genres_for_model}
trg_embeds = {genre: nn.Linear(vocab_size, emb_size) for genre in genres_for_model}

model = EncoderDecoderLyricGenerator(
        encoders,
        decoders,
        generators,
        src_embeds,
        trg_embeds,
        ).to(device)

## 3.2 Define classifier model

In [15]:
import torch.nn as nn

class ModelOutputs:
    def __init__(self, logits = None, loss=None):
        self.logits = logits
        self.loss = loss

class GenreClassificationModel(nn.Module):
  def __init__(self, lm, num_labels, dropout=0.2, num_layers = 1, is_bidirectional = False):
    super(GenreClassificationModel, self).__init__()
    # (batch_size, num_tokens)
    # (batch_size, num_tokens, hidden_size)
    # (batch_size, 1 , hidden_size)
    # (batch_size, 1, num_labels)

    self.lm = lm
    self.dropout = nn.Dropout(dropout)
    self.encoder = nn.GRU(
        input_size  = lm.config.hidden_size,
        hidden_size = lm.config.hidden_size,
        num_layers = num_layers,
        batch_first = True,
        bidirectional = is_bidirectional,
        dropout = dropout
        )
    self.classifier = nn.Linear(lm.config.hidden_size, num_labels)
    self.bidirectional = is_bidirectional
    self.num_layers = num_layers
    

  def forward(self, input_ids, attn_mask, labels = None):
    '''
    Inputs;
    input_ids: (batch_size, num_tokens) tensor of input_ids
    attn_mask: (batch_size, num_tokens) tensor 
    labels (optional): (batch_size,) tensor


    Outputs:
    label_logits: (batch_size, num_labels) tensor of logits
    '''

    lm_outputs = self.lm(input_ids, attn_mask)
    hidden_states = lm_outputs.last_hidden_state
    
    hidden_states = self.dropout(hidden_states)

    _, hidden_states = self.encoder(hidden_states)

    if not self.bidirectional:
      hidden_states = hidden_states[-1]
    else: 
      hidden_states = torch.sum(hidden_states[-2:], dim=0)
    logits = self.classifier(hidden_states)

    loss = None

    if labels is not None:
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
    
    return ModelOutputs(
        logits = logits,
        loss = loss)

## 3.3 Set the hyperparameters

In [16]:
# Hyper-parameters: you could try playing with different settings
num_epochs = 5
learning_rate = 1e-3
batch_size = 64

## 3.4 Define loss functions

In [17]:
class SelfLoss:
    """A self loss compute and train function."""

    def __init__(self, model, criterion, opt=None):
      self.model = model
      self.criterion = criterion
      self.opt = opt

    def __call__(self, x, y, norm=1, genre=None):
        '''
        x: output of our encoder-decoder (logits of each word)
        y: input of our encoder-decoder
        '''
        x = self.model.generate(x, genre)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)),
                              y.contiguous().view(-1))
        loss = loss / norm

        if self.opt is not None:
            loss.backward()          
            self.opt.step()
            self.opt.zero_grad()

        return loss.data.item() * norm

In [18]:
class CycleLoss:
    """A cycle loss compute and train function."""

    def __init__(self, model, criterion, opt=None):
      self.model = model
      self.criterion = criterion
      self.opt = opt

    def __call__(self, x, y, norm=1, genre=None):
        '''
        x: output of our encoder-decoder (logits of each word)
        y: input of our encoder-decoder
        '''
        x = self.model.generate(x, genre)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)),
                              y.contiguous().view(-1))
        loss = loss / norm

        if self.opt is not None:
            loss.backward()          
            self.opt.step()
            self.opt.zero_grad()

        return loss.data.item() * norm

In [19]:
# class ClassificationLoss:
#     """A classification model loss compute and train function."""

#     def __init__(self, classifier, criterion, opt=None):
#       self.classifier = classifier
#       self.criterion = criterion
#       self.opt = opt

#     def __call__(self, x, y, norm=1):
#         '''
#         x: output of our encoder-decoder (logits of each word)
#         y: input of our encoder-decoder
#         '''
#         # x = self.generator(x)
#         # x = torch.argmax(x, dim=-1)

#         self.classifier.eval()

#         prediction = torch.argmax(self.classifier(x)[2:4])
#         loss = self.criterion(prediction, y)

#         loss = loss / norm

#         if self.opt is not None:
#             loss.backward()          
#             self.opt.step()
#             self.opt.zero_grad()

#         return loss.data.item() * norm

In [20]:
def greedy_decode_text_style_transfer_batch(model, src_ids, src_lengths, src_mask, src_genre, trg_genre, max_len, grad=True):
  """Greedily decode a sentence for EncoderDecoder. Make sure to chop off the 
     EOS token!"""

  batch_size = src_ids.size()[0]
  if not grad:
    with torch.no_grad():
      encoder_hidden, encoder_finals = model.encode(src_ids, src_mask, src_lengths, src_genre)
      prev_y = F.one_hot(torch.ones(batch_size, 1, dtype=torch.long).fill_(SOS_INDEX), vocab_size).to(device).float()
      mask = torch.ones(batch_size, 1).type_as(src_ids).to(device)
  else:
    encoder_hidden, encoder_finals = model.encode(src_ids, src_mask, src_lengths, src_genre)
    prev_y = F.one_hot(torch.ones(batch_size, 1, dtype=torch.long).fill_(SOS_INDEX), vocab_size).to(device).float()
    mask = torch.ones(batch_size, 1).type_as(src_ids).to(device)

  output = []
  hidden = None
  
  # --------- Your code here --------- #
  for i in range(max_len):
    outputs, hidden, cur_output = model.decode(encoder_hidden, encoder_finals, src_mask, prev_y, mask, trg_genre, decoder_hidden=hidden)
    cur_output = model.generate(cur_output, trg_genre)
    output.append(cur_output)
    prev_y = cur_output
  # --------- Your code ends --------- #

  return torch.cat(output, dim=1)

In [21]:
def decode_to_string(ids, dataset):
  words = [dataset.id2v[i.item()] for i in ids]
  return " ".join(words)
# def decode_to_string_batch(batch_ids, dataset):
#   # input batch_size x max len x 1
#   output = []
#   for ids in batch_ids:
#     output.append(" ".join([dataset.id2v[i.item()] for i in ids]))
#   # batch size tensor of strings 
#   return output

## 3.5 Train

In [22]:
def run_epoch(data_loader, model, loss_compute, cycle_loss_compute, print_every, src_genre, int_genre):
  total_tokens = 0
  total_loss = []

  for i, (song_vecs, song_ids, song_lengths, song_attn) in enumerate(data_loader):

    song_vecs=song_vecs.to(device)
    song_ids = song_ids.to(device)
    song_attn = song_attn.to(device)
    song_attn = song_attn.unsqueeze(-2)

    _, _, self_output = model(song_vecs, song_vecs, song_attn, song_attn, song_lengths, song_lengths, src_genre, src_genre)

    loss0 = loss_compute(x=self_output, y=song_ids[:, 1:],
                        norm=song_vecs.size(0), genre=src_genre)
    
    cycle_pre_output = greedy_decode_text_style_transfer_batch(model, song_vecs, song_lengths, song_attn, src_genre, int_genre, MAX_SENT_LENGTH_PLUS_SOS_EOS)

    int_attn_mask = torch.ones(song_attn.shape).to(device)
    int_lengths = torch.ones(batch_size).fill_(MAX_SENT_LENGTH_PLUS_SOS_EOS)

    _, _, cycle_output = model(cycle_pre_output, song_vecs, int_attn_mask, song_attn, int_lengths, song_lengths, src_genre, int_genre)
    loss1 = cycle_loss_compute(x=cycle_output, y=song_ids[:, 1:],
                        norm=song_vecs.size(0), genre=src_genre)

    # decoded_outputs = greedy_decode_text_style_transfer_batch(model, src_ids, src_lengths, src_mask, src_genre, trg_genre, max_len)
    # decoded_strings = decode_to_string_batch(decoded_outputs, datasets_train[trg_genre])

    loss = (loss0 + loss1)/2
    total_loss.append(loss)

    if model.training and i % print_every == 0:
      print("Epoch Step: %d Loss: %f" % (i, loss / song_ids.size(0)))

  return total_loss


In [23]:
train = False
print_every = 50

if train:
  data_loaders = {genre: data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2) for genre, dataset in datasets_train.items()}
  
  criterion = nn.NLLLoss(reduction='sum', ignore_index=PAD_INDEX)
  criterion_cycle = nn.NLLLoss(reduction='sum', ignore_index=PAD_INDEX)
  # criterion_classifier = nn.CrossEntropyLoss()

  optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
  
  self_loss_fn = SelfLoss(model,criterion, optim)
  cycle_loss_fn = CycleLoss(model,criterion, optim)
  # classifier_loss_fn = ClassifierLoss(classifier_model, criterion_classifier, optim)

  # folk_losses = []
  jazz_losses = []
  metal_losses = []
  losses = []

  step_id = 0

  print("model training")

  print(len(data_loaders['Jazz']))
  print(len(data_loaders['Metal']))

  for _ in range(num_epochs):

    model.train()
    epoch_loss_jazz = run_epoch(data_loader=data_loaders['Jazz'], model=model,
                loss_compute=self_loss_fn, cycle_loss_compute=cycle_loss_fn,
                print_every=print_every, src_genre='Jazz', int_genre='Metal')
    epoch_loss_metal = run_epoch(data_loader=data_loaders['Metal'], model=model,
                loss_compute=self_loss_fn, cycle_loss_compute=cycle_loss_fn, 
                print_every=print_every, src_genre='Metal', int_genre='Jazz')

    jazz_losses.extend(epoch_loss_jazz)
    metal_losses.extend(epoch_loss_metal)
    losses.extend(epoch_loss_jazz)
    losses.extend(epoch_loss_metal)
  print('model training complete')  
  utils.save_model(model, "/content/genre2genre.pt")

else:
  model.load_state_dict(torch.load("/content/genre2genre.pt"))


In [24]:
import matplotlib.pyplot as plt

if train:

  plt.plot(losses)
  plt.legend()
  plt.xlabel("Step")
  plt.ylabel("Loss")
  plt.title("Step vs loss for genre-to-genre generation model")
  plt.show()

# 4. Evaluate the Model


## 4.1 Calculate the basic cross entropy loss

In [25]:
test_batch_size = 1

datasets_test = {genre: SongDataset(lyrics, vocab) for genre, lyrics in genre_lyrics_test.items()}
data_loaders_test = {genre: data.DataLoader(dataset, batch_size=test_batch_size, shuffle=True) for genre, dataset in datasets_test.items()}

## 4.2 Qualitative analysis using greedy decode

In [32]:
model.eval()

total_loss = 0

# confusion_matrix = torch.zeros((num_labels, num_labels)).to(device)
num_test_batches = 0

jazz_to_jazz = []
jazz_to_metal = []
metal_to_jazz = []
metal_to_metal = []

for genre, data_loader in data_loaders_test.items():

    int_genre = "Jazz" if genre == "Metal" else "Metal"

    for i, (song_vecs, song_ids, song_lengths, song_attn) in enumerate(data_loader):

        if i > 100:
          break

        song_ids = song_ids.to(device)
        song_vecs = song_vecs.to(device)
        song_attn = song_attn.to(device)

        # output = greedy_decode(model, song_ids, song_lengths, song_attn, MAX_SENT_LENGTH_PLUS_SOS_EOS)
        output = greedy_decode_text_style_transfer_batch(model, song_vecs, song_lengths, song_attn, genre, genre, MAX_SENT_LENGTH_PLUS_SOS_EOS, grad=False)
        output = torch.argmax(output, dim=-1).squeeze()

        # print(genre, genre)
        # print("input:", decode_to_string(song_ids[0,1:song_lengths-1], datasets_test[genre]))
        # print("output:", decode_to_string(output, datasets_test[genre]))
        # print()
        if genre == 'Jazz':
          jazz_to_jazz.extend(decode_to_string(output, datasets_test[genre]).split())
        else:
          metal_to_metal.extend(decode_to_string(output, datasets_test[genre]).split())

        output = greedy_decode_text_style_transfer_batch(model, song_vecs, song_lengths, song_attn, genre, int_genre, MAX_SENT_LENGTH_PLUS_SOS_EOS, grad=False)
        output = torch.argmax(output, dim=-1).squeeze()

        # print(genre, int_genre)
        # print("input:", decode_to_string(song_ids[0,1:song_lengths-1], datasets_test[genre]))
        # print("output:", decode_to_string(output, datasets_test[genre]))
        # print()
        if genre == 'Jazz':
          jazz_to_metal.extend(decode_to_string(output, datasets_test[genre]).split())
        else:
          metal_to_jazz.extend(decode_to_string(output, datasets_test[genre]).split())

jazz_to_jazz = Counter(jazz_to_jazz)
jazz_to_metal = Counter(jazz_to_metal)
metal_to_jazz = Counter(metal_to_jazz)
metal_to_metal = Counter(metal_to_metal)

print(jazz_to_jazz.most_common(10))
print(jazz_to_metal.most_common(10))
print(metal_to_jazz.most_common(10))
print(metal_to_metal.most_common(10))

print()

[('till', 1298), ('there’s', 1054), ('oh', 951), ('air', 659), ('star', 648), ("what's", 565), ('silence', 460), ('cut', 380), ('hands', 379), ('ah', 373)]
[('until', 9367), ('lonely', 589), ('grace', 27), ('answer', 11), ('sound', 9), ('burn', 8), ('sad', 7), ('say', 6), ('greater', 6), ('thought', 5)]
[('there’s', 9970), ('murderer', 64), ('<eos>', 20), ("what's", 6), ('and', 5), ('the', 4), ('for', 3), ('soon', 3), ('<unk>', 2), ('down', 2)]
[('until', 2432), ('wanna', 1472), ('as', 1151), ('lonely', 1142), ("can't", 587), ("it's", 405), ('do', 261), ('never', 207), ('<unk>', 197), ('will', 196)]



## 4.3 BLEU score

In [27]:
import sacrebleu
from tqdm import tqdm

def compute_BLEU(model, data_loader, decoder, dataset, genre):

  bleu_score = []

  model.eval()
  for src_vecs, src_ids, src_lengths, src_mask in tqdm(data_loader):
    try:
      result = decoder(model, src_vecs.to(device), src_lengths, src_mask.to(device),
                          genre, genre,
                          max_len=MAX_SENT_LENGTH_PLUS_SOS_EOS)
      
      result = torch.argmax(result, dim=-1).squeeze()
      
      # remove <s>
      src_ids = src_ids[0, 1:]
      # remove </s> and <pad>
      src_ids = src_ids[:np.where(src_ids == EOS_INDEX)[0][0]]

      pred = decode_to_string(result, dataset)
      targ = decode_to_string(src_ids, dataset)

      if targ:
        bleu_score.append(sacrebleu.raw_corpus_bleu([pred], [[targ]]).score)
    
    except:
      continue

  return bleu_score

In [28]:
print('Jazz BLEU score: %f' % (np.mean(compute_BLEU(model, 
                                            data_loaders_test["Jazz"],
                                            greedy_decode_text_style_transfer_batch, 
                                            datasets_test["Jazz"], "Jazz"))))
print('Metal BLEU score: %f' % (np.mean(compute_BLEU(model, 
                                            data_loaders_test["Metal"],
                                            greedy_decode_text_style_transfer_batch, 
                                            datasets_test["Metal"], "Metal"))))

100%|██████████| 3000/3000 [04:37<00:00, 10.83it/s]
  0%|          | 2/3000 [00:00<04:48, 10.38it/s]

Jazz BLEU score: 0.155308


100%|██████████| 3000/3000 [04:38<00:00, 10.76it/s]

Metal BLEU score: 0.168149





In [29]:
def compute_BLEU_cycle(model, data_loader, decoder, og_dataset, intermediate_dataset, genre, int_genre):

  bleu_score = []

  model.eval()
  for src_vecs, src_ids, src_lengths, src_mask in tqdm(data_loader):
    try:
      if not len(src_ids.size()):
        continue

      int_result = decoder(model, src_vecs.to(device), src_lengths, src_mask.to(device), genre, int_genre,
                          max_len=MAX_SENT_LENGTH_PLUS_SOS_EOS)
      
      int_result = torch.argmax(int_result, dim=-1).squeeze()
      int_result = decode_to_string(int_result, intermediate_dataset)
      
      new_vecs, new_ids, new_lengths, new_mask = intermediate_dataset.get_items(int_result)

      new_vecs = new_vecs[None, :]
      new_ids = new_ids[None,:]
      new_lengths = new_lengths[None]
      new_mask = new_mask[None,:]

      # print(new_ids.shape, new_mask.shape, new_lengths.shape)
      result = decoder(model, new_vecs.to(device), new_lengths, new_mask.to(device), genre, int_genre,
                          max_len=MAX_SENT_LENGTH_PLUS_SOS_EOS)
      result = torch.argmax(result, dim=-1).squeeze() if len(result.shape) > 1 else result
      
      # remove <s>
      src_ids = src_ids[0, 1:]

      # remove </s> and <pad>
      src_ids = src_ids[:np.where(src_ids == EOS_INDEX)[0][0]]

      pred = decode_to_string(result, og_dataset)
      targ = decode_to_string(src_ids, og_dataset)

      if targ:
        bleu_score.append(sacrebleu.raw_corpus_bleu([pred], [[targ]]).score)

    except:
      continue

  return bleu_score

In [None]:
print('\nJazz -> Metal -> Jazz BLEU score: %f' % (np.mean(compute_BLEU_cycle(model,
                                            data_loaders_test["Jazz"],
                                            greedy_decode_text_style_transfer_batch, 
                                            datasets_test["Jazz"], datasets_test["Metal"], "Jazz", "Metal"))))
print('\nMetal -> Jazz -> Metal BLEU score: %f' % (np.mean(compute_BLEU_cycle(model,
                                            data_loaders_test["Metal"],
                                            greedy_decode_text_style_transfer_batch, 
                                            datasets_test["Metal"], datasets_test["Jazz"], "Metal", "Jazz"))))

 36%|███▌      | 1065/3000 [03:20<05:53,  5.47it/s]

## 4.2 Calculate the accuracy, precision, and recall

In [None]:
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt


def get_recall(confusion_matrix):
  return torch.diag(confusion_matrix) / confusion_matrix.sum(dim = 1)

def get_precision(confusion_matrix):
  return torch.diag(confusion_matrix) / confusion_matrix.sum(dim = 0)

def get_accuracy(confusion_matrix):
  return torch.diag(confusion_matrix).sum() / confusion_matrix.sum()

def plot_confusion_matrix(confusion_matrix):
  df_cm = pd.DataFrame(confusion_matrix.cpu().numpy(), 
                     index = all_labels,
                     columns = all_labels)

  plt.figure(figsize = (10,7))
  sn.heatmap(df_cm, annot=True, fmt='g', cmap='Blues')  
  plt.ylabel("Actual")
  plt.xlabel("Predicted")

def plot_statistics(confusion_matrix):
  accuracy = get_accuracy(confusion_matrix)
  print("Accuracy:", accuracy.item())

  recall = get_recall(confusion_matrix)
  precision = get_precision(confusion_matrix)

  return pd.DataFrame({"Precision": precision.cpu().numpy(), "Recall": recall.cpu().numpy()}, index = all_labels)

plot_confusion_matrix(confusion_matrix)
plot_statistics(confusion_matrix)

In [None]:
plot_statistics(confusion_matrix)

Implementation Notes:
1. Split songs up into stanzas with 150 words each
2. Truncated it at 512 tokens (due to BERT limits)
3. Used GRU to combine information from each word in the song, LSTM does not improve performance
4. Spread across 5 genres as default
5. Equal number of datapoints for each genre
6. Different transformer types - ALBERT did not improve performance

Ideas:
1. use a different embedding (google word2vec or stanford glove)
2. use a different genre

TO DO:
1. change from stanza separation + tokenization + BERT embedding over to using entire songs, embedding them, and encoding + decoding them

In [None]:
utils.save_model(model, "/content/shitty_model.pt")

In [None]:
while True:
  pass