In [1]:
import pandas as pd
import csv
import pickle as pkl
import numpy as np
from fasttext import FastText
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch import nn
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import datetime
from nltk import pos_tag
from tqdm import tqdm

In [2]:
class NetworkModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.input = nn.Linear(300, 150)
        self.linear_stack = nn.Sequential(
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(150, 100),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(50, 25),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(25, 10),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(10, 5),
            nn.ReLU(),
        )
        self.output = nn.Linear(5, 2)

    def forward(self, x):
        x = self.input(x)
        x = self.linear_stack(x)
        x = self.output(x)
        return x

mlp = NetworkModel()
mlp.load_state_dict(torch.load('./models/best_model.pth'))

<All keys matched successfully>

In [3]:
mlp

NetworkModel(
  (input): Linear(in_features=300, out_features=150, bias=True)
  (linear_stack): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): ReLU()
    (2): Linear(in_features=150, out_features=100, bias=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): ReLU()
    (5): Linear(in_features=100, out_features=50, bias=True)
    (6): Dropout(p=0.5, inplace=False)
    (7): ReLU()
    (8): Linear(in_features=50, out_features=25, bias=True)
    (9): Dropout(p=0.4, inplace=False)
    (10): ReLU()
    (11): Linear(in_features=25, out_features=10, bias=True)
    (12): Dropout(p=0.3, inplace=False)
    (13): ReLU()
    (14): Linear(in_features=10, out_features=5, bias=True)
    (15): ReLU()
  )
  (output): Linear(in_features=5, out_features=2, bias=True)
)

In [4]:
embedding_model = FastText.load_model('cc.en.300.bin')



In [5]:
def get_pos(word):
    """Gets the part of speech tag for a single word."""
    word_tag = pos_tag([word])
    return word_tag[0][1]  # Extract the tag from the tuple

def check_valid_pos(word):
    pos = get_pos(word)
    list_1 = ["JJ", "JJR", "JJS", "RB", "RBR", "RBS"]
    if(pos in list_1):
        return True
    return False

In [6]:
class Word2VA():
    def __init__(self, ):
        self.model = None
        self.mlp = None
        pass
    def set_embedding_model(self, model):
        self.model = model
    def set_mlp_model(self, model):
        self.mlp = model
    def return_embedding(self, word):
        if(self.model == None):
            print("No Embedding Model")
            return None
        return torch.tensor(self.model.get_word_vector(word))
    def return_va_values(self, embedding):
        if(self.mlp == None):
            print("No MLP ")
            return None
        return self.mlp(embedding)
    def va_value(self, word):
        if(type(word) != str):
            raise TypeError
        va_value_vector = self.return_va_values(self.return_embedding(word))
        return va_value_vector

word2va_model = Word2VA()
word2va_model.set_embedding_model(embedding_model)
word2va_model.set_mlp_model(mlp)


In [7]:
temp = word2va_model.va_value('aardvark')

In [8]:
temp

tensor([0.6305, 0.4369], grad_fn=<ViewBackward0>)

In [9]:
all_tags = pkl.load(open('./../mod_dataset/tags.pkl', 'rb'))

In [10]:
print(len(all_tags))
all_tags

3092178


['indie',
 'Mellow',
 'femalevocalistsgdchill',
 'pop',
 'Dreamy',
 'sweet',
 'alternative',
 'chillout',
 'happy',
 'Favorite',
 'wedding',
 'heartstrings',
 'indiewave',
 'sbeeaachouseiloves',
 'electronic',
 'rock',
 'heavy',
 'folk',
 'ambient',
 'experimental',
 'sexy',
 'memories',
 'dark',
 'acoustic',
 'indietronica',
 'beautiful',
 'bells',
 'guitar',
 'organ',
 'NoMeansNo',
 'punk',
 'hardcore',
 'jazzcore',
 'Canadian',
 'Nomeansno',
 'rock',
 'indie',
 'alternative',
 'electro',
 'Awesome',
 'bc',
 'Insanity',
 'tension',
 'pulse',
 'Wrong',
 'bujamsie',
 'BIRP',
 'somafm',
 'Bagel',
 'ktcth',
 'experimental',
 'insane',
 'wow',
 'Fave',
 'favorites',
 'pirate',
 'noise',
 'experimental',
 'fuck',
 'KILL',
 'indie',
 'rap',
 'wow',
 'Extreme',
 'fabfuckintastic',
 'noisetronica',
 'fockedop',
 'transgender',
 'rock',
 'lgbt',
 'played',
 'singlesphere',
 'noise',
 'electronic',
 'experimental',
 'indie',
 'rock',
 'fip',
 'hardcore',
 'rap',
 'classic',
 'wu',
 'hiphop',
 '

In [11]:
all_tags = list(set(all_tags))
print(len(all_tags))

52392


In [12]:
va_dict = dict()
for i in tqdm(range(len(all_tags))):
    if(check_valid_pos(all_tags[i])):
        # print(all_tags[i])
        va_dict[all_tags[i]] = word2va_model.va_value(all_tags[i])

100%|██████████| 52392/52392 [00:09<00:00, 5256.96it/s]


In [13]:
pkl.dump(va_dict, open("va_preprocessed_word.pkl", 'wb'))