# **Prepare data**

In [11]:
!gdown --id 1rIcrwTKF7S-uO6CPsOta_ZGsiWiHOcJu

'gdown' is not recognized as an internal or external command,
operable program or batch file.


In [12]:
!unzip train_data.zip 

'unzip' is not recognized as an internal or external command,
operable program or batch file.


# **Import packages**

In [13]:
from tqdm import tqdm
import numpy as np
import gensim
import os 
from nltk.tokenize import word_tokenize
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from gensim.parsing.preprocessing import remove_stopwords
from gensim.parsing.preprocessing import STOPWORDS
import multiprocessing
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ASUS\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [14]:
dir_path = os.path.dirname(os.path.realpath(os.getcwd()))
data_path = os.path.join(dir_path, 'data\\train_data')
cores = multiprocessing.cpu_count()

In [15]:
data_path

'C:\\Users\\ASUS\\Documents\\Projects\\Python\\topic-classification\\data\\train_data'

# **Pre-process data**

In [16]:
def rm_stopwords(tokenized_doc, stop_words):
    tok_without_sw=[]
    for txt_tokens in tokenized_doc:
        tok_without_sw = [word for word in tokenized_doc if not word.lower() in STOPWORDS]
    return tok_without_sw

In [17]:
def process_data(data):
    data = ' '.join(data)
    data = gensim.utils.simple_preprocess(data)
    data = ' '.join(data)
    processed_data = word_tokenize(data)
    processed_data = rm_stopwords(processed_data, STOPWORDS)
    return processed_data


In [18]:
def get_data(folder_path):
    dirs = os.listdir(folder_path)
    processed_doc = []
    for path in tqdm(dirs):
        file_paths = os.listdir(os.path.join(folder_path, path))
        for file_path in tqdm(file_paths):
            with open(os.path.join(folder_path, path, file_path), 'r',encoding='utf-8') as f:
                data = f.readlines()
                tokenized_doc = process_data(data)
                processed_doc.append([tokenized_doc,path])
    return processed_doc

In [19]:
def tagging_data(data):
    tagged_doc =[]
    for case in range(len(data)):
      case_i = TaggedDocument(data[case][0],[data[case][1]])
      tagged_doc.append(case_i)
    return tagged_doc

In [20]:
data_train = get_data(data_path)

100%|██████████| 1752/1752 [04:30<00:00,  6.49it/s]
100%|██████████| 1795/1795 [02:27<00:00, 12.18it/s]
100%|██████████| 286/286 [00:28<00:00,  9.96it/s]
100%|██████████| 1845/1845 [02:58<00:00, 10.33it/s]
100%|██████████| 1826/1826 [02:49<00:00, 10.77it/s]
100%|██████████| 1780/1780 [02:44<00:00, 10.79it/s]
100%|██████████| 1608/1608 [02:54<00:00,  9.21it/s]
100%|██████████| 1832/1832 [02:59<00:00, 10.23it/s]
100%|██████████| 8/8 [21:52<00:00, 164.12s/it]


In [21]:
len(data_train)

12724

In [22]:
tagged_doc = tagging_data(data_train)

# **Train model**

In [23]:
model = Doc2Vec(tagged_doc, vector_size=300, window=5, min_count=20, workers=cores, epochs = 80)

In [24]:
model_path = os.path.dirname(os.path.realpath(os.getcwd()))
model_path = os.path.join(model_path, 'models\\d2v.model')
model.save(model_path)

# **Test model**

In [25]:
model= Doc2Vec.load(model_path)

In [26]:
text_check = """
The Health Ministry has proposed to stop all non-essential activities with large gatherings during the incoming Lunar New Year holiday in light of recent Covid-19 surges.
In a document sent to the Government Office on Friday, the Ministry of Health stated the number of Covid-19 cases in the community has been rising, especially in areas with high population and traffic density.

The Omicron variant circulating around the world is also a concerning variable, despite Vietnam having recorded no such infection, it added.

Tran Dac Phu, former head of the General Department of Preventive Medicine, said essential activities for economic development and production should be sustained, while other activities could continue under certain conditions, or be suspended.

Choosing to live with the virus means there would be more cases, he added. "People should only go to places with large gatherings when it's truly necessary, avoid parties and reduce the number of people in meetings, especially in the coming times."

Vietnam has recorded over 1.5 million Covid-19 cases in the fourth coronavirus wave, with over 29,000 deaths. Both infections and deaths have been on the rise lately, with around 15,000 new cases daily.

Compared to the previous month, the number of community transmission cases has risen by 186.4 percent, and the number of severe cases, by 62.2 percent.

Tet, or the Lunar New Year, is Vietnam's biggest holiday which normally involves a lot of festivals, partying and traveling. The holiday peaks on February 1 this year."""

In [27]:
model = Doc2Vec.load(model_path)
test_doc = word_tokenize(text_check.lower())
result = model.docvecs.most_similar(positive=[model.infer_vector(test_doc)],topn=5)
result

  result = model.docvecs.most_similar(positive=[model.infer_vector(test_doc)],topn=5)


[('covid-19', 0.36554476618766785),
 ('health', 0.2853069007396698),
 ('politics', 0.2283198982477188),
 ('business', 0.17431792616844177),
 ('environment', 0.12269987910985947)]

In [28]:
!pip list

Package                       Version
----------------------------- ---------
backcall                      0.2.0
backports.functools-lru-cache 1.6.4
beautifulsoup4                4.10.0
boto                          2.49.0
boto3                         1.18.21
botocore                      1.21.41
Bottleneck                    1.3.2
brotlipy                      0.7.0
bz2file                       0.98
certifi                       2021.10.8
cffi                          1.15.0
charset-normalizer            2.0.4
click                         8.0.3
colorama                      0.4.4
cryptography                  36.0.0
Cython                        0.29.23
debugpy                       1.5.1
decorator                     5.1.0
entrypoints                   0.3
gensim                        4.0.1
idna                          3.3
ipykernel                     6.6.0
ipython                       7.30.1
jedi                          0.18.1
jmespath                      0.10.0
joblib    