In [1]:
import re
import torch
import torch.nn.functional as F
from lib.utils import load, preprocess, get_dict, ask_dict

In [2]:
NGRAM_LENGTH_MIN = 1 # uni-gram
NGRAM_CONTEXT_WINDOW = 3

END_TOKEN = "</s>"
WITH_END_TOKEN = True

## Preprocessing

In [3]:
list_str = load('./data/songs.json')
del list_str["rows"][49] # Shape of You
del list_str["rows"][50 - 1] # You Need Me, I Don't Need You
del list_str["rows"][70 - 2] # Dive
del list_str["rows"][76 - 3] # Take Me Back to London

list_str = [preprocess(str["row"]["text"]) for str in list_str["rows"] if str["row"]["text"]]

if WITH_END_TOKEN:
    # Treat new lines as a word or not
    list_str = [str.replace("\n", f" {END_TOKEN} ") for str in list_str]

## Intializations

In [4]:
# Split the lines into words and lowercase them
list_list_words = [re.split(r'[\s\n]+', str.lower().strip()) for str in list_str]
list_str_comp = ["+".join(list_words) for list_words in list_list_words]

print(list_list_words[0])

['the', 'club', 'isnt', 'the', 'best', 'place', 'to', 'find', 'a', 'lover', '</s>', 'so', 'the', 'bar', 'is', 'where', 'i', 'go', '</s>', 'me', 'and', 'my', 'friends', 'at', 'the', 'table', 'doing', 'shots', '</s>', 'drinking', 'fast', 'and', 'then', 'we', 'talk', 'slow', '</s>', 'and', 'you', 'come', 'over', 'and', 'start', 'up', 'a', 'conversation', 'with', 'just', 'me', '</s>', 'and', 'trust', 'me', 'ill', 'give', 'it', 'a', 'chance', 'now', '</s>', 'take', 'my', 'hand', 'stop', 'put', 'van', 'the', 'man', 'on', 'the', 'jukebox', '</s>', 'and', 'then', 'we', 'start', 'to', 'dance', 'and', 'now', 'im', 'singing', 'like', '</s>', 'girl', 'you', 'know', 'i', 'want', 'your', 'love', '</s>', 'your', 'love', 'was', 'handmade', 'for', 'somebody', 'like', 'me', '</s>', 'come', 'on', 'now', 'follow', 'my', 'lead', '</s>', 'i', 'may', 'be', 'crazy', 'dont', 'mind', 'me', '</s>', 'say', 'boy', 'lets', 'not', 'talk', 'too', 'much', '</s>', 'grab', 'on', 'my', 'waist', 'and', 'put', 'that', 'bod

In [5]:
lengths = [len(list_words) for list_words in list_list_words]
lengths_sorted = sorted(lengths)

# The biggest n-gram can only be the length of the smallest song
NGRAM_LENGTH_MAX = lengths_sorted[0]

print(min(lengths), max(lengths))
print(NGRAM_LENGTH_MIN, NGRAM_LENGTH_MAX)

182 1904
1 182


## Dictionary

In [10]:
dict_key_tuple = get_dict(list_list_words, 3)
list_key_tuple_answers = ask_dict(dict_key_tuple, "i+dont")
dicts = {}

list_key_tuple_answers

[('i+dont+wanna', 51, 0.0010843911462652294, 0.29310344827586204),
 ('i+dont+need', 47, 0.0009993408602836428, 0.27011494252873564),
 ('i+dont+want', 22, 0.0004677765728987264, 0.12643678160919541),
 ('i+dont+care', 9, 0.00019136314345856987, 0.05172413793103448),
 ('i+dont+know', 8, 0.00017010057196317324, 0.04597701149425287),
 ('i+dont+love', 7, 0.00014883800046777657, 0.040229885057471264),
 ('i+dont+deserve', 4, 8.505028598158662e-05, 0.022988505747126436),
 ('i+dont+like', 4, 8.505028598158662e-05, 0.022988505747126436),
 ('i+dont+even', 4, 8.505028598158662e-05, 0.022988505747126436),
 ('i+dont+really', 4, 8.505028598158662e-05, 0.022988505747126436),
 ('i+dont+ever', 3, 6.378771448618996e-05, 0.017241379310344827),
 ('i+dont+but', 2, 4.252514299079331e-05, 0.011494252873563218),
 ('i+dont+reckon', 1, 2.1262571495396654e-05, 0.005747126436781609),
 ('i+dont+have', 1, 2.1262571495396654e-05, 0.005747126436781609),
 ('i+dont+</s>', 1, 2.1262571495396654e-05, 0.005747126436781609),

In [7]:
def predict(list_list_words: list, input: str, seed = None, max_lines = 10, separator = "+"):
    if len(input) == 0:
        return "Input cannot be empty"
    list_inputs = input.split(separator)
    ngram_context_window = len(list_inputs) + 1
    
    # Provided that the input is sanitized already
    arg = input
    max_index = 0
    generator = torch.Generator().manual_seed(seed) if seed else None

    # Generate the dictionary
    dict_key_tuple = dicts.get(ngram_context_window) or get_dict(list_list_words, ngram_context_window)
    if ngram_context_window not in dicts:
        dicts[ngram_context_window] = dict_key_tuple

    # Print over all list_words as a starter
    print("START\n")
    for _input in list_inputs:
        print(_input, end=" ")

    while True:
        # Get normalized probabilities for that input
        list_key_tuple_answers = ask_dict(dict_key_tuple, arg)
        if len(list_key_tuple_answers) == 0: return None

        # Extract probability tensor
        tensor_probs = torch.tensor([rel for (_, _, _, rel) in list_key_tuple_answers], dtype=torch.float32)

        # Sample through the tensor
        dict_multinomial = {
            "input": tensor_probs,
            "num_samples": 1,
            "replacement": True
        }
        if generator:
            dict_multinomial["generator"] = generator
        i = torch.multinomial(**dict_multinomial).item()

        # Get the key using the sampled index
        key = list_key_tuple_answers[i][0]

        # Split the key at the separator ("+")
        list_keys = key.split(separator)

        # Iterate list_keys as long as there is no end token (</s>)
        for _key in list_keys[len(list_keys) - 1:]:
            if _key == END_TOKEN:
                max_index += 1
                print("")
            else:
                print(_key, end=" ")

        # arg is the key replaced by the former arg and the separator
        arg = "+".join(list_keys[1:])
        if (max_index >= max_lines):
            print("\nEND")
            break

predict(list_list_words, "the+club+isnt")

# Possible questions:
# - Can you generate the same output with both wihtout and with the seed?
# - How to generate accurate song lyrics from Shape of You by Ed Sheeran?
# - Why does your generation stop abruptly?

START

the club isnt the best place to find a lover 
thought id find her in a bottle 
god make me another one 
you got the kind of look in your eyes 
well baby im just being honest uh 
and i am only being honest with you i 
i get lonely and make mistakes from time to time 
se enioma enko ye bibia be ye ye 
bibia be ye ye 
bibia be ye ye ye 

END
