In [2]:
import os
import yaml

import fasttext
import torch
from tqdm import tqdm

from model.encoder import CandidateEncoderConfig
from model.decoder import CandidateDecoderConfig
from config.general_config import GeneralConfig
from trainer.trainer import TrainerConfig
from dataset.dataset import SellersDataset

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

with open("config/config.yaml", "r") as file:
    try:
        config = yaml.safe_load(file)["vae"]
    except yaml.YAMLError as exc:
        print(exc)

general_config = GeneralConfig(**config["general"])
encoder_config = CandidateEncoderConfig(**config["encoder"], **config["general"])

decoder_config = CandidateDecoderConfig(**config["decoder"], **config["general"])

In [4]:
dataset = SellersDataset(
    dataset_path="data/dataset_fasttext/",
    embedder_name=general_config.embedder_name,
    raw_data_path=general_config.raw_data_path,
    device=DEVICE,
    bow_remove_stopwords=general_config.bow_remove_stopwords,
    bow_remove_sentiment=general_config.bow_remove_sentiment,
    nn_embedding_size=encoder_config.lstm_hidden_dim,
    trim_tr=general_config.trim_tr,
)
# dataset.prepare_dataset(dropna=False)
dataset.load_dataset()

Loading dataset...
[2022-05-30 00:11:54,750] {dataset.py:226} INFO - Loading dataset...
Loaded dataset!
[2022-05-30 00:11:56,807] {dataset.py:245} INFO - Loaded dataset!




In [7]:
train_file = "data/dataset_fasttext/train.txt"

texts = []
for idx in tqdm(range(len(dataset))):
    texts.append(dataset.get_textual_description(idx))

with open(train_file, "w") as file:
    file.writelines([text + "\n" for text in texts]) 

100%|███████████████████████████████████| 55252/55252 [00:23<00:00, 2322.65it/s]


In [14]:
model = fasttext.train_unsupervised(train_file, minn=3, maxn=6, epoch=5, dim=100)

model.save_model("model/fasttext/cv.en.100.bin")

Read 4M words
Number of words:  14794
Number of labels: 0
Progress: 100.0% words/sec/thread:   58734 lr:  0.000000 avg.loss:  2.037953 ETA:   0h 0m 0s100.0% words/sec/thread:   58734 lr: -0.000003 avg.loss:  2.037953 ETA:   0h 0m 0s


In [13]:
model.get_nearest_neighbors('excel')

[(0.9046449065208435, 'msexcel'),
 (0.8958105444908142, 'microsoft'),
 (0.887099027633667, 'office'),
 (0.8751331567764282, 'word'),
 (0.8606970310211182, 'entry'),
 (0.8443353176116943, 'msoffice'),
 (0.8427814245223999, 'data'),
 (0.8026374578475952, 'typing'),
 (0.7979578375816345, 'spreadsheet'),
 (0.780920684337616, 'powerpoint')]