# Model Building

In [5]:
import pickle
import torch
import torch.nn
import torch.optim
import torch.utils.data

In [6]:
dataset_path = '../dataset'
store_file = os.path.join(dataset_path,'embedding.pickle')

## Construct Dataset class

In [7]:
class WorshipLyricDataset(torch.utils.data.Dataset):
    """Worhip Song dataset from Genius.
    """

    def __init__(self, embedding_file: str, sentence_length: int = None, sentence_step: int = 1):
        self.sentence_step = sentence_step
        self.sentence_length = sentence_length
        self.embedding_file = embedding_file

        # Load the embedding.
        with open(embedding_file, 'rb') as fp:
            store = pickle.load(fp)
        
        self.corpus = store['mapping']
        # self.songid_to_embed_all = store['embedding']
        # self.idx_to_songid = {idx: songid for idx,songid in enumerate(sorted(store['embedding'].keys()))}
        self.songids = sorted(store['embedding'].keys())

        # Break each lyric into contiguous sentences.
        self.songid_to_embed = {}
        self.songid_to_nextword = {}
        for songid,embed in store['embedding'].items():

            # Group embedding to contiguous sentence length.
            if sentence_length:
                self.songid_to_embed[songid] = []
                self.songid_to_nextword[songid] = []
                for i in range(0, len(embed) - self.sentence_length, self.sentence_step):
                    self.songid_to_embed[songid].append(embed[i:i+self.sentence_length])
                    self.songid_to_nextword[songid].append(embed[i+self.sentence_length])
                self.songid_to_embed[songid] = torch.tensor(self.songid_to_embed[songid], dtype=torch.long)
                self.songid_to_nextword[songid] = torch.tensor(self.songid_to_nextword[songid], dtype=torch.long)

            # Return original lyric embedding.
            else:
                self.songid_to_embed[songid] = torch.tensor(embed, dtype=torch.long)
                self.songid_to_nextword[songid] = torch.tensor([], dtype=torch.long)

        # Add padding based on longest song.
        self.embed_pad = torch.nn.utils.rnn.pad_sequence([self.songid_to_embed[songid] for songid in self.songids], batch_first=True, padding_value=0)
        self.nextword_pad = torch.nn.utils.rnn.pad_sequence([self.songid_to_nextword[songid] for songid in self.songids], batch_first=True, padding_value=0)
        # longest_size = max(embed.shape[0] for songid,embed in self.songid_to_embed)
        # for songid in self.songid_to_embed.keys():
        #     self.songid_to_embed.


    def __len__(self):
        return len(self.songid_to_embed)

    def __getitem__(self, idx):
        # songid = self.songids[idx] # Get song ID from dataset index.

        lyric = {
            # 'embed': self.songid_to_embed[songid],
            'embed': self.embed_pad[idx],
            # 'nextword': self.songid_to_nextword[songid],
            'nextword': self.nextword_pad[idx],
            # 'songid': songid,
            'songid': self.songids[idx],
        }
        return lyric

In [8]:
lyric_dataset = WorshipLyricDataset(embedding_file=store_file, sentence_length=5)

In [9]:
lyric_dataset[0]['embed'].shape, lyric_dataset[0]['nextword'].shape

(torch.Size([1284, 5]), torch.Size([1284]))

In [10]:
# Construct data loader.
dataloader = torch.utils.data.DataLoader(lyric_dataset, batch_size=1, shuffle=True, num_workers=0)

## Construct Model

In [11]:
class LyricGenerator(torch.nn.Module):
    def __init__(self, sentence_length: int, corpus_length: int, n_hidden: int = 128, n_layers: int = 2, drop_prob: float = 0.5):
        super().__init__()
        self.n_hidden = n_hidden
        self.lstm = torch.nn.LSTM(
            input_size=sentence_length,
            hidden_size=n_hidden,
            num_layers=n_layers,
            dropout=drop_prob,
            bidirectional=True,
            batch_first=True,
            )
        self.dropout = torch.nn.Dropout(p=drop_prob)
        self.fc = torch.nn.Linear(in_features=n_hidden, out_features=corpus_length)
    
    def forward(self, x):
        x = x.float() # Convert input to float type for LSTM.

        # Run inputs through LSTM.
        lstm_out, _ = self.lstm(x)
        print(lstm_out.shape)

        # Pass LSTM outputs through dropout layer.
        out = self.dropout(lstm_out)

        # Stack-up LSTM outputs.
        # out = out.contiguous().view(-1, self.n_hidden)
        out = out.view(-1, self.n_hidden)
        print(out.shape)

        out = self.fc(out)
        return out


## Train

In [12]:
def train(model, epoch, optim, criterion, loader, device='cpu'):
    """Helper to train the model."""
    model.train()
    for e in range(epoch):
        running_loss = 0.0
        for lyric in loader:

            # Send data to desired device.
            x = lyric['embed'].to(device)
            y = lyric['nextword'].to(device)

            # Evaluate the model.
            y_pred = model(x)

            # Compute losses.
            print(y_pred.shape, y.shape)
            loss = criterion(y_pred, y)

            # Zero the gradient, back-propagate, and step the optimizer.
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Accumulate the loss for this epoch.
            running_loss += loss.item()

        # Report epoch results.
        print(f'Epoch {e}: loss {running_loss}')

In [13]:
import time
from contextlib import contextmanager
@contextmanager
def timing(description='Elapsed time'):
    """Context manager to print elapsed time from call."""
    start_time = time.time()
    yield
    stop_time = time.time()
    print(f"{description}: {stop_time - start_time} seconds")

In [14]:
# Set runtime device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model.
corpus_length = len(lyric_dataset.corpus)
model = LyricGenerator(
    sentence_length=lyric_dataset.sentence_length,
    corpus_length=corpus_length,
    )

In [15]:
print(corpus_length)

6272


In [16]:
# Learning parameters.
epoch = 1
lr = 1e-2

# Train the model.
# Display training time too.
with timing():
    model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    train(model, loader=dataloader, epoch=epoch, optim=optim, criterion=criterion, device=device)

torch.Size([1, 1284, 256])
torch.Size([2568, 128])
torch.Size([2568, 6272]) torch.Size([1, 1284])


ValueError: Expected input batch_size (2568) to match target batch_size (1).