In [2]:
import os
import pandas as pd
import csv
import pickle as pkl
import numpy as np
from fasttext import FastText
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch import nn
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import datetime
from nltk import pos_tag
from tqdm import tqdm

In [3]:
class NetworkModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.input = nn.Linear(300, 150)
        self.linear_stack = nn.Sequential(
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(150, 100),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(50, 25),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(25, 10),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(10, 5),
            nn.ReLU(),
        )
        self.output = nn.Linear(5, 2)

    def forward(self, x):
        x = self.input(x)
        x = self.linear_stack(x)
        x = self.output(x)
        return x

mlp = NetworkModel()
mlp.load_state_dict(torch.load('./models/best_model.pth'))

<All keys matched successfully>

In [4]:
embedding_model = FastText.load_model('cc.en.300.bin')



In [5]:
class Word2VA():
    def __init__(self, ):
        self.model = None
        self.mlp = None
        pass
    def set_embedding_model(self, model):
        self.model = model
    def set_mlp_model(self, model):
        self.mlp = model
    def return_embedding(self, word):
        if(self.model == None):
            print("No Embedding Model")
            return None
        return torch.tensor(self.model.get_word_vector(word))
    def return_va_values(self, embedding):
        if(self.mlp == None):
            print("No MLP ")
            return None
        return self.mlp(embedding)
    def va_value(self, word):
        if(type(word) != str):
            raise TypeError
        va_value_vector = self.return_va_values(self.return_embedding(word))
        return va_value_vector




word2va_model = Word2VA()
word2va_model.set_embedding_model(embedding_model)
word2va_model.set_mlp_model(mlp)


In [6]:
dir_1 = "./../dataset/1year_top500tracks_with_tags/"
file_list = os.listdir(dir_1)

In [7]:
def get_pos(word):
    """Gets the part of speech tag for a single word."""
    word_tag = pos_tag([word])
    return word_tag[0][1]  # Extract the tag from the tuple
def valid_word(word):
    if len(word.split()) > 1 or (not word.isalpha() )or ( not word.isascii()):
        return False
    pos = get_pos(word)
    list_1 = ["JJ", "JJR", "JJS", "RB", "RBR", "RBS"]
    if(pos in list_1):
        return True
    return False

In [8]:
all_user_va_words = dict()

In [31]:
for idx in tqdm(range(len(file_list))):
    file_1 = file_list[idx]
    temp = open(dir_1 + file_1)
    temp = temp.read()
    temp = temp.split('\n')
    user_weight_dict = dict()
    for i in temp[1:]:
        # print(i)
        x = i.split(',')
        # print(x)
        for j in range(len(x)):
            if(x[j].isdigit()):
                play_count = int(x[j])
                x = x[(j+1):]
                break
        # if(x[1] == '"Tyler' or x[1] == ' I Am A Thot"'):
        #     play_count = int(x[3])
        #     x = x[4:]
        # else:
        #     play_count = int(x[2])
        #     x = x[3:]
        for j in range(0, len(x), 2):
            if(valid_word(x[j])):
                # print(x[j])
                if(x[j] in user_weight_dict):
                    user_weight_dict[x[j]].append(int(x[j+1])*play_count)
                else:
                    user_weight_dict[x[j]] = [int(x[j+1])*play_count]
    user_tag_set = list(user_weight_dict.keys())
    va_end = np.array([0.0, 0.0])
    total_count = 0
    for i in user_tag_set:
        va_tag = np.array(word2va_model.va_value(i).tolist())
        va_end += va_tag*(np.sum(user_weight_dict[i]))
        total_count += (np.sum(user_weight_dict[i]))
    va_end /= total_count
    all_user_va_words[file_1] = va_end

100%|██████████| 541/541 [07:44<00:00,  1.16it/s]


In [19]:
all_user_va_words

{'48e9a7f191695247670d622319fba921.csv': array([0.60159399, 0.46248166]),
 '1bc8474dc465d16f5c2a424b2a7427e7.csv': array([0.58019775, 0.46163458]),
 '2bdf82895c47ce2fd668fd6369da78db.csv': array([0.62774141, 0.44810855]),
 'cebe7bb75898cb69fe9bb992d10cf954.csv': array([0.6268855 , 0.45200578]),
 '929e1073f2ec54bf04e18ca8f33a1724.csv': array([0.64130253, 0.44796665]),
 'e81cc3349b02197a403e412ad57157cc.csv': array([0.62133687, 0.44458032]),
 '3ae99483f8e0b14ff7c264625c96c507.csv': array([0.63131315, 0.44317926]),
 'd125a06735486fa936de1f856e01bbb4.csv': array([0.65810375, 0.45522119]),
 '003534c470fd624743ee3f7acef28a53.csv': array([0.58984938, 0.45279962]),
 '05ed17d0ecc78208ba5d6a3661a9fc95.csv': array([0.54021596, 0.47222574]),
 '20f3da7086adafce8a7adf3a85c67c7a.csv': array([0.63223746, 0.4423741 ]),
 'ae37da73621a312704f806a04664d5ae.csv': array([0.63972423, 0.45009325]),
 'ff7aefe14748d4cfed2a7f6cbc6cb226.csv': array([0.62422694, 0.44680883]),
 '2fe1a2176412d3a3307a28047d16dfbc.csv

In [17]:
pkl.dump(all_user_va_words, open('./all_user_va_words.pkl', 'wb'))

In [15]:
all_user_va_words

{'48e9a7f191695247670d622319fba921.csv': array([0.62776437, 0.44471566]),
 '1bc8474dc465d16f5c2a424b2a7427e7.csv': array([0.60038346, 0.44718212]),
 '2bdf82895c47ce2fd668fd6369da78db.csv': array([0.6284668 , 0.43972123]),
 'cebe7bb75898cb69fe9bb992d10cf954.csv': array([0.6467357 , 0.44958853]),
 '929e1073f2ec54bf04e18ca8f33a1724.csv': array([0.63511963, 0.44009491]),
 'e81cc3349b02197a403e412ad57157cc.csv': array([0.62639084, 0.43650207]),
 '3ae99483f8e0b14ff7c264625c96c507.csv': array([0.64190446, 0.44145089]),
 'd125a06735486fa936de1f856e01bbb4.csv': array([0.6496025, 0.4453403]),
 '003534c470fd624743ee3f7acef28a53.csv': array([0.61486697, 0.43366431]),
 '05ed17d0ecc78208ba5d6a3661a9fc95.csv': array([0.61232984, 0.43005474]),
 '20f3da7086adafce8a7adf3a85c67c7a.csv': array([0.62350067, 0.43497552]),
 'ae37da73621a312704f806a04664d5ae.csv': array([0.63573935, 0.43915293]),
 'ff7aefe14748d4cfed2a7f6cbc6cb226.csv': array([0.62891423, 0.43685042]),
 '2fe1a2176412d3a3307a28047d16dfbc.csv':