# 0 - Environment Setup

In [18]:
# import package to load BERT model
!pip install transformers

# mount google drive to load dataset 
from google.colab import drive
drive.mount('/content/drive')

# for data handling
import pandas as pd
import numpy as np
from sklearn.utils import shuffle
import string

# pytorch module for model implementation
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from transformers import BertModel, BertTokenizer

from tqdm import tqdm

# For saving model
from collections import OrderedDict
import urllib.request
import pickle

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [13]:
# Set up CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# 1 - Load the Data

In [17]:
# Load dataset
train_songs = pd.read_csv('/content/drive/MyDrive/COMP89/train.csv')

# filter to only English songs 
train_songs = train_songs.loc[train_songs.Language == 'en']

# convert to categorical to get numerical classes
train_songs['Genre'] = train_songs['Genre'].astype('category') 


val_counts = train_songs['Genre'].value_counts()

new_songs = train_songs.loc[train_songs.Genre == val_counts.index[0]].sample(n=min(train_songs['Genre'].value_counts()), random_state=42)
for g in val_counts.index[1:]:
  genre_df = train_songs.loc[train_songs.Genre == g]
  genre_df = genre_df.sample(n=min(train_songs['Genre'].value_counts()), random_state=42)

  new_songs = pd.concat([new_songs, genre_df])

new_songs.head()

Unnamed: 0,Artist,Song,Genre,Language,Lyrics
247351,manic street preachers,motorcycle emptiness,Rock,en,Culture sucks down words\nItemize loathing and...
50981,santana,put your lights on,Rock,en,"Hey now, all you sinners\nPut your lights on, ..."
35632,matchbox 20,bent,Rock,en,If I fall along the way\nPick me up and dust m...
273875,count the stars,all good things,Rock,en,"This air is contagious, no one can save us, no..."
29567,killing joke,democracy,Rock,en,"You have a choice, we are your voice\nRed, blu..."


# 2 - Model

## 2.1 - Implementation

In [9]:
# bert word attention
class Word_RNN(nn.Module):
    def __init__(self, hidden_size):
        super(Word_RNN, self).__init__()
        self.hidden_size = hidden_size

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = False).to(device)

        self.word_weight = nn.Linear(self.hidden_size, self.hidden_size)
        self.word_context_weight = nn.Linear(self.hidden_size, 1)

    def forward(self, raw_sents):
        sents = torch.zeros(1,len(raw_sents),self.hidden_size).to(device)

        encoded_input = self.tokenizer(raw_sents, return_tensors='pt', padding = True, truncation = True).to(device)
        h = self.bert_model(**encoded_input)[0]
        for i in range(len(raw_sents)):
            h_i = h[i, :, :].unsqueeze(0)
            u_i = torch.tanh(self.word_weight(h_i))
            u_iTw = self.word_context_weight(u_i).squeeze(2)

            attn_weights = F.softmax(u_iTw, dim=1)
            s_i = (attn_weights * h_i.permute(0,2,1)).sum(dim = 2)
            sents[:,i,:] = s_i
        return sents

# sentence attention
class Sent_RNN(nn.Module):
    def __init__(self, word_num_hidden, sentence_num_hidden):
        super(Sent_RNN, self).__init__()
        self.sentence_num_hidden = sentence_num_hidden

        self.lstm = nn.LSTM(word_num_hidden, sentence_num_hidden, bidirectional=True, batch_first = True)

        self.sent_weight = nn.Linear(2*sentence_num_hidden, 2*sentence_num_hidden)
        self.sent_context_weight = nn.Linear(2*sentence_num_hidden, 1)
        
    def forward(self, x):
        h_is, _ = self.lstm(x)
        u_is = torch.tanh(self.sent_weight(h_is))
        u_iTs = self.sent_context_weight(u_is).squeeze(2)
        a_is = F.softmax(u_iTs, dim=1)

        v = (a_is * h_is.permute(0,2,1)).sum(dim = 2)
        return v

class HAN(nn.Module):
    def __init__(self, sentence_num_hidden, word_hidden_size, num_classes):
        super(HAN, self).__init__()

        self.word_attn_rnn = Word_RNN(word_hidden_size)
        self.sent_attn_rnn = Sent_RNN(word_hidden_size, sentence_num_hidden)

        self.linear = nn.Linear(2*sentence_num_hidden, num_classes)
    
    def forward(self, raw_sents):
        word_embeddings = self.word_attn_rnn(raw_sents)
        sent_embeddings = self.sent_attn_rnn(word_embeddings)

        scores = self.linear(sent_embeddings)

        return scores

In [14]:
# Define model hyperparamters
sentence_num_hidden = 256
word_hidden_size = 768
num_classes = new_songs['Genre'].nunique()

# Create instance of model
model = HAN(sentence_num_hidden, word_hidden_size, num_classes).to(device)

# Freeze all BERT layers from training
non_bert_params = []
for name, _param in model.named_parameters():
    if 'bert' not in name:
        non_bert_params.append(_param)
    else:
        _param.requires_grad = False

# Define training paramters
num_epochs = 5
lr = 0.01

# Define loss function and optimiser
criterion = nn.CrossEntropyLoss()
optimiser = torch.optim.SGD(non_bert_params, lr=lr)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## 2.2 - Training

In [None]:
# Helper functions for accuracy
def predict(X):
    preds = []
    model.eval()
    for x in X:
        bars = x.split('/n')
        out = model(bars)
        train_preds.append(torch.argmax(F.softmax(out)).item())
    
    return preds

def Accuracy(preds, label):
    return np.mean(np.array(preds) == np.array(label))

<font color='red'>DO NOT RUN THIS CELL (WILL TAKE > 5 HOURS)</font>

In [21]:
X = new_songs.Lyrics.values.copy()
y = new_songs.Genre.cat.codes.values

model.train()
for epoch in range(num_epochs):
    # shuffle dataset for each epoch
    X, y = shuffle(X, y)

    # SGD
    for x, label in zip (X, y):
        optimiser.zero_grad()

        # split raw text input by each line in song
        bars = x.split('/n')
        out = model(bars)

        label = torch.LongTensor([label]).to(device)

        # find the loss
        train_loss = criterion(out, label)

        #backprop
        train_loss.backward()
        optimiser.step()

# Training Accuracy
print("Train acc: {}".format(Accuracy(predict(X), y)))

Train acc: 0.44253463894


## 2.3 - Save Model

In [None]:
partial_state_dict = OrderedDict()

for param_name in list(model.state_dict().keys()):
    if 'bert' not in param_name:
        partial_state_dict[param_name] = model.state_dict()[param_name]

# export
torch.save(partial_state_dict, '/content/drive/MyDrive/COMP89/partial_model_weights')

# 3 - Evaluation

In [None]:
# Import model weights
url = 'https://github.com/salkhalil/Lyrics2Vec/raw/main/saved_embeddings/partial_model_weights'
urllib.request.urlretrieve(url, './partial_model')

# Load Model
model = HAN(sentence_num_hidden, word_hidden_size, num_classes).to(device)
model.load_state_dict(torch.load('./partial_model'), strict=False)

# Import test set
test_data = pd.read_csv('https://github.com/salkhalil/Lyrics2Vec/raw/main/datasets/cheeky.csv')

X = test_data.Lyrics.values.copy()
y = test_data.Genre.cat.codes.values

print("Test acc: {}".format(Accuracy(predict(X_test), y_test)))

In [None]:
# loss plot

In [None]:
# confusion matrix