In [1]:
import string
import time
from nltk.corpus import stopwords
from nltk import download

import numpy as np
import pyspark
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.ml.feature import CountVectorizer
from pyspark.mllib.clustering import LDA
from pyspark.mllib.linalg import Vectors as MLlibVectors

In [2]:
download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/mobod2022/mob2022013/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [3]:
def read_txt(data_path, min_token_tf, max_token_tf, min_token_length, min_doc_length=50, is_vw_format=False):
    time_start = time.time()
    sqlContext = SQLContext(sc)

    def parse_file(kv):
        line = kv if is_vw_format else kv[1]
        line = line.replace('\n', ' ').replace('\t', ' ')
        
        if is_vw_format:
            line_list = []
            for token in line.split(' ')[1: ]:
                lst = token.split(':')
                if len(lst) == 1:
                    line_list.append(token)
                else:
                    line_list += [lst[0]] * int(float(lst[1]))
            line = ' '.join(line_list)
        
        for p in string.punctuation:
            line = line.replace(p, ' ')
        
        tokens = [e.strip().lower() for e in line.strip().split(' ') if len(e) > 0]
        if is_vw_format:
            return tokens
        else:
            return (kv[0], tokens)


    def filter_token(kv):
        token = kv[0]
        value = kv[1]

        if value > max_token_tf or value < min_token_tf:
            return False

        if len(token) < min_token_length:
            return False

        if token in stopwords_:
            return False

        for i in '0123456789':
            if i in token:
                return False

        return True


    def get_tokens(tokens):
        if is_vw_format:
            return tokens
        return tokens[1]


    def parseVectors(line):
        return [int(line[2]), line[0]]


    if is_vw_format:
        dataset = sc.textFile(data_path)
    else:
        dataset = sc.wholeTextFiles("{}/*".format(data_path))
    dataset = dataset.map(parse_file)
    
    word_counts = (dataset
                   .flatMap(lambda path_with_tokens: ((token, 1) for token in get_tokens(path_with_tokens)))
                   .reduceByKey(lambda cnt_1, cnt_2: cnt_1 + cnt_2)
                   .sortBy(lambda token_with_cnt: -token_with_cnt[1]))

    stopwords_ = set(stopwords.words('english'))

    word_counts = word_counts.filter(filter_token)
    vocab = set([e[0] for e in word_counts.collect()])

    print('Total number of tokens: {}'.format(len(vocab)))
    
    if is_vw_format:
        dataset = (dataset
                   .map(lambda kv: (0, list(filter(lambda t: t in vocab, kv))))
                   .filter(lambda kv: len(kv[1]) > min_doc_length))
    else:
        dataset = (dataset
                   .map(lambda kv: (kv[0].split('/')[-1], list(filter(lambda t: t in vocab, kv[1]))))
                   .filter(lambda kv: len(kv[1]) > min_doc_length))
    
    print('Total number of documents: {}'.format(dataset.count()))
    
    data_df = sqlContext.createDataFrame(dataset, ['id', 'tokens'])
    
    data_df = data_df.withColumn("id", monotonically_increasing_id())
    
    cv = CountVectorizer(inputCol="tokens", outputCol="vectors")
    cv_model = cv.fit(data_df)
    df_vect = cv_model.transform(data_df)

    bow = (df_vect
           .select('vectors', 'tokens', 'id')
           .rdd.map(parseVectors)
           .mapValues(MLlibVectors.fromML)
           .map(list))
    
    nnz = sum(bow.map(lambda x: list(x[1].values)).reduce(lambda x, y: x + y))
    print('Total collection size: {}'.format(nnz))

    print('Elapsed time : {} sec.'.format(int(time.time() - time_start)))
    return bow, cv_model, nnz

In [76]:
class TopicModel:
    
    def __init__(self, num_topics, cv_model, nnz, num_document_passes, use_phi_broadcast=True, beta=0.0):
        self.num_topics = num_topics                    # число тем в модели
        self.cv_model_vocabulary = cv_model.vocabulary  # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
        self.nnz = nnz                                  # общее число словопозиций в коллекции
        self.num_document_passes = num_document_passes  # число проходов по документу на E-шаге
        self.use_phi_broadcast = use_phi_broadcast      # использование бродкастинга матрицы $\Phi$
        self.beta = beta                                # коэффициент регуляризации
        self.perplexity_list = []
        
        phi_wt_np = np.random.random((len(self.cv_model_vocabulary), self.num_topics))
        if self.use_phi_broadcast:
            self.phi_wt = sc.broadcast(phi_wt_np)
        else:
            self.phi_wt = phi_wt_np
                
    def fit(self, bow_data, num_collection_passes=10):
        self.perplexity_list = []
        time_start = time.time()
        for _ in range(num_collection_passes):
            
            def process_document(document, 
                                 num_topics=self.num_topics, 
                                 num_document_passes=self.num_document_passes,
                                 phi_wt=self.phi_wt, 
                                 use_phi_broadcast=self.use_phi_broadcast):
                
                theta = np.array([1/num_topics] * num_topics)
                n_dw = np.zeros(document.size)
                n_dw[document.indices] = document.values
                
                phi = phi_wt.value if use_phi_broadcast else phi_wt
                
                for _ in range(num_document_passes):
                    _p_tdw = np.einsum('wt,t->tw', phi, theta)
                    p_tdw = _p_tdw / np.sum(_p_tdw, axis=0, keepdims=True)
                    
                    _theta = np.einsum('t,wt->w', n_dw, p_tdw)
                    theta = _theta / np.sum(_theta, axis=0, keepdims=True)
                
                n_wt = np.einsum('w,tw->wt', n_dw, p_tdw)
                
                theta = np.nan_to_num(theta, 0)
                
                return n_wt, (n_dw * np.log((phi * theta).sum(axis=1))).sum()
        
            def E_step(rows):
                for row in rows:
                    _, documents = row
                    n_wt, perplexity = process_document(documents)
                    yield n_wt, perplexity

            E = bow_data.partitionBy(5).mapPartitions(E_step)

            n_wt, perplexity = E.reduce(lambda x, y: (x[0] + y[0], x[1] + y[1]))

            phi = self.phi_wt.value if self.use_phi_broadcast else self.phi_wt
            
            _phi_wt = np.maximum(n_wt + phi + self.beta, 1e-40)
            phi_wt  = _phi_wt / _phi_wt.sum(axis=0, keepdims=True)

            if self.use_phi_broadcast:
                self.phi_wt = sc.broadcast(phi_wt)
            else:
                self.phi_wt = phi_wt

            self.perplexity_list.append(int(np.exp(-perplexity / self.nnz)))
        
        print('Elapsed time : {} sec.\n'.format(int(time.time() - time_start)))
        
    def print_perplexity(self):
        print(self.perplexity_list)
    
    def print_topics(self, num_tokens=10):
        indexes = self.phi_wt.argsort(axis=0)[-num_tokens:][::-1]
        for idx, i_w in enumerate(range(indexes.shape[1])):
            print('topic_id: {}'.format(idx))
            for i in indexes[:, i_w]:
                print(model.cv_model_vocabulary[i])
            print('--------')

In [5]:
sc = SparkContext.getOrCreate()

In [6]:
data_path = "/data/mobod/tm/vw.wiki-en-20K.txt"
min_token_tf = 10
max_token_tf = 30000
min_token_length = 3

In [7]:
bow, cv_model, nnz = read_txt(data_path, \
                              min_token_tf, \
                              max_token_tf, \
                              min_token_length, \
                              min_doc_length=50, \
                              is_vw_format=True)

Total number of tokens: 42143
Total number of documents: 16029
Total collection size: 5193241.0
Elapsed time : 389 sec.


Перед проведением экспериментов, проверим, что данные разбиваются на достаточное число партиций и что среди них нет вырожденных (существенно меньших по объёму, чем прочие).

In [9]:
def count_row_in_partitions(rows):
    count = 0
    for _ in rows:
        count += 1
    yield count

In [10]:
%%time
bow.mapPartitions(count_row_in_partitions).collect()

CPU times: user 8 ms, sys: 4 ms, total: 12 ms
Wall time: 21.3 s


[8038, 7991]

In [11]:
%%time
bow.partitionBy(5).mapPartitions(count_row_in_partitions).collect()

CPU times: user 12 ms, sys: 0 ns, total: 12 ms
Wall time: 11.1 s


[3206, 3206, 3207, 3205, 3205]

Оценим время работы при num_document_passes=5 и num_collection_passes=10, без broadcast переменной, с числом тем num_topics = 10, 20, 50.

In [11]:
num_topics_list = [10, 20, 50]
for num_topics in num_topics_list:
    print('Num topics: ', num_topics)
    model = TopicModel(num_topics=num_topics,    # число тем в модели
                       cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                       nnz=nnz,                  # общее число словопозиций в коллекции
                       num_document_passes=5,    # число проходов по документу на E-шаге
                       use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                       beta=0.0)                 # коэффициент регуляризации
    model.fit(bow, num_collection_passes=10) 

Num topics:  10
Done
Elapsed time : 1200 sec.
Num topics:  20
Done
Elapsed time : 1953 sec.
Num topics:  50
Done
Elapsed time : 5161 sec.


Оценим время работы при num_document_passes=5 и num_collection_passes=10, с broadcast переменной, с числом тем num_topics = 10, 20, 50.

In [12]:
num_topics_list = [10, 20, 50]
for num_topics in num_topics_list:
    print('Num topics: ', num_topics)
    model = TopicModel(num_topics=num_topics,    # число тем в модели
                       cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                       nnz=nnz,                  # общее число словопозиций в коллекции
                       num_document_passes=5,    # число проходов по документу на E-шаге
                       use_phi_broadcast=True,   # использование бродкастинга матрицы $\Phi$
                       beta=0.0)                 # коэффициент регуляризации
    model.fit(bow, num_collection_passes=10)

Num topics:  10
Done
Elapsed time : 1464 sec.
Num topics:  20
Done
Elapsed time : 2461 sec.
Num topics:  50
Done
Elapsed time : 4760 sec.


Можно сделать вывод о том, что использование broadcast не ускоряет алгоритм для маленького количества тем, но возможно есть небольшой прирост когда тем 50.

---------------------------

Оценим время работы, перплексию и интерпретируемость тем по топ-словам при num_topics=20 и num_collection_passes=10 и различных значениях num_document_passes = 1, 2, 5, 10.

num_document_passes = 1:

In [14]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=1,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=0.0)                 # коэффициент регуляризации

In [15]:
%%time
model.fit(bow, num_collection_passes=10)

Done
Elapsed time : 693 sec.
CPU times: user 1.07 s, sys: 384 ms, total: 1.46 s
Wall time: 11min 33s


In [16]:
model.print_topics()

topic_id: 0
album
club
station
game
king
system
president
park
show
church
--------
topic_id: 1
album
town
side
air
white
club
district
political
title
division
--------
topic_id: 2
film
party
town
club
league
children
band
london
book
right
--------
topic_id: 3
album
league
club
president
political
king
party
french
division
championship
--------
topic_id: 4
league
king
town
society
game
district
station
man
song
band
--------
topic_id: 5
station
town
river
game
church
king
division
children
league
band
--------
topic_id: 6
film
system
church
king
party
show
album
football
road
band
--------
topic_id: 7
film
album
party
town
league
king
station
society
london
song
--------
topic_id: 8
film
station
game
band
football
research
law
center
published
division
--------
topic_id: 9
film
league
church
game
king
air
road
london
football
women
--------
topic_id: 10
film
album
party
town
right
band
park
son
man
song
--------
topic_id: 11
film
game
system
william
album
son
road
linear
river
build

In [17]:
model.print_perplexity()

[2, 11692, 11588, 11588, 11588, 11588, 11588, 11588, 11588, 11588]


-----------------------------

num_document_passes = 2:

In [18]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=2,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=0.0)                 # коэффициент регуляризации

In [19]:
%%time
model.fit(bow, num_collection_passes=10)

Done
Elapsed time : 1095 sec.
CPU times: user 1.12 s, sys: 368 ms, total: 1.48 s
Wall time: 18min 15s


In [20]:
model.print_topics()

topic_id: 0
film
james
band
robert
london
business
village
book
song
george
--------
topic_id: 1
king
emperor
film
empire
roman
album
china
province
show
award
--------
topic_id: 2
station
railway
road
center
district
park
street
central
town
party
--------
topic_id: 3
club
band
song
album
live
video
party
game
championship
record
--------
topic_id: 4
league
cup
football
round
division
club
park
ret
goals
games
--------
topic_id: 5
album
film
band
song
award
show
radio
love
chart
video
--------
topic_id: 6
linear
socorro
peak
saint
kitt
spacewatch
anderson
mount
system
neat
--------
topic_id: 7
church
king
president
art
german
army
military
air
radio
force
--------
topic_id: 8
education
research
open
australia
medal
club
art
society
network
women
--------
topic_id: 9
saint
cause
party
system
right
political
center
river
stadium
william
--------
topic_id: 10
party
development
aircraft
air
court
cup
player
church
version
game
--------
topic_id: 11
regiment
foot
river
battalion
indian
inf

In [21]:
model.print_perplexity()

[1, 11628, 11488, 11430, 11353, 11260, 11144, 10993, 10795, 10531]


-----------------------------

num_document_passes = 5:

In [22]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=5,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=0.0)                 # коэффициент регуляризации

In [23]:
%%time
model.fit(bow, num_collection_passes=10)

Done
Elapsed time : 1940 sec.
CPU times: user 1.22 s, sys: 384 ms, total: 1.6 s
Wall time: 32min 20s


In [24]:
model.print_topics()

topic_id: 0
water
species
research
river
often
region
common
form
level
areas
--------
topic_id: 1
club
division
league
japan
town
station
football
win
manchester
railway
--------
topic_id: 2
church
saint
town
village
french
trust
community
paris
institute
museum
--------
topic_id: 3
cup
league
club
stadium
football
round
goals
championship
total
match
--------
topic_id: 4
students
children
women
news
program
radio
president
society
human
media
--------
topic_id: 5
air
army
force
aircraft
military
fire
forces
battle
training
command
--------
topic_id: 6
isbn
book
published
books
modern
theory
god
written
text
volume
--------
topic_id: 7
party
king
china
roman
political
law
empire
democratic
province
hong
--------
topic_id: 8
race
points
indian
racing
india
miss
car
championship
women
event
--------
topic_id: 9
system
design
using
designed
model
type
standard
file
light
example
--------
topic_id: 10
album
song
show
episode
film
video
chart
records
live
love
--------
topic_id: 11
game
sa

In [25]:
model.print_perplexity()

[1, 11368, 10888, 10184, 9102, 7937, 7037, 6439, 6050, 5789]


-----------------------------

num_document_passes = 10:

In [26]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=10,   # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=0.0)                 # коэффициент регуляризации

In [27]:
%%time
model.fit(bow, num_collection_passes=10)

Done
Elapsed time : 3609 sec.
CPU times: user 1.32 s, sys: 544 ms, total: 1.86 s
Wall time: 1h 9s


In [28]:
model.print_topics()

topic_id: 0
species
miss
polish
little
black
white
female
mexico
german
male
--------
topic_id: 1
king
roman
father
cause
opera
mother
son
man
empire
daughter
--------
topic_id: 2
film
show
episode
television
award
radio
awards
role
director
films
--------
topic_id: 3
india
indian
population
region
island
africa
water
islands
land
sea
--------
topic_id: 4
game
player
games
league
win
players
championship
play
champion
round
--------
topic_id: 5
air
aircraft
army
force
navy
engine
car
ship
racing
military
--------
topic_id: 6
system
information
systems
data
access
development
health
water
using
process
--------
topic_id: 7
book
isbn
published
books
story
man
magazine
press
novel
stories
--------
topic_id: 8
album
band
song
records
songs
chart
video
guitar
track
live
--------
topic_id: 9
saint
french
art
france
paris
des
bgcolor
works
museum
jean
--------
topic_id: 10
castle
prince
son
emperor
chateau
chinese
army
military
china
king
--------
topic_id: 11
system
space
model
using
power
d

In [29]:
model.print_perplexity()

[1, 10917, 9608, 7920, 6673, 6002, 5629, 5402, 5254, 5153]


-----------------------------

Можно сделать ожидаемый вывод о том, что время работы увеличивается, а перплексия снижается с увеличением числа проходов по документу.
Что касается интерпретируемости тем, то для 1 и 2 проходов её оказывается недостаточно, хотя для 2 проходов уже можно сделать вывод об одной теме, но бывают выбросы. Разница между 5 и 10 проходами не существенна, поэтому оптимальным числом проходов будет 5.

-----------------------------

-----------------------------

Рассмотрим простейшую регуляризацию а-ля LDA, с помощью разреживания константой beta = 0.0, -0.1, -1.0, при параметрах num_topics=20, num_collection_passes=10 и num_document_passes=5.

beta = 0.0:

In [96]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=5,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=0.0)                 # коэффициент регуляризации

In [97]:
%%time
model.fit(bow, num_collection_passes=10)

Elapsed time : 2758 sec.

CPU times: user 1.18 s, sys: 292 ms, total: 1.47 s
Wall time: 45min 58s


In [98]:
model.print_perplexity()

[1, 11376, 10890, 10138, 8989, 7807, 6936, 6361, 5979, 5722]


Cчитаем разряженность матрицы:

In [99]:
sparsity = np.round(100 * (1.0 - np.count_nonzero(np.round(model.phi_wt, decimals=4)) / float(model.phi_wt.size)), decimals=1)
print('sparsity: ', sparsity, '%')

sparsity:  92.0 %


---

beta = -0.1:

In [100]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=5,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=-0.1)                # коэффициент регуляризации

In [101]:
%%time
model.fit(bow, num_collection_passes=10)

Elapsed time : 2534 sec.

CPU times: user 1.16 s, sys: 320 ms, total: 1.48 s
Wall time: 42min 14s


In [102]:
model.print_perplexity()

[1, 11334, 10883, 10151, 9012, 7784, 6877, 6316, 5969, 5747]


In [103]:
sparsity = np.round(100 * (1.0 - np.count_nonzero(np.round(model.phi_wt, decimals=4)) / float(model.phi_wt.size)), decimals=1)
print('sparsity: ', sparsity, '%')

sparsity:  91.9 %


---

beta = -1.0:

In [104]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=5,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=-1.0)                # коэффициент регуляризации

In [105]:
%%time
model.fit(bow, num_collection_passes=10)

Elapsed time : 2553 sec.

CPU times: user 1 s, sys: 280 ms, total: 1.28 s
Wall time: 42min 33s


In [106]:
model.print_perplexity()

[1, 11262, 10849, 9839, 8826, 7914, 7264, 6858, 6605, 6443]


In [107]:
sparsity = np.round(100 * (1.0 - np.count_nonzero(np.round(model.phi_wt, decimals=4)) / float(model.phi_wt.size)), decimals=1)
print('sparsity: ', sparsity, '%')

sparsity:  91.6 %


---

Перплексия увеличивается с увеличением параметра beta, как и разреженность.

---

beta = -25.0:

In [92]:
model = TopicModel(num_topics=20,            # число тем в модели
                   cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                   nnz=nnz,                  # общее число словопозиций в коллекции
                   num_document_passes=5,    # число проходов по документу на E-шаге
                   use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                   beta=-25.0)               # коэффициент регуляризации

In [93]:
%%time
model.fit(bow, num_collection_passes=10)

Elapsed time : 2208 sec.

CPU times: user 872 ms, sys: 420 ms, total: 1.29 s
Wall time: 36min 48s


In [94]:
model.print_perplexity()

[1, 96265710347577568, 74498027578544560, 55643859199365960, 35748269860209220, 20201163991965068, 12487707418476852, 8943836361390868, 6928327427149308, 5817054592067324]


In [95]:
sparsity = np.round(100 * (1.0 - np.count_nonzero(np.round(model.phi_wt, decimals=4)) / float(model.phi_wt.size)), decimals=1)
print('sparsity: ', sparsity, '%')

sparsity:  98.7 %


---

---

Заменим реализацию структуры для хранения счётчиков и матрицу Φ с numpy-массива на структуру из scipy для работы с разреженными матрицами.

In [108]:
from scipy.sparse import csr_matrix

In [109]:
class TopicModelSparse:

    def __init__(self, num_topics, cv_model, nnz, num_document_passes, use_phi_broadcast=True, beta=0.0):
        self.num_topics = num_topics                    
        self.cv_model_vocabulary = cv_model.vocabulary  
        self.nnz = nnz                                  
        self.num_document_passes = num_document_passes  
        self.use_phi_broadcast = use_phi_broadcast      
        self.beta = beta                                
        self.perplexity_list = []
        
        phi_wt_csr = csr_matrix(np.random.random((len(self.cv_model_vocabulary), self.num_topics)))
        if self.use_phi_broadcast:
            self.phiwt = sc.broadcast(phi_wt_csr)
        else:
            self.phiwt = phi_wt_csr
                
    def fit(self, bow_data, num_collection_passes=10):
        self.perplexity_list = []
        time_start = time.time()
        for _ in range(num_collection_passes):
            
            def process_document(document, 
                                 num_topics=self.num_topics, 
                                 num_document_passes=self.num_document_passes,
                                 phiwt=self.phiwt, 
                                 use_phi_broadcast=self.use_phi_broadcast):
                
                theta = csr_matrix(np.array([1/num_topics] * num_topics))
                n_dw = np.zeros(document.size)
                n_dw[document.indices] = document.values
                n_dw = csr_matrix(n_dw)
                
                phi = phiwt.value if use_phi_broadcast else phiwt
                
                for _ in range(num_document_passes):
                    #_p_tdw = np.einsum('wt,t->tw', phi, theta)
                    _p_tdw = (phi.multiply(theta)).T
                    p_tdw = _p_tdw / _p_tdw.sum(axis=0)
                    #_theta = np.einsum('t,wt->w', n_dw, p_tdw)
                    _theta = n_dw.dot(p_tdw.T)
                    theta = _theta / _theta.sum(axis=0)
                #n_wt = np.einsum('w,tw->wt', n_dw, p_tdw)
                n_wt = (p_tdw.multiply(n_dw)).T
                
                theta = np.nan_to_num(theta, 0)
                
                return n_wt, (n_dw.toarray() * np.log((phi.multiply(theta)).sum(axis=1))).sum()
        
            def E_step(rows):
                for row in rows:
                    _, documents = row
                    n_wt, perplexity = process_document(documents)
                    yield n_wt, perplexity

            E = bow_data.partitionBy(5).mapPartitions(E_step)


            n_wt, perplexity = E.reduce(lambda x, y: (x[0] + y[0], x[1] + y[1]))

            phi = self.phiwt.value if self.use_phi_broadcast else self.phiwt
            
            _phi_wt = np.maximum(n_wt.todense() + phi.todense() + self.beta, 1e-40)
            phiwt  = csr_matrix(_phi_wt / _phi_wt.sum(axis=0))

            if self.use_phi_broadcast:
                self.phiwt = sc.broadcast(phiwt)
            else:
                self.phiwt = phiwt


            self.perplexity_list.append(int(np.exp(-perplexity / (self.nnz + 1e-40))))
        
        print('Elapsed time : {} sec.\n'.format(int(time.time() - time_start)))
        
    def print_perplexity(self):
        print(self.perplexity_list)
    
    def print_topics(self, num_tokens=10):
        indexes = self.phiwt.argsort(axis=0)[-num_tokens:][::-1]
        for idx, i_w in enumerate(range(indexes.shape[1])):
            print('topic_id: {}'.format(idx))
            for i in indexes[:, i_w]:
                print(model.cv_model_vocabulary[i])
            print('--------')

In [110]:
num_topics_list = [10, 20, 50]
for num_topics in num_topics_list:
    print('Num topics: ', num_topics)
    model = TopicModel(num_topics=num_topics,    # число тем в модели
                       cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                       nnz=nnz,                  # общее число словопозиций в коллекции
                       num_document_passes=5,    # число проходов по документу на E-шаге
                       use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                       beta=-50.0)               # коэффициент регуляризации
    model.fit(bow, num_collection_passes=10)
    model.print_perplexity()

Num topics:  10
Elapsed time : 1454 sec.

[2, 347415074773276608, 289322479036779136, 268359768369659360, 246187881608650208, 217907687035157312, 202492230523938720, 187065982155361216, 175129898402068000, 166953857670722240]
Num topics:  20
Elapsed time : 2299 sec.

[1, 43639416790438482280448, 35206300837285222940672, 32543626795818739040256, 27056635451614492098560, 19475061364181778300928, 15029164789962184327168, 11855623453586166382592, 10058371743066716372992, 8895359532209015881728]
Num topics:  50
Elapsed time : 4315 sec.

[1, 4930901589741260510003590070272, 843449813812448929969275404288, 1604702367895, 1474799926561, 1430009871892, 1399796395377, 1390095885347, 1385804100251, 1376299168296]


In [111]:
num_topics_list = [10, 20, 50]
for num_topics in num_topics_list:
    print('Num topics: ', num_topics)
    model = TopicModel(num_topics=num_topics,    # число тем в модели
                       cv_model=cv_model,        # векторизатор, объект класса `pyspark.ml.feature.CountVectorizer`
                       nnz=nnz,                  # общее число словопозиций в коллекции
                       num_document_passes=5,    # число проходов по документу на E-шаге
                       use_phi_broadcast=False,  # использование бродкастинга матрицы $\Phi$
                       beta=-25.0)               # коэффициент регуляризации
    model.fit(bow, num_collection_passes=10)
    model.print_perplexity()

Num topics:  10
Elapsed time : 1072 sec.

[2, 9519996327750, 7827064885036, 7194103247180, 6573907893997, 5755058668682, 4964189431759, 4412839792071, 4039596268401, 3689146990891]
Num topics:  20
Elapsed time : 1916 sec.

[1, 83813007954602816, 65030323624663424, 52364805260211864, 37477876803085920, 23539375085451428, 14395727243317482, 9623690957252396, 7516768122379486, 6312555008361182]
Num topics:  50
Elapsed time : 4596 sec.

[1, 608460637590711856791552, 434890635895124380352512, 58106010018199044620288, 94450936673882144768, 4060532980514948608, 2262328083508298240, 1875518993904480256, 1744668108264572672, 1611464321411701248]
