# Train a Latent Dirichlet Allocation (LDA) Topic Model

This notebook trains an LDA model for predicting the latent topics in customer comments.

## Notebook Summary

**It takes as inputs...**

2. A range of standard packages for modelling and evaluation
3. A custom py script for preparing the data


**It applies the following process...**

1. Data preperation
2. Model evaluation
3. Model selection (confirmed by the user)
4. Export model to be used for latent topic predictions

**It's outputs are...**

1. Latent Dirichlet Allocation latent topic prediction model


## Further Notes
**During the project multiple models were evaluated. It was found that:**

1. LDA was comparitively quick to train and returned excellent coherence scores (>0.5)
2. Embeddings based methods returned too broad a range of topics to useful for human labelling

**Future Improvements:**

1. Spell check on the text inputs to improve the BOW representation
2. Fine tune an embedding model (similar approach to what would be done for a transformer based sentiment model) and apply GenAI to automatically label the generated topics to remove the human labelling constraint

In [None]:
__name__ = '__main__'

# 1 Package Imports and Constants

## 1.0 Package Imports

In [1]:
# snowflake functions
from snowflake.snowpark.context import get_active_session

# Snowpark ML
from snowflake.ml.registry import Registry
from snowflake.ml._internal.utils import identifier

# modelling and prep libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ast

from gensim import corpora
from gensim.models import LdaMulticore, coherencemodel

import spacy
import nltk
import re
import time

## 1.1 Notebook Constants

In [None]:
import pandas as pd
import numpy as np

# Set seed for reproducibility
seed = 1234
np.random.seed(seed)

# Define the columns as mentioned
all_comment_cols = [
    'comment_reponse', 'promoter_comment', 'passive_comment', 'detractor_comment', 
    'overall_experience_comment', 'admission_information_comment', 'went_well_improved_comment', 
    'staff_communication_comment', 'inclusion_support_perspn_comment', 'communication_and_treatment_comment', 
    'cultural_and_spiritual_needs_comment', 'included_in_decisions_comment', 'pain_relief_comment', 
    'worries_or_concerns_comment', 'meals_comment', 'cleanliness_comment', 'leaving_preparation_comment', 
    'medication_comment', 'anything_to_add'
]

# Generate sample data
num_samples = 500

data = {
    'unique_id': np.arange(1, num_samples + 1),
    'comment_reponse': np.random.choice(['Good service', 'Bad experience', 'Average care', None], num_samples),
    'promoter_comment': np.random.choice(['Excellent staff', 'Very satisfied', 'Great environment', None], num_samples),
    'passive_comment': np.random.choice(['It was okay', 'Nothing special', 'Neutral feelings', None], num_samples),
    'detractor_comment': np.random.choice(['Terrible service', 'Very disappointed', 'Bad management', None], num_samples),
    'overall_experience_comment': np.random.choice(['Wonderful stay', 'Could be better', 'Poor facilities', None], num_samples),
    'admission_information_comment': np.random.choice(['Clear instructions', 'Confusing process', 'No information', None], num_samples),
    'went_well_improved_comment': np.random.choice(['Food quality', 'Staff behavior', 'Cleanliness', None], num_samples),
    'staff_communication_comment': np.random.choice(['Very informative', 'Lack of communication', 'Average', None], num_samples),
    'inclusion_support_perspn_comment': np.random.choice(['Very supportive', 'Not inclusive', 'Okay', None], num_samples),
    'communication_and_treatment_comment': np.random.choice(['Excellent treatment', 'Poor communication', 'Good but slow', None], num_samples),
    'cultural_and_spiritual_needs_comment': np.random.choice(['Met all needs', 'Ignored needs', 'Neutral', None], num_samples),
    'included_in_decisions_comment': np.random.choice(['Fully included', 'Excluded', 'Partially included', None], num_samples),
    'pain_relief_comment': np.random.choice(['Effective', 'Ineffective', 'Average', None], num_samples),
    'worries_or_concerns_comment': np.random.choice(['Addressed all concerns', 'Ignored concerns', 'Somewhat addressed', None], num_samples),
    'meals_comment': np.random.choice(['Tasty meals', 'Bland food', 'Good variety', None], num_samples),
    'cleanliness_comment': np.random.choice(['Very clean', 'Dirty', 'Acceptable', None], num_samples),
    'leaving_preparation_comment': np.random.choice(['Well prepared', 'Unprepared', 'Average preparation', None], num_samples),
    'medication_comment': np.random.choice(['On time', 'Late', 'No issues', None], num_samples),
    'anything_to_add': np.random.choice(['Thank you', 'Never coming back', 'Great experience', None], num_samples)
}

# Create DataFrame
df_raw = pd.DataFrame(data)

# Display the DataFrame
df_raw.head()

In [2]:
# set seed for reproducible results
seed = 1234
np.random.seed(seed)

# whether to evaluate the hyperparameter n_topics. Setting True will run the final section. 
# This is only necessary when you want to re-evaluate the number of topics LDA will fit
evaluate_num_topics = False


# col name grouping for easy filtering
all_comment_cols = ['comment_reponse','promoter_comment','passive_comment','detractor_comment','overall_experience_comment','admission_information_comment',
                   'went_well_improved_comment','staff_communication_comment', 'inclusion_support_perspn_comment', 'communication_and_treatment_comment',
                   'cultural_and_spiritual_needs_comment','included_in_decisions_comment','pain_relief_comment','worries_or_concerns_comment',
                    'meals_comment', 'cleanliness_comment','leaving_preparation_comment','medication_comment','anything_to_add']

## 1.2 Import Data

In [None]:
session = get_active_session()

## Define the SQL query
#query = "select top 500 * from dw.public.hospital_reviews" # TODO: remove 500 row limit once notebook is running
query = "select top 500 * from demo_datasets.public.hospital_reviews"
## Execute the query and fetch the data into a Pandas DataFrame
df_raw = session.sql(query).to_pandas()
#
## Display the DataFrame
df_raw.head()

In [3]:
# change all cols to lower case. Personal preference but makes writting out col headers faster
df_raw.columns = df_raw.columns.str.lower()

# rename the id key to match existing processing pipelines from the project created prior to the adding of a unique_id to the dataset
df_raw = df_raw.rename(columns={'unique_id': 'surrogate_survey_id'})

# create a seperate copy to compare to raw if necessary
df_imported = df_raw.copy()

In [None]:

# spacy's language model...
nlp = spacy.load("en_core_web_sm")
# nltk's part-of-speech tagger
# NB: if other nltk components are required, download them from github and 
# upload them to the appropruate sub folder of the nltk_data folder next to this notebook
# more details can be found here: https://www.nltk.org/data.html

if 'nltk_data' not in nltk.data.path:
    nltk.data.path.append('nltk_data')

# 2 Process

## 2.0 Data Preperation

**Tidy format**

In [None]:
# added line breaks to deal with the line limit
stop_word_list = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'all', 'almost', 'alone', 'along', 'already', 'also', 'although', 'always', 'am', 'among', 'amongst', 'amoungst', 'amount', 'an', 'and', 'another', 'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'around', 'as', 'at', 
                 'back', 'be', 'became', 'because', 'become', 'becomes', 'becoming', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides', 'between', 'beyond', 'bill', 'both', 'bottom', 'but', 'by', 'call', 'can', 'cannot', 'cant', 'co', 'computer', 'con', 'could', 'couldnt', 'cry', 
                 'de', 'describe', 'detail', 'did', 'didn', 'do', 'does', 'doesn', 'doing', 'don', 'done', 'down', 'due', 'during', 'each', 'eg', 'eight', 'either', 'eleven', 'else', 'elsewhere', 'empty', 'enough', 'etc', 'even', 'ever', 'every', 'everyone', 'everything', 'everywhere', 'except', 
                 'few', 'fifteen', 'fify', 'fill', 'find', 'fire', 'first', 'five', 'for', 'former', 'formerly', 'forty', 'found', 'four', 'from', 'front', 'full', 'further', 'get', 'give', 'go', 'had', 'has', 'hasnt', 'have', 'he', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'however', 'hundred', 
                 'i', 'ie', 'if', 'in', 'inc', 'indeed', 'interest', 'into', 'is', 'it', 'its', 'itself', 'just', 'keep', 'kg', 'km', 'last', 'latter', 'latterly', 'least', 'less', 'ltd', 'made', 'make', 'many', 'may', 'me', 'meanwhile', 'might', 'mill', 'mine', 'more', 'moreover', 'most', 'mostly', 'move', 'much', 'must', 'my', 'myself', 
                 'name', 'namely', 'neither', 'never', 'nevertheless', 'next', 'nine', 'no', 'nobody', 'none', 'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'of', 'off', 'often', 'on', 'once', 'one', 'only', 'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'own', 'part', 'per', 'perhaps', 'please', 'put', 'quite', 
                 'rather', 'rather', 're', 'really', 'regarding', 'same', 'say', 'see', 'seem', 'seemed', 'seeming', 'seems', 'serious', 'several', 'she', 'should', 'show', 'side', 'since', 'sincere', 'six', 'sixty', 'so', 'some', 'somehow', 'someone', 'something', 'sometime', 'sometimes', 'somewhere', 'still', 'such', 'system', 
                 'take', 'ten', 'than', 'that', 'the', 'their', 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein', 'thereupon', 'these', 'they', 'thick', 'thin', 'third', 'this', 'those', 'though', 'three', 'through', 'throughout', 'thru', 'thus', 'to', 'together', 'too', 'top', 'toward', 'towards', 'twelve', 'twenty', 'two', 
                 'un', 'under', 'unless', 'until', 'up', 'upon', 'us', 'used', 'using', 'various', 'very', 'very', 'via', 
                 'was', 'we', 'well', 'were', 'what', 'whatever', 'when', 'whence', 'whenever', 'where', 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'will', 'with', 'within', 'without', 'would', 'yet', 'you', 'your', 'yours', 'yourself', 'yourselves']

# If a single line gets too long Snowflake (streamlit) silently fails to populate the list
# The stop word list must not be empty or the notebook will fail
stop_word_list

In [None]:
# preprocessing functions. TODO once stage_packages are available in Snowflake AU use that to load the preprocessing .py scripts provided
# that way any edits to preprocessing are done in one place vs updating each notebook that contains the following code

def get_clean_text(df, col_name):
    """
    Parent function to call all related functions to prepare text for modelling.
    :param df: dataframe with customer comments to clean
    :param col: column name in dataframe containing comments
    :param encode_ascii: whether to remove non-ascii characters from comments
    :param lowercase: whether to convert all comments to lower case. If applying VADER set this to False
    :param lemmas: whether to convert matching words to spacy lemmas
    :param remove_punc: whether to remove punctuation

    :return: dataframe with comments column cleaned as per spec
    """

    # remove non-taggable tokens/characters i.e. emojis
    df.loc[:, col_name] = df[col_name].map(drop_non_ascii)

    # lowercase texts
    df.loc[:, col_name] = df[col_name].map(lambda x: x.lower())

    # lemmatize words
    df.loc[:, col_name] = df[col_name].astype(str).map(lemma)

    # remove punctuation
    df.loc[:, col_name] = df[col_name].map(drop_punctuation)

    return df


def drop_non_ascii(comment):
    """
    Removes non ascii characters from comments i.e. emojis
    If applying VADER retain emojis
    :param comment: a comment
    :return: comment with only ascii characters
    """
    comment = comment.encode('ascii', errors = 'ignore')
    return comment

def lemma(comment):
    """
    Lemmatize comments using spacy lemmatizer.
    :param comment: a comment
    :return: lemmatized comment
    """
    lemmatized = nlp(comment)
    lemmatized_final = ' '.join([word.lemma_ for word in lemmatized if word.lemma_ != '\'s'])
    return lemmatized_final


def drop_punctuation(comment):
    """
    Removes punctuation characters from comments
    If applying VADER retain punctuation
    :param comment: a comment
    :return: comment without punctuations
    """
    regex = re.compile('[' + re.escape('!"#%&\'()*+,-./:;<=>?@[\\]^_`{|}~')+'0-9\\r\\t\\n]')
    nopunct = regex.sub(" ", comment)
    nopunct_words = nopunct.split(' ')
    filter_words = [word.strip() for word in nopunct_words if word != '']
    words = ' '.join(filter_words)
    return words

def get_tidy_text(df, col_names_to_keep, id_col_name):
    """
    Takes a wide form dataframe and converts it to long form (tidy format)
    Also extends features and applies a standard set of quality control rules

    :param df: dataframe to restructure
    :param col_names_to_keep: comment columns of df to keep
    :param id_col_name: column of df that holds the unique survey id
    :return: tidy dataframe
    """

    # reduce to desired cols
    df_new = df[col_names_to_keep]

    # pivot to long form and order by survey (default order is column_header)
    df_new = pd.melt(df_new, id_vars=[id_col_name], var_name='comment_header', value_name='comment_response')

    # sort to keep all comments for a survey together when viewing the dataframe
    df_new = df_new.sort_values(by=id_col_name, ascending=True)

    # get count of populated comments before applying quality control
    initial_comment_count = df_new.dropna().shape[0]

    # extend features
    df_new['word_count'] = df_new['comment_response'].apply(count_words)

    # apply quality control rules

    # drop rows (comments) with fewer than 3 words
    df_new = df_new.dropna(subset=['comment_response']).loc[df_new['word_count'] >= 3]

    final_survey_count = len(df_new['surrogate_survey_id'].unique())
    
    print('Started with {} surveys and {} comments. \n '
          'Left with {} surveys and {} comments after quality control.'.format(df.shape[0],
                                                                               initial_comment_count,
                                                                               final_survey_count,
                                                                               df_new.shape[0]))

    return df_new


def count_words(text):
    """
    Simple function to count words based on spaces.
    :param text: text to count words in
    :return: integer of words in text
    """
    if isinstance(text, str):
        return len(text.split())
    else:
        return 0



def tokenize_comments(df, col_name, keep_verbs=False, min_comment_length=2):
    """
    Processes a column of comments from a dataframe to prepare for LDA.
    Assumes the comments have already been cleaned by get_clean_text
    :param df: dataframe of comments
    :param col_name: column of the dataframe that contains the cleaned survey comments
    :param keep_verbs: if True include nouns AND verbs. If False keep only nouns. Default is False
    :param min_comment_length: comments with fewer words than this will be dropped. Defailt is 2
    :return: dataframe with comments formatted as a list of processed n_grams
    """

    print(df.shape)
    # drop stop words and erroneous small words/typos
    df_new = df[col_name].map(lambda x: [word for word in x.split() \
                                            if word not in stop_word_list\
                                                and len(word) > 2])

    # filter for the desired part-of-speech tags (nouns only or nouns+verbs)
    df_new = df_new.apply(filter_by_pos_tags, keep_verbs=keep_verbs)

    # filter to comments at or above a minimum length
    df_new = df_new[df_new.apply(lambda x: len(x) >= min_comment_length)]

    return df_new


def filter_by_pos_tags(df, keep_verbs=False):
    """
    Parses and applies part-of-speech tags to comments. Comments are then filtered to nouns and (optionally) verbs.
    Nouns are more indicative of topics thus this filtering improves topic interpretability.
    :param df: dataframe of comments (only) to parse
    :param keep_verbs: if True include nouns AND verbs. If False keep only nouns. Default is False
    :return: list of filtered comments
    """
    pos_comment = nltk.pos_tag(df)

    if keep_verbs:
        # keep nouns and verbs
        filtered_list = [word[0] for word in pos_comment if word[1] in ['NN','VB', 'VBD', 'VBG', 'VBN', 'VBZ']]
    else:
        # keep nouns only
        filtered_list = [word[0] for word in pos_comment if word[1] in ['NN']]
    return filtered_list


In [None]:
df_imported = df_raw

In [4]:
cols_to_include = all_comment_cols + ['surrogate_survey_id']

stacked_df = get_tidy_text(df_imported, cols_to_include, 'surrogate_survey_id')

Started with 29884 surveys and 84732 comments. 
 Left with 25443 surveys and 74470 comments after quality control.


**Clean text**

In [5]:
# takes ~8min

start_time = time.time()

clean_comments_df = get_clean_text(stacked_df[['comment_response']], 'comment_response')

end_time = time.time()
print("Time taken: {} seconds".format(end_time - start_time))


final_comments= tokenize_comments(df=clean_comments_df, 
                                   col_name='comment_response',
                                   keep_verbs=False,
                                   min_comment_length=2)

end_time = time.time()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[col_name] = df[col_name].map(drop_non_ascii)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[col_name] = df[col_name].map(lambda x: x.lower())
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[col_name] = df[col_name].astype(str).map(lemma)


Time taken: 455.4314966201782 seconds


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[col_name] = df[col_name].map(drop_punctuation)


In [None]:
# check the tokenization has worked
final_comments.head()

**Tokenize and Filter Comments**

In [8]:
# takes ~1.5min

start_time = time.time()

final_comments= tokenize_comments(df=clean_comments_df,
                                   col_name='comment_response',
                                   keep_verbs=False,
                                   min_comment_length=2)

end_time = time.time()
print("Time taken: {} seconds".format(end_time - start_time))

Time taken: 89.71900153160095 seconds


**Generate Dictionary and Document Term Matrix for Model Fitting**

The dictionary maps each token (word) to a numeric ID. The document term matrix maps each token within a comment to its ID in the dictionary
Both objects are required for model fitting.

In [9]:
dictionary = corpora.Dictionary(final_comments)
doc_term_matrix = [dictionary.doc2bow(doc) for doc in final_comments]

## 2.1 Fit a 10 Topic Model

The number if topics is a critical hyperparameter in LDA. Experimentation suggested a good coherence was reached at 10 topics. The best coherence was at around 18-22 topics but FIRN decided that was too many to be interpretable initially. Once SCHL is comfortable with the output of topic models, FIRN strongly recommends increasing the num_topics to at least 18 to get a more granular view of latent themes in the customer comments.


In section 2.2 there is code to evaluate a range of values of num_topics.

In [10]:
ldamodel = LdaMulticore(doc_term_matrix, 
               num_topics=10, # 10 was found to be a good balance of coherence, diversity and quality (~0.6 c_v coherence)
               id2word = dictionary, 
               passes=40, # adjust this down to speed up fitting
               iterations=200, # adjust this down to speed up fitting
               chunksize = 10000, # adjust this down if the environment hits memory issues. Training will be slower.
               eval_every = None, # this uses perplexity but we want to use coherence. However, dropping this hyperparameter may improve a model.
               random_state=seed)

In [None]:
# measure the coherence of the topics. 
cm = coherencemodel.CoherenceModel(model=ldamodel, texts=final_comments,
                                                         dictionary=dictionary, coherence='c_v')
c_v_coherence = cm.get_coherence()
c_v_coherence

**Save the Model**

This also saves the objects related to the model - dictionary, model state and some precompted topic distribution values (the file ending with ...expElogbeta.npy).


In [None]:
-- here is another way to define the database and schema for the model registry.  
-- "Another way" means an SQL alternative to the python code version in the train sentiment model notebook 
use database SCHL_DEV;
create schema if not exists ML_MODEL_REGISTRY;
use schema ML_MODEL_REGISTRY

In [None]:
# function to find the current model version and return the next version to use
def get_next_model_version(reg, model_name):
    models = reg.show_models()
    if models.empty:
        return 'V_1'
    elif model_name not in models['name'].to_list():
        return 'V_1'
    max_version = max(
        ast.literal_eval(models.loc[models['name'] == model_name, 'versions'].values[0])
    )
    return 'V_{}'.format({int(max_version.split('_')[-1]) + 1})

# Save Model

## Option 1: Save the model artifacts to a stage

This code is not required if using the registry. However, it is here if you need a way to save the model objects somewhere else so you can download them and use them against a notebook (or an environment outside snowflake).

In [None]:
# how to "put" files in snowflake 


model_name = "CEMPLICITY_TOPIC_MODEL"

# get current db and schema to create the registry against

db = identifier._get_unescaped_name(session.get_current_database())
schema = identifier._get_unescaped_name(session.get_current_schema())

# Create a model registry
reg = Registry(session=session, database_name=db, schema_name=schema) # TODO: change this to just session=session once target db and schema are defined


next_model_version = get_next_model_version(reg, model_name) # this will always return V_1 while the model registry is not being used by this model

# save the file to the temporary local file storage associated with the notebook session. !!This will not persist across sessions
ldamodel.save('{}_{}'.format(model_name, next_model_version))

print('The next version of the model {} is: {}'.format(model_name, next_model_version))


# now push the files to a permenant stage
put_result_model = session.file.put('{}_{}'.format(model_name, next_model_version),'@TEST', auto_compress = False, overwrite=True)
put_result_id2word = session.file.put('{}_{}.id2word'.format(model_name, next_model_version),'@TEST', auto_compress = False, overwrite=True)
put_result_state = session.file.put('{}_{}.state'.format(model_name, next_model_version),'@TEST', auto_compress = False, overwrite=True)
put_result_betas = session.file.put('{}_{}.expElogbeta.npy'.format(model_name, next_model_version),'@TEST', auto_compress = False, overwrite=True)

print('\n model: {}\n id2word: {}\n state: {}\n betas: {}'.format(put_result_model[0].status,
                                                              put_result_id2word[0].status,
                                                              put_result_state[0].status,
                                                              put_result_betas[0].status))
# UPLOADED means saved, SKIPPED means the file already exists. The put is set to overwrite so we should always see UPLOADED

## Option 2: Save the model to the Snowflake Registry

Requires a custom model to log the LDA model to a snowflake model registry. Gensim models are not currently supported by Snowflake

!!This will not work as snowflake's conversion of the custom model class to a UDF triggers a max recursion issue with Pickle. Increasing the max recursion limit causes the notebook to crash. This code is kept here for when snowflake resolve the problem.

In [None]:
# required packages
from snowflake.ml.model import custom_model
from snowflake.ml.model import model_signature

In [None]:
-- Push the files to a permenant stage so they persist and can be called from any notebook
use database SCHL_DEV;
use schema ML_MODEL_REGISTRY;

create stage if not exists test; --TODO: needs to be somewhere more permenant once determined by SCHL.

In [None]:
# TODO: can delete the test stage on the ML registry schema and this code once the notebook runs and log_model() executes successfully.

model_name = "CEMPLICITY_TOPIC_MODEL"

# get current db and schema to create the registry against

db = identifier._get_unescaped_name(session.get_current_database())
schema = identifier._get_unescaped_name(session.get_current_schema())

# Create a model registry
reg = Registry(session=session, database_name=db, schema_name=schema) # TODO: change this to just session=session once target db and schema are defined

next_model_version = get_next_model_version(reg, model_name)

# save the file to the temporary local file storage associated with the notebook session. !!This will not persist across sessions
ldamodel.save('{}_{}'.format(model_name, next_model_version))

print('The next version of the model {} is: {}'.format(model_name, next_model_version))

In [None]:
# create custom class for Gensim LDA models. Any such model can use this class.

class custom_lda_model(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)

        # model
        model_dir = self.context.path('model_file')
        self.model = LdaMulticore.load(model_dir)

        # load the mapping of words to IDs
        id2word_dir = self.context.path('id2word')
        self.id2word = corpora.Dictionary.load(id2word_dir) # to get around a picking recursion error

    # define the predict method
    @custom_model.inference_api
    def predict(self, input: pd.DataFrame) -> pd.DataFrame:
       
        # extract the text feature as a series as the bag of words method does not accept a dataframe
        comments = input.iloc[:,0]

        # create the matrix of documents the terms (words) within them
        doc_term_matrix = [self.id2word.doc2bow(doc) for doc in comments]
 
        #for each document get the topic with the highest attribution.
        print('preparing predictions for model: {} on samples of len: {}'.format(self.model,len(doc_term_matrix)))
        preds = [self.model.get_document_topics(doc) for doc in doc_term_matrix] # attribution per topic for all topics per doc
        top_topics = [max(doc, key=lambda x: x[1])[0] for doc in preds] # topic with the highest attribution per doc
        attributions = [max(doc, key=lambda x: x[1])[1] for doc in preds] # value of the highest attribution per doc
        
        model_output = pd.DataFrame({
            'predicted_topic': top_topics,
            '__name__ = '__main__'': attributions
        }) 
    
        return model_output

In [None]:
# specify the location of the model artifacts which are those saved locally to this notebook session. 

lda_mc2 = custom_model.ModelContext(
    artifacts={
        'model_file': 'CEMPLICITY_TOPIC_MODEL_V_1',
        'expElogbeta': 'CEMPLICITY_TOPIC_MODEL_V_1.expElogbeta.npy',
        'id2word': 'CEMPLICITY_TOPIC_MODEL_V_1.id2word',
        'state':'CEMPLICITY_TOPIC_MODEL_V_1.state'
    }
)

**Verify the custom model works**

In [None]:
# sample of data
data_sample = pd.DataFrame(final_comments[:5])

# load the custom model via the custom class created above
my_lda_model = custom_lda_model(lda_mc2)

# run inference
output_data = my_lda_model.predict(data_sample)
output_data

**Location of the error when trying to log the custom model class** 

log_model() fails with the error: "PicklingError: Could not pickle object as excessively deep recursion required"

tried:
1. increasing the recursion limit with sys.setrecursionlimit. W/H fails at around 30k which is still too small for the log_model method to work
2. setting the resource rstack limit
3. vectorizing the for loop in the model class 

In [None]:
# Log the model
# the name + version must be unique and both are strings. The name could be a UUID but a short description of the use case is more practical.
# the model can be logged with metrics, to track history, and the features

# use the sample data and inference result to infer the signature (col names and types)
inferred_signature = model_signature.infer_signature(input_data=data_sample, output_data=output_data)

logged_model = reg.log_model(
        model = my_lda_model,
        model_name = model_name,
        version_name = get_next_model_version(reg, model_name), # assumes versioning is incremental
        metrics = {'c_v_coherence': c_v_coherence},
        signatures={"predict": inferred_signature},
        options={"relax_version": False},
        comment = 'This is the topic model trained on the Cemplicity survey data.'
)

# bug with comments not saving when submitted by log_model. TODO: retest as this seems to have been resolved by the time we did the sentiment model
#logged_model.comment = 'This is the topic model trained on the Cemplicity survey data.'


# Explore The Optimum Value for num_topics

In [None]:
# takes 60+ min to run on a local PC as it's building a topic model for every number between 5 and 25.


if evaluate_num_topics:
    
    counter = 1
    
    for i in [1,2,3]:
        start_time = time.time()
        print('evaluating min doc length of {}'.format(counter))

        # these are only necessary if inspecting variance in attribution between models across topics and terms 
        # this may cause memory problems in snowflake
        #dictionary = dictionaries['dictionary{}'.format(counter)]
        #doc_term_matrix = doc_term_matrices['doc_term_matrix{}'.format(counter)]
        
        coherence = []
        for k in range(5,25):
            print('Round: '+str(k))
            ldamodel = LdaMulticore(doc_term_matrix, num_topics=k, id2word = dictionary, passes=40, workers=7, # set this equal to the number of PHYSICAL cores
                           iterations=200, chunksize = 10000, eval_every = None)
            
            cm = coherencemodel.CoherenceModel(model=ldamodel, texts=final_reviews,
                                                             dictionary=dictionary, coherence='c_v')
            coherence.append((k,cm.get_coherence()))
    
        end_time = time.time()
        print("Time taken: {} seconds".format(end_time - start_time))
        
        
        x_val = [x[0] for x in coherence]
        y_val = [x[1] for x in coherence]
        
        print('\n\nPlot for min doc length of {}'.format(counter))
        plt.plot(x_val,y_val)
        plt.scatter(x_val,y_val)
        plt.title('Number of Topics vs. Coherence')
        plt.xlabel('Number of Topics')
        plt.ylabel('Coherence')
        plt.xticks(x_val)
        plt.show()
        counter += 1