In [25]:
import re
import torch
import torch.nn.functional as F
from lib.utils import load, preprocess, get_dict, ask_dict, sanitize

In [26]:
NGRAM_CONTEXT_WINDOW = 2

END_TOKEN = "</s>"
WITH_END_TOKEN = False

## Preprocessing

In [27]:
list_str = load('./data/songs.json')
list_str = [preprocess(str["row"]["text"]) for str in list_str["rows"] if str["row"]["text"]]

if WITH_END_TOKEN:
    # Treat new lines as a word or not
    list_str = [str.replace("\n", f" {END_TOKEN} ") for str in list_str]

list_str

['The club isnt the best place to find a lover\nSo the bar is where I go\nMe and my friends at the table doing shots\nDrinking fast and then we talk slow\nAnd you come over and start up a conversation with just me\nAnd trust me Ill give it a chance now\nTake my hand stop put Van the Man on the jukebox\nAnd then we start to dance and now Im singing like\nGirl you know I want your love\nYour love was handmade for somebody like me\nCome on now follow my lead\nI may be crazy dont mind me\nSay boy lets not talk too much\nGrab on my waist and put that body on me\nCome on now follow my lead\nCome come on now follow my lead\nIm in love with the shape of you\nWe push and pull like a magnet do\nAlthough my heart is falling too\nIm in love with your body\nAnd last night you were in my room\nAnd now my bed sheets smell like you\nEvery day discovering something brand new\nIm in love with your body\nOh I oh I oh I oh I\nIm in love with your body\nOh I oh I oh I oh I\nIm in love with your body\nOh I 

## Intializations

In [28]:
# Split the lines into words and lowercase them
list_list_words = sanitize(list_str)

list_list_words

[['the',
  'club',
  'isnt',
  'the',
  'best',
  'place',
  'to',
  'find',
  'a',
  'lover',
  'so',
  'the',
  'bar',
  'is',
  'where',
  'i',
  'go',
  'me',
  'and',
  'my',
  'friends',
  'at',
  'the',
  'table',
  'doing',
  'shots',
  'drinking',
  'fast',
  'and',
  'then',
  'we',
  'talk',
  'slow',
  'and',
  'you',
  'come',
  'over',
  'and',
  'start',
  'up',
  'a',
  'conversation',
  'with',
  'just',
  'me',
  'and',
  'trust',
  'me',
  'ill',
  'give',
  'it',
  'a',
  'chance',
  'now',
  'take',
  'my',
  'hand',
  'stop',
  'put',
  'van',
  'the',
  'man',
  'on',
  'the',
  'jukebox',
  'and',
  'then',
  'we',
  'start',
  'to',
  'dance',
  'and',
  'now',
  'im',
  'singing',
  'like',
  'girl',
  'you',
  'know',
  'i',
  'want',
  'your',
  'love',
  'your',
  'love',
  'was',
  'handmade',
  'for',
  'somebody',
  'like',
  'me',
  'come',
  'on',
  'now',
  'follow',
  'my',
  'lead',
  'i',
  'may',
  'be',
  'crazy',
  'dont',
  'mind',
  'me',
  's

## Dictionary

In [32]:
dict_key_tuple = get_dict(list_list_words, NGRAM_CONTEXT_WINDOW)
list_key_tuple_answers = ask_dict(dict_key_tuple, "i")

list_key_tuple = [(key, tuple[0], tuple[1]) for key, tuple in dict_key_tuple.items()]
tensor_probs = torch.tensor([tuple[1] for _, tuple in dict_key_tuple.items()], dtype=torch.float32)

def bigram(tensor_probs, list_key_tuple):
    index = torch.multinomial(tensor_probs, num_samples=1, replacement=True).item()
    return list_key_tuple[index][0]

dict_key_tuple
list_key_tuple_answers

[('i+dont', 201, 0.0045314155601145255, 0.10953678474114441),
 ('i+know', 124, 0.0027955001465383143, 0.06757493188010899),
 ('i+love', 70, 0.0015781049214329192, 0.03814713896457766),
 ('i+was', 63, 0.0014202944292896274, 0.03433242506811989),
 ('i+wanna', 62, 0.0013977500732691571, 0.033787465940054495),
 ('i+just', 50, 0.0011272178010235138, 0.027247956403269755),
 ('i+can', 50, 0.0011272178010235138, 0.027247956403269755),
 ('i+could', 50, 0.0011272178010235138, 0.027247956403269755),
 ('i+got', 45, 0.0010144960209211623, 0.02452316076294278),
 ('i+wont', 39, 0.0008792298847983407, 0.02125340599455041),
 ('i+will', 38, 0.0008566855287778705, 0.020708446866485014),
 ('i+oh', 36, 0.0008115968167369299, 0.019618528610354225),
 ('i+want', 33, 0.0007439637486755191, 0.017983651226158037),
 ('i+think', 30, 0.0006763306806141083, 0.01634877384196185),
 ('i+cant', 29, 0.0006537863245936379, 0.01580381471389646),
 ('i+never', 27, 0.0006086976125526974, 0.014713896457765668),
 ('i+said', 27,