# Topic Modelling - Genism

In [4]:
import matplotlib.pyplot as plt
import gensim
import numpy as np
import spacy
import pandas as pd

from gensim.models import CoherenceModel, LdaModel, LsiModel, HdpModel
from gensim.models.wrappers import LdaMallet
from gensim.corpora import Dictionary
import pyLDAvis.gensim

import os, re, operator, warnings
warnings.filterwarnings('ignore')  # Let's not pay heed to them right now
%matplotlib inline

In [10]:
import spacy
nlp = spacy.load("en_core_web_sm")

In [23]:
# preprocess the text 
def proprocess_text(string_list):
    final_str_list = []
    for s in string_list:
        if len(s) > 15:
            doc = nlp(s)
            new_list =[]
            for token in doc:
                if token.pos_ == 'ADJ' or token.pos_ == 'ADV' or token.pos_ == 'NOUN' or token.pos_ == 'NUM' or token.pos_ == 'PROPN' or token.pos_ == 'VERB':
                    new_list.append(token.text)
            final_str_list.append(new_list)
    return final_str_list

In [24]:
my_df = pd.read_csv("/Users/sdeshpande/Desktop/bioinformatices/full_articlesLDA.csv")
print(my_df.shape)
my_df.dropna(subset = ['title'], inplace = True)
print(my_df.shape)
my_df.head()

(21053, 7)
(20140, 7)


Unnamed: 0.1,Unnamed: 0,paper_id,title,abstract,body_text,doi,title_abstract_body
0,0,0001418189999fea7f7cbe3e82703d71c85a6fe5,Absence of surface expression of feline infect...,Feline infectious peritonitis virus (FIPV) pos...,Feline infectious peritonitis (FIP) is a fatal...,http://doi.org/10.1016/j.vetmic.2006.11.026,Absence of surface expression of feline infect...
1,1,000affa746a03f1fe4e3b3ef1a62fdfa9b9ac52a,Correlation between antimicrobial consumption ...,Objectives: This study was conducted to invest...,The incidence of health-care-associated infect...,http://doi.org/10.1016/j.jmii.2013.10.008,Correlation between antimicrobial consumption ...
2,2,000e754142ba65ef77c6fdffcbcbe824e141ea7b,Laboratory-based surveillance of hospital-acqu...,"Of 7,772 laboratory-confirmed cases of respira...",The human respiratory viruses include adenovir...,http://doi.org/10.1016/j.ajic.2017.01.009,Laboratory-based surveillance of hospital-acqu...
3,3,000eec3f1e93c3792454ac59415c928ce3a6b4ad,Pneumonie virale sévère de l'immunocompétent V...,Reçu et accepté le 7 février 2004 Les infectio...,Les pathologies infectieuses respiratoires son...,http://doi.org/10.1016/j.reaurg.2004.02.009,Pneumonie virale sévère de l'immunocompétent V...
4,4,001259ae6d9bfa9376894f61aa6b6c5f18be2177,Microheterogeneity of S-glycoprotein of mouse ...,"IEF, isoelectric focusing; NC, nitrocellulose;...",(Accepted 10 January 1992) is a neurotropic co...,http://doi.org/10.1016/0166-0934(92)90173-B,Microheterogeneity of S-glycoprotein of mouse ...


In [25]:
health_titles = my_df['title'].tolist()

In [26]:
processed_titles = proprocess_text(health_titles)

In [28]:
processed_titles[0]

['Absence',
 'surface',
 'expression',
 'feline',
 'infectious',
 'peritonitis',
 'virus',
 'FIPV',
 'antigens',
 'infected',
 'cells',
 'isolated',
 'cats',
 'FIP']

In [27]:
print(len(health_titles))
print(len(processed_titles))

20140
20047


In [29]:
bigram = gensim.models.Phrases(processed_titles)

In [30]:
texts = [bigram[line] for line in processed_titles]

In [31]:
texts[0]

['Absence',
 'surface',
 'expression',
 'feline_infectious',
 'peritonitis_virus',
 'FIPV',
 'antigens',
 'infected_cells',
 'isolated',
 'cats',
 'FIP']

In [32]:
dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]

In [33]:
lsimodel = LsiModel(corpus=corpus, num_topics=10, id2word=dictionary)
lsimodel.show_topics(num_topics=5)  # Showing only the top 5 topics

[(0,
  '0.760*"virus" + 0.282*"infection" + 0.139*"coronavirus" + 0.130*"respiratory_syndrome" + 0.119*"protein" + 0.108*"viral" + 0.101*"human" + 0.099*"based" + 0.098*"RNA" + 0.096*"Virus"'),
 (1,
  '0.556*"C" + -0.359*"virus" + 0.327*"Virus" + 0.223*"N" + 0.198*"A" + 0.188*"E" + 0.168*"R" + 0.129*"1" + 0.121*"infection" + 0.093*"S"'),
 (2,
  '0.545*"Virus" + -0.444*"C" + 0.334*"infection" + -0.301*"virus" + -0.122*"R" + -0.121*"N" + -0.113*"E" + 0.109*"Middle_East" + -0.106*"A" + 0.101*"coronavirus"'),
 (3,
  '-0.624*"infection" + 0.565*"Virus" + 0.302*"virus" + -0.159*"coronavirus" + -0.127*"viral" + -0.105*"respiratory_syndrome" + 0.091*"RNA" + -0.080*"human" + -0.077*"Middle_East" + 0.065*"Protein"'),
 (4,
  '-0.416*"infection" + -0.304*"C" + 0.267*"N" + -0.264*"Virus" + 0.248*"1" + 0.240*"Middle_East" + 0.219*"coronavirus" + 0.192*"respiratory_syndrome" + 0.168*"Respiratory_Syndrome" + 0.164*"A"')]

In [34]:
hdpmodel = HdpModel(corpus=corpus, id2word=dictionary)
hdpmodel.show_topics()

[(0,
  '0.001*virus + 0.001*infection + 0.001*analysis + 0.001*using + 0.001*specific + 0.001*cells + 0.001*coronavirus + 0.001*human + 0.001*1 + 0.001*detection + 0.000*Virus + 0.000*patients + 0.000*disease + 0.000*based + 0.000*induced + 0.000*2 + 0.000*\u202bفي\u202c + 0.000*China + 0.000*Novel + 0.000*respiratory_syndrome'),
 (1,
  '0.001*virus + 0.001*infection + 0.001*E + 0.001*Virus + 0.001*2 + 0.001*Clinical + 0.001*coronavirus + 0.001*Infection + 0.001*children + 0.001*respiratory_syndrome + 0.001*trial + 0.001*influenza + 0.001*protein + 0.001*expression + 0.000*based + 0.000*Human + 0.000*Cells + 0.000*using + 0.000*RNA + 0.000*Respiratory'),
 (2,
  '0.001*virus + 0.001*1 + 0.001*coronavirus + 0.001*viral + 0.001*novel + 0.001*N + 0.001*A + 0.001*Virus + 0.001*infection + 0.001*2 + 0.001*RNA + 0.001*3 + 0.001*analysis + 0.001*Human + 0.001*Protein + 0.001*cells + 0.000*characterization + 0.000*ovalbumin + 0.000*protein + 0.000*human'),
 (3,
  '0.002*Respiratory_Syndrome + 0

In [35]:
ldamodel = LdaModel(corpus=corpus, num_topics=10, id2word=dictionary)
ldamodel.show_topics()

[(0,
  '0.014*"novel" + 0.014*"vaccination" + 0.013*"Global" + 0.012*"Different" + 0.010*"Systematic_Review" + 0.010*"water" + 0.009*"Epidemics" + 0.009*"affects" + 0.009*"epidemiological" + 0.009*"air"'),
 (1,
  '0.016*"MACROPHAGES" + 0.012*"virus" + 0.011*"respiratory_syncytial" + 0.011*"replication" + 0.010*"patients" + 0.010*"VIRUS" + 0.010*"outcomes" + 0.009*"Active" + 0.009*"associated" + 0.009*"Wuhan_China"'),
 (2,
  '0.012*"virus" + 0.009*"virus_infection" + 0.009*"antiviral" + 0.008*"2009_H1N1" + 0.007*"Respiratory_Infections" + 0.007*"Human_Coronavirus" + 0.007*"Regions" + 0.007*"antigen" + 0.007*"experimental_infection" + 0.007*"PCR"'),
 (3,
  '0.017*"respiratory_syndrome" + 0.013*"0_0" + 0.012*"virus" + 0.011*"human" + 0.011*"coronavirus" + 0.009*"disease" + 0.009*"evolution" + 0.008*"porcine_reproductive" + 0.008*"protein" + 0.008*"host"'),
 (4,
  '0.013*"Singapore" + 0.012*"Pandemic" + 0.012*"Reporting" + 0.010*"infection" + 0.010*"response" + 0.009*"COVID-19" + 0.008*"fo

In [36]:
pyLDAvis.enable_notebook()
pyLDAvis.gensim.prepare(ldamodel, corpus, dictionary)

# Topic Modelling overall

In [37]:
my_df.head()

Unnamed: 0.1,Unnamed: 0,paper_id,title,abstract,body_text,doi,title_abstract_body
0,0,0001418189999fea7f7cbe3e82703d71c85a6fe5,Absence of surface expression of feline infect...,Feline infectious peritonitis virus (FIPV) pos...,Feline infectious peritonitis (FIP) is a fatal...,http://doi.org/10.1016/j.vetmic.2006.11.026,Absence of surface expression of feline infect...
1,1,000affa746a03f1fe4e3b3ef1a62fdfa9b9ac52a,Correlation between antimicrobial consumption ...,Objectives: This study was conducted to invest...,The incidence of health-care-associated infect...,http://doi.org/10.1016/j.jmii.2013.10.008,Correlation between antimicrobial consumption ...
2,2,000e754142ba65ef77c6fdffcbcbe824e141ea7b,Laboratory-based surveillance of hospital-acqu...,"Of 7,772 laboratory-confirmed cases of respira...",The human respiratory viruses include adenovir...,http://doi.org/10.1016/j.ajic.2017.01.009,Laboratory-based surveillance of hospital-acqu...
3,3,000eec3f1e93c3792454ac59415c928ce3a6b4ad,Pneumonie virale sévère de l'immunocompétent V...,Reçu et accepté le 7 février 2004 Les infectio...,Les pathologies infectieuses respiratoires son...,http://doi.org/10.1016/j.reaurg.2004.02.009,Pneumonie virale sévère de l'immunocompétent V...
4,4,001259ae6d9bfa9376894f61aa6b6c5f18be2177,Microheterogeneity of S-glycoprotein of mouse ...,"IEF, isoelectric focusing; NC, nitrocellulose;...",(Accepted 10 January 1992) is a neurotropic co...,http://doi.org/10.1016/0166-0934(92)90173-B,Microheterogeneity of S-glycoprotein of mouse ...


In [38]:
health_data = my_df[["paper_id","title"]]

In [39]:
no_topics = 3 #@param {type:"integer"}

no_top_words = 4 #@param {type:"integer"}

no_top_documents = 3 #@param {type:"integer"}

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
import numpy as np

In [40]:
titles = health_data.title.tolist()

In [41]:
def display_topics(H, W, feature_names, documents, no_top_words, no_top_documents):
    for topic_idx, topic in enumerate(H):
        print("Topic %d:" % (topic_idx))
        print(" ".join([feature_names[i]
                        for i in topic.argsort()[:-no_top_words - 1:-1]]))
        top_doc_indices = np.argsort( W[:,topic_idx] )[::-1][0:no_top_documents]
        for doc_index in top_doc_indices:
            print(documents[doc_index])


In [42]:
# NMF is able to use tf-idf
tfidf_vectorizer = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english')
tfidf = tfidf_vectorizer.fit_transform(titles)
tfidf_feature_names = tfidf_vectorizer.get_feature_names()

In [43]:
# Run NMF
nmf_model = NMF(n_components=no_topics, random_state=1, alpha=.1, l1_ratio=.5, init='nndsvd').fit(tfidf)
nmf_W = nmf_model.transform(tfidf)
nmf_H = nmf_model.components_

In [44]:
print("NMF Topics")
display_topics(nmf_H, nmf_W, tfidf_feature_names, titles, no_top_words, no_top_documents)
print("--------------")

NMF Topics
Topic 0:
virus infection porcine influenza
Transmissible Gastroenteritis Virus of Pigs and Porcine Epidemic Diarrhea Virus
Virus-Vectored Influenza Virus Vaccines
Multiplex real-time RT-PCR for the simultaneous detection and quantification of transmissible gastroenteritis virus and porcine epidemic diarrhea virus
Topic 1:
respiratory syndrome acute infections
Severe Acute Respiratory Syndrome-associated Coronavirus Infection
Comparative Epidemiology of Human Infections with Middle East Respiratory Syndrome and Severe Acute Respiratory Syndrome Coronaviruses among Healthcare Personnel
Middle East respiratory syndrome
Topic 2:
coronavirus sars protein cov
Peptide Mimicrying Between SARS Coronavirus Spike Protein and Human Proteins Reacts with SARS Patient Serum
Characterization of protein-protein interactions between the nucleocapsid protein and membrane protein of the SARS coronavirus
Antibody responses against SARS-coronavirus and its nucleocaspid in SARS patients
----------

In [45]:
# LDA can only use raw term counts for LDA because it is a probabilistic graphical model
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2, stop_words='english')
tf = tf_vectorizer.fit_transform(titles)
tf_feature_names = tf_vectorizer.get_feature_names()

# Run LDA
lda_model = LatentDirichletAllocation(n_components=no_topics, max_iter=5, learning_method='online', learning_offset=50.,random_state=0).fit(tf)
lda_W = lda_model.transform(tf)
lda_H = lda_model.components_

print("LDA Topics")
display_topics(lda_H, lda_W, tf_feature_names, titles, no_top_words, no_top_documents)

LDA Topics
Topic 0:
protein coronavirus virus sars
Interferon-induced HERC5 is evolving under positive selection and inhibits HIV-1 particle production by a novel mechanism targeting Rev/RRE-dependent RNA nuclear export Interferon-induced HERC5 is evolving under positive selection and inhibits HIV-1 particle production by a novel mechanism targeting Rev/RRE-dependent RNA nuclear export
Proteome and phosphoproteome analysis of honeybee (Apis mellifera) venom collected from electrical stimulation and manual extraction of the venom gland Proteome and phosphoproteome analysis of honeybee (Apis mellifera) venom collected from electrical stimulation and manual extraction of the venom gland
From the Similarity Analysis of Protein Cavities to the Functional Classification of Protein Families Using Cavbase Keywords: protein binding pockets; classification of protein binding pockets; cluster analysis of protein binding pockets; protein kinases; SARS protease
Topic 1:
respiratory health study dis

# Zero-shot topic modelling

In [46]:
import numpy as np
import torch


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience.
    Source code: https://github.com/Bjarten/early-stopping-pytorch """
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        if path is None:
            self.path = 'checkpoint.pt'
        else:
            self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')

        model.save(self.path)

        self.val_loss_min = val_loss


In [47]:
"""PyTorch class for feed foward inference network."""

from collections import OrderedDict
from torch import nn
import torch


class ContextualInferenceNetwork(nn.Module):

    """Inference Network."""

    def __init__(self, input_size, bert_size, output_size, hidden_sizes,
                 activation='softplus', dropout=0.2):
        """
        Initialize InferenceNetwork.
        Args
            input_size : int, dimension of input
            output_size : int, dimension of output
            hidden_sizes : tuple, length = n_layers
            activation : string, 'softplus' or 'relu', default 'softplus'
            dropout : float, default 0.2, default 0.2
        """
        super(ContextualInferenceNetwork, self).__init__()
        assert isinstance(input_size, int), "input_size must by type int."
        assert isinstance(output_size, int), "output_size must be type int."
        assert isinstance(hidden_sizes, tuple), \
            "hidden_sizes must be type tuple."
        assert activation in ['softplus', 'relu'], \
            "activation must be 'softplus' or 'relu'."
        assert dropout >= 0, "dropout must be >= 0."

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.dropout = dropout

        if activation == 'softplus':
            self.activation = nn.Softplus()
        elif activation == 'relu':
            self.activation = nn.ReLU()

        self.input_layer = nn.Linear(input_size+input_size, hidden_sizes[0])
        self.adapt_bert = nn.Linear(bert_size, hidden_sizes[0])

        self.hiddens = nn.Sequential(OrderedDict([
            ('l_{}'.format(i), nn.Sequential(nn.Linear(h_in, h_out), self.activation))
            for i, (h_in, h_out) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:]))]))

        self.f_mu = nn.Linear(hidden_sizes[-1], output_size)
        self.f_mu_batchnorm = nn.BatchNorm1d(output_size, affine=False)

        self.f_sigma = nn.Linear(hidden_sizes[-1], output_size)
        self.f_sigma_batchnorm = nn.BatchNorm1d(output_size, affine=False)

        self.dropout_enc = nn.Dropout(p=self.dropout)

    def forward(self, x, x_bert):
        """Forward pass."""
        x_bert = self.adapt_bert(x_bert)

        x = self.activation(x_bert)
        x = self.hiddens(x)
        x = self.dropout_enc(x)
        mu = self.f_mu_batchnorm(self.f_mu(x))
        log_sigma = self.f_sigma_batchnorm(self.f_sigma(x))

        return mu, log_sigma




class CombinedInferenceNetwork(nn.Module):

    """Inference Network."""

    def __init__(self, input_size, bert_size, output_size, hidden_sizes,
                 activation='softplus', dropout=0.2):
        """
        Initialize InferenceNetwork.
        Args
            input_size : int, dimension of input
            output_size : int, dimension of output
            hidden_sizes : tuple, length = n_layers
            activation : string, 'softplus' or 'relu', default 'softplus'
            dropout : float, default 0.2, default 0.2
        """
        super(CombinedInferenceNetwork, self).__init__()
        assert isinstance(input_size, int), "input_size must by type int."
        assert isinstance(output_size, int), "output_size must be type int."
        assert isinstance(hidden_sizes, tuple), \
            "hidden_sizes must be type tuple."
        assert activation in ['softplus', 'relu'], \
            "activation must be 'softplus' or 'relu'."
        assert dropout >= 0, "dropout must be >= 0."

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.dropout = dropout

        if activation == 'softplus':
            self.activation = nn.Softplus()
        elif activation == 'relu':
            self.activation = nn.ReLU()

        self.input_layer = nn.Linear(input_size+input_size, hidden_sizes[0])
        self.adapt_bert = nn.Linear(bert_size, input_size)
        self.bert_layer = nn.Linear(hidden_sizes[0], hidden_sizes[0])

        self.hiddens = nn.Sequential(OrderedDict([
            ('l_{}'.format(i), nn.Sequential(nn.Linear(h_in, h_out), self.activation))
            for i, (h_in, h_out) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:]))]))

        self.f_mu = nn.Linear(hidden_sizes[-1], output_size)
        self.f_mu_batchnorm = nn.BatchNorm1d(output_size, affine=False)

        self.f_sigma = nn.Linear(hidden_sizes[-1], output_size)
        self.f_sigma_batchnorm = nn.BatchNorm1d(output_size, affine=False)

        self.dropout_enc = nn.Dropout(p=self.dropout)

    def forward(self, x, x_bert):
        """Forward pass."""
        x_bert = self.adapt_bert(x_bert)
        x = torch.cat((x, x_bert), 1)
        x = self.input_layer(x)

        x = self.activation(x)
        x = self.hiddens(x)
        x = self.dropout_enc(x)
        mu = self.f_mu_batchnorm(self.f_mu(x))
        log_sigma = self.f_sigma_batchnorm(self.f_sigma(x))

        return mu, log_sigma


In [48]:
"""PyTorch class for feed foward AVITM network."""

import torch
from torch import nn
from torch.nn import functional as F

class DecoderNetwork(nn.Module):

    """AVITM Network."""

    def __init__(self, input_size, bert_size, infnet, n_components=10, model_type='prodLDA',
                 hidden_sizes=(100,100), activation='softplus', dropout=0.2,
                 learn_priors=True):
        """
        Initialize InferenceNetwork.
        Args
            input_size : int, dimension of input
            n_components : int, number of topic components, (default 10)
            model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA')
            hidden_sizes : tuple, length = n_layers, (default (100, 100))
            activation : string, 'softplus', 'relu', (default 'softplus')
            learn_priors : bool, make priors learnable parameter
        """
        super(DecoderNetwork, self).__init__()
        assert isinstance(input_size, int), "input_size must by type int."
        assert isinstance(n_components, int) and n_components > 0, \
            "n_components must be type int > 0."
        assert model_type in ['prodLDA', 'LDA'], \
            "model type must be 'prodLDA' or 'LDA'"
        assert isinstance(hidden_sizes, tuple), \
            "hidden_sizes must be type tuple."
        assert activation in ['softplus', 'relu'], \
            "activation must be 'softplus' or 'relu'."
        assert dropout >= 0, "dropout must be >= 0."

        self.input_size = input_size
        self.n_components = n_components
        self.model_type = model_type
        self.hidden_sizes = hidden_sizes
        self.activation = activation
        self.dropout = dropout
        self.learn_priors = learn_priors
        self.topic_word_matrix = None

        if infnet == "zeroshot":
            self.inf_net = ContextualInferenceNetwork(
                input_size, bert_size, n_components, hidden_sizes, activation)
        elif infnet == "combined":
            self.inf_net = CombinedInferenceNetwork(
                input_size, bert_size, n_components, hidden_sizes, activation)
        else:
            raise Exception('Missing infnet parameter, options are zeroshot and combined')

        # init prior parameters
        # \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i;
        # \alpha = 1 \forall \alpha
        topic_prior_mean = 0.0
        self.prior_mean = torch.tensor(
            [topic_prior_mean] * n_components)
        if torch.cuda.is_available():
            self.prior_mean = self.prior_mean.cuda()
        if self.learn_priors:
            self.prior_mean = nn.Parameter(self.prior_mean)

        # \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k;
        # \alpha = 1 \forall \alpha
        topic_prior_variance = 1. - (1. / self.n_components)
        self.prior_variance = torch.tensor(
            [topic_prior_variance] * n_components)
        if torch.cuda.is_available():
            self.prior_variance = self.prior_variance.cuda()
        if self.learn_priors:
            self.prior_variance = nn.Parameter(self.prior_variance)

        self.beta = torch.Tensor(n_components, input_size)
        if torch.cuda.is_available():
            self.beta = self.beta.cuda()
        self.beta = nn.Parameter(self.beta)
        nn.init.xavier_uniform_(self.beta)

        self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False)

        # dropout on theta
        self.drop_theta = nn.Dropout(p=self.dropout)

    @staticmethod
    def reparameterize(mu, logvar):
        """Reparameterize the theta distribution."""
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x, x_bert):
        """Forward pass."""
        # batch_size x n_components
        posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert)
        posterior_sigma = torch.exp(posterior_log_sigma)

        # generate samples from theta
        theta = F.softmax(
            self.reparameterize(posterior_mu, posterior_log_sigma), dim=1)
        theta = self.drop_theta(theta)

        # prodLDA vs LDA
        if self.model_type == 'prodLDA':
            # in: batch_size x input_size x n_components
            word_dist = F.softmax(
                self.beta_batchnorm(torch.matmul(theta, self.beta)), dim=1)
            # word_dist: batch_size x input_size
            self.topic_word_matrix = self.beta
        elif self.model_type == 'LDA':
            # simplex constrain on Beta
            beta = F.softmax(self.beta_batchnorm(self.beta), dim=1)
            self.topic_word_matrix = beta
            word_dist = torch.matmul(theta, beta)
            # word_dist: batch_size x input_size

        return self.prior_mean, self.prior_variance, \
            posterior_mu, posterior_sigma, posterior_log_sigma, word_dist

    def get_theta(self, x, x_bert):
        with torch.no_grad():
            # batch_size x n_components
            posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert)
            posterior_sigma = torch.exp(posterior_log_sigma)

            # generate samples from theta
            theta = F.softmax(
                self.reparameterize(posterior_mu, posterior_log_sigma), dim=1)

            return theta


In [49]:
import datetime
import multiprocessing as mp
import os
import warnings
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
import wordcloud
from scipy.special import softmax
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm


class CTM:
    """Class to train the contextualized topic model. This is the more general class that we are keeping to
    avoid braking code, user should use the two subclasses ZeroShotTM and CombinedTm to do topic modeling.
        :param input_size: int, dimension of input
        :param bert_input_size: int, dimension of input that comes from BERT embeddings
        :param inference_type: string, you can choose between the contextual model and the combined model
        :param n_components: int, number of topic components, (default 10)
        :param model_type: string, 'prodLDA' or 'LDA' (default 'prodLDA')
        :param hidden_sizes: tuple, length = n_layers, (default (100, 100))
        :param activation: string, 'softplus', 'relu', (default 'softplus')
        :param dropout: float, dropout to use (default 0.2)
        :param learn_priors: bool, make priors a learnable parameter (default True)
        :param batch_size: int, size of batch to use for training (default 64)
        :param lr: float, learning rate to use for training (default 2e-3)
        :param momentum: float, momentum to use for training (default 0.99)
        :param solver: string, optimizer 'adam' or 'sgd' (default 'adam')
        :param num_epochs: int, number of epochs to train for, (default 100)
        :param reduce_on_plateau: bool, reduce learning rate by 10x on plateau of 10 epochs (default False)
        :param num_data_loader_workers: int, number of data loader workers (default cpu_count). set it to 0 if you are using Windows
    """

    def __init__(self, input_size, bert_input_size, inference_type="combined", n_components=10, model_type='prodLDA',
                 hidden_sizes=(100, 100), activation='softplus', dropout=0.2,
                 learn_priors=True, batch_size=64, lr=2e-3, momentum=0.99,
                 solver='adam', num_epochs=100, reduce_on_plateau=False, num_data_loader_workers=mp.cpu_count()):
        warnings.simplefilter('always', DeprecationWarning)

        if self.__class__.__name__ == "CTM":
            warnings.warn(
                "Direct call to CTM is deprecated and will be removed in version 2, use CombinedTM or ZeroShotTM",
                DeprecationWarning)

        assert isinstance(input_size, int) and input_size > 0, \
            "input_size must by type int > 0."
        assert isinstance(n_components, int) and input_size > 0, \
            "n_components must by type int > 0."
        assert model_type in ['LDA', 'prodLDA'], \
            "model must be 'LDA' or 'prodLDA'."
        assert isinstance(hidden_sizes, tuple), \
            "hidden_sizes must be type tuple."
        assert activation in ['softplus', 'relu'], \
            "activation must be 'softplus' or 'relu'."
        assert dropout >= 0, "dropout must be >= 0."
        assert isinstance(learn_priors, bool), "learn_priors must be boolean."
        assert isinstance(batch_size, int) and batch_size > 0, \
            "batch_size must be int > 0."
        assert lr > 0, "lr must be > 0."
        assert isinstance(momentum, float) and 0 < momentum <= 1, \
            "momentum must be 0 < float <= 1."
        assert solver in ['adam', 'sgd'], "solver must be 'adam' or 'sgd'."
        assert isinstance(reduce_on_plateau, bool), \
            "reduce_on_plateau must be type bool."
        assert isinstance(num_data_loader_workers, int) and num_data_loader_workers >= 0, \
            "num_data_loader_workers must by type int >= 0. set 0 if you are using windows"

        self.input_size = input_size
        self.n_components = n_components
        self.model_type = model_type
        self.hidden_sizes = hidden_sizes
        self.activation = activation
        self.dropout = dropout
        self.learn_priors = learn_priors
        self.batch_size = batch_size
        self.lr = lr
        self.bert_size = bert_input_size
        self.momentum = momentum
        self.solver = solver
        self.num_epochs = num_epochs
        self.reduce_on_plateau = reduce_on_plateau
        self.num_data_loader_workers = num_data_loader_workers

        self.model = DecoderNetwork(
            input_size, self.bert_size, inference_type, n_components, model_type, hidden_sizes, activation,
            dropout, learn_priors)
        self.early_stopping = None

        # init optimizer
        if self.solver == 'adam':
            self.optimizer = optim.Adam(
                self.model.parameters(), lr=lr, betas=(self.momentum, 0.99))
        elif self.solver == 'sgd':
            self.optimizer = optim.SGD(
                self.model.parameters(), lr=lr, momentum=self.momentum)

        # init lr scheduler
        if self.reduce_on_plateau:
            self.scheduler = ReduceLROnPlateau(self.optimizer, patience=10)

        # performance attributes
        self.best_loss_train = float('inf')

        # training attributes
        self.model_dir = None
        self.train_data = None
        self.nn_epoch = None

        # validation attributes
        self.validation_data = None

        # learned topics
        self.best_components = None

        # Use cuda if available
        if torch.cuda.is_available():
            self.USE_CUDA = True
        else:
            self.USE_CUDA = False

        if self.USE_CUDA:
            self.model = self.model.cuda()

    def _loss(self, inputs, word_dists, prior_mean, prior_variance,
              posterior_mean, posterior_variance, posterior_log_variance):

        # KL term
        # var division term
        var_division = torch.sum(posterior_variance / prior_variance, dim=1)
        # diff means term
        diff_means = prior_mean - posterior_mean
        diff_term = torch.sum(
            (diff_means * diff_means) / prior_variance, dim=1)
        # logvar det division term
        logvar_det_division = \
            prior_variance.log().sum() - posterior_log_variance.sum(dim=1)
        # combine terms
        KL = 0.5 * (
            var_division + diff_term - self.n_components + logvar_det_division)

        # Reconstruction term
        RL = -torch.sum(inputs * torch.log(word_dists + 1e-10), dim=1)

        loss = KL + RL

        return loss.sum()

    def _train_epoch(self, loader):
        """Train epoch."""
        self.model.train()
        train_loss = 0
        samples_processed = 0

        for batch_samples in loader:
            # batch_size x vocab_size
            X = batch_samples['X']
            X = X.reshape(X.shape[0], -1)
            X_bert = batch_samples['X_bert']
            if self.USE_CUDA:
                X = X.cuda()
                X_bert = X_bert.cuda()

            # forward pass
            self.model.zero_grad()
            prior_mean, prior_variance, posterior_mean, posterior_variance, posterior_log_variance, word_dists =\
                self.model(X, X_bert)

            # backward pass
            loss = self._loss(
                X, word_dists, prior_mean, prior_variance,
                posterior_mean, posterior_variance, posterior_log_variance)
            loss.backward()
            self.optimizer.step()

            # compute train loss
            samples_processed += X.size()[0]
            train_loss += loss.item()

        train_loss /= samples_processed

        return samples_processed, train_loss

    def fit(self, train_dataset, validation_dataset=None, save_dir=None, verbose=False, patience=5, delta=0):
        """
        Train the CTM model.
        :param train_dataset: PyTorch Dataset class for training data.
        :param validation_dataset: PyTorch Dataset class for validation data. If not None, the training stops if
        validation loss doesn't improve after a given patience
        :param save_dir: directory to save checkpoint models to.
        :param verbose: verbose
        :param patience: How long to wait after last time validation loss improved. Default: 5
        :param delta: Minimum change in the monitored quantity to qualify as an improvement. Default: 0
        """
        # Print settings to output file
        if verbose:
            print("Settings: \n\
                   N Components: {}\n\
                   Topic Prior Mean: {}\n\
                   Topic Prior Variance: {}\n\
                   Model Type: {}\n\
                   Hidden Sizes: {}\n\
                   Activation: {}\n\
                   Dropout: {}\n\
                   Learn Priors: {}\n\
                   Learning Rate: {}\n\
                   Momentum: {}\n\
                   Reduce On Plateau: {}\n\
                   Save Dir: {}".format(
                self.n_components, 0.0,
                1. - (1. / self.n_components), self.model_type,
                self.hidden_sizes, self.activation, self.dropout, self.learn_priors,
                self.lr, self.momentum, self.reduce_on_plateau, save_dir))

        self.model_dir = save_dir
        self.train_data = train_dataset
        self.validation_data = validation_dataset
        if self.validation_data is not None:
            self.early_stopping = EarlyStopping(patience=patience, verbose=verbose, path=save_dir, delta=delta)
        train_loader = DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_data_loader_workers)

        # init training variables
        train_loss = 0
        samples_processed = 0

        # train loop
        pbar = tqdm(self.num_epochs, position=0, leave=True)
        for epoch in range(self.num_epochs):
            self.nn_epoch = epoch
            # train epoch
            s = datetime.datetime.now()
            sp, train_loss = self._train_epoch(train_loader)
            samples_processed += sp
            e = datetime.datetime.now()
            pbar.update(1)
            pbar.set_description("Epoch: [{}/{}]\t Seen Samples: [{}/{}]\tTrain Loss: {}\tTime: {}".format(
                epoch + 1, self.num_epochs, samples_processed,
                len(self.train_data) * self.num_epochs, train_loss, e - s))

            if self.validation_data is not None:
                self.best_loss_train = train_loss
                self.best_components = self.model.beta

                validation_loader = DataLoader(self.validation_data, batch_size=self.batch_size, shuffle=True,
                                               num_workers=self.num_data_loader_workers)
                # train epoch
                s = datetime.datetime.now()
                val_samples_processed, val_loss = self._validation(validation_loader)
                e = datetime.datetime.now()

                # report
                if verbose:
                    print("Epoch: [{}/{}]\tSamples: [{}/{}]\tValidation Loss: {}\tTime: {}".format(
                        epoch + 1, self.num_epochs, val_samples_processed,
                        len(self.validation_data) * self.num_epochs, val_loss, e - s))

                self.early_stopping(val_loss, self)
                if self.early_stopping.early_stop:
                    print("Early stopping")

                    break
            else:
                # save best
                if train_loss < self.best_loss_train:
                    self.best_loss_train = train_loss
                    self.best_components = self.model.beta

                if save_dir is not None:
                    self.save(save_dir)

        pbar.close()

    def _validation(self, loader):
        """Validation epoch."""
        self.model.eval()
        val_loss = 0
        samples_processed = 0
        for batch_samples in loader:
            # batch_size x vocab_size
            X = batch_samples['X']
            X = X.reshape(X.shape[0], -1)
            X_bert = batch_samples['X_bert']

            if self.USE_CUDA:
                X = X.cuda()
                X_bert = X_bert.cuda()

            # forward pass
            self.model.zero_grad()
            prior_mean, prior_variance, posterior_mean, posterior_variance, posterior_log_variance, word_dists =\
                self.model(X, X_bert)
            loss = self._loss(X, word_dists, prior_mean, prior_variance,
                              posterior_mean, posterior_variance, posterior_log_variance)

            # compute train loss
            samples_processed += X.size()[0]
            val_loss += loss.item()

        val_loss /= samples_processed

        return samples_processed, val_loss

    def get_thetas(self, dataset, n_samples=20):
        """
        Get the document-topic distribution for a dataset of topics. Includes multiple sampling to reduce variation via
        the parameter n_sample.
        :param dataset: a PyTorch Dataset containing the documents
        :param n_samples: the number of sample to collect to estimate the final distribution (the more the better).
        """
        warnings.warn("Call to `get_thetas` is deprecated and will be removed in version 2, "
                      "use `get_doc_topic_distribution` instead",
                      DeprecationWarning)
        return self.get_doc_topic_distribution(dataset, n_samples=n_samples)

    def get_doc_topic_distribution(self, dataset, n_samples=20):
        """
        Get the document-topic distribution for a dataset of topics. Includes multiple sampling to reduce variation via
        the parameter n_sample.
        :param dataset: a PyTorch Dataset containing the documents
        :param n_samples: the number of sample to collect to estimate the final distribution (the more the better).
        """
        self.model.eval()

        loader = DataLoader(
            dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_data_loader_workers)
        pbar = tqdm(n_samples, position=0, leave=True)
        final_thetas = []
        for sample_index in range(n_samples):
            with torch.no_grad():
                collect_theta = []

                for batch_samples in loader:
                    # batch_size x vocab_size
                    X = batch_samples['X']
                    X = X.reshape(X.shape[0], -1)
                    X_bert = batch_samples['X_bert']

                    if self.USE_CUDA:
                        X = X.cuda()
                        X_bert = X_bert.cuda()

                    # forward pass
                    self.model.zero_grad()
                    collect_theta.extend(self.model.get_theta(X, X_bert).cpu().numpy().tolist())

                pbar.update(1)
                pbar.set_description("Sampling: [{}/{}]".format(sample_index + 1, n_samples))

                final_thetas.append(np.array(collect_theta))
        pbar.close()
        return np.sum(final_thetas, axis=0) / n_samples

    def get_most_likely_topic(self, doc_topic_distribution):
        """ get the most likely topic for each document
        :param doc_topic_distribution: ndarray representing the topic distribution of each document
        """
        return np.argmax(doc_topic_distribution, axis=0)

    def predict(self, dataset, k=10):
        """Predict input."""
        self.model.eval()

        loader = DataLoader(
            dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_data_loader_workers)

        preds = []

        with torch.no_grad():
            for batch_samples in loader:
                # batch_size x vocab_size
                X = batch_samples['X']
                X = X.reshape(X.shape[0], -1)
                X_bert = batch_samples['X_bert']

                if self.USE_CUDA:
                    X = X.cuda()
                    X_bert = X_bert.cuda()

                # forward pass
                self.model.zero_grad()
                _, _, _, _, _, word_dists = self.model(X, X_bert)

                _, indices = torch.sort(word_dists, dim=1)
                preds += [indices[:, :k]]

            preds = torch.cat(preds, dim=0)
        return preds

    def get_topics(self, k=10):
        """
        Retrieve topic words.
        :param k: int, number of words to return per topic, default 10.
        """
        assert k <= self.input_size, "k must be <= input size."
        component_dists = self.best_components
        topics = defaultdict(list)
        for i in range(self.n_components):
            _, idxs = torch.topk(component_dists[i], k)
            component_words = [self.train_data.idx2token[idx]
                               for idx in idxs.cpu().numpy()]
            topics[i] = component_words
        return topics

    def get_topic_lists(self, k=10):
        """
        Retrieve the lists of topic words.
        :param k: (int) number of words to return per topic, default 10.
        """
        assert k <= self.input_size, "k must be <= input size."
        # TODO: collapse this method with the one that just returns the topics
        component_dists = self.best_components
        topics = []
        for i in range(self.n_components):
            _, idxs = torch.topk(component_dists[i], k)
            component_words = [self.train_data.idx2token[idx]
                               for idx in idxs.cpu().numpy()]
            topics.append(component_words)
        return topics

    def _format_file(self):
        model_dir = "contextualized_topic_model_nc_{}_tpm_{}_tpv_{}_hs_{}_ac_{}_do_{}_lr_{}_mo_{}_rp_{}". \
            format(self.n_components, 0.0, 1 - (1. / self.n_components),
                   self.model_type, self.hidden_sizes, self.activation,
                   self.dropout, self.lr, self.momentum,
                   self.reduce_on_plateau)
        return model_dir

    def save(self, models_dir=None):
        """
        Save model. (Experimental Feature, not tested)
        :param models_dir: path to directory for saving NN models.
        """
        warnings.simplefilter('always', Warning)
        warnings.warn("This is an experimental feature that we has not been fully tested. Refer to the following issue:"
                      "https://github.com/MilaNLProc/contextualized-topic-models/issues/38",
                      Warning)

        if (self.model is not None) and (models_dir is not None):

            model_dir = self._format_file()
            if not os.path.isdir(os.path.join(models_dir, model_dir)):
                os.makedirs(os.path.join(models_dir, model_dir))

            filename = "epoch_{}".format(self.nn_epoch) + '.pth'
            fileloc = os.path.join(models_dir, model_dir, filename)
            with open(fileloc, 'wb') as file:
                torch.save({'state_dict': self.model.state_dict(),
                            'dcue_dict': self.__dict__}, file)

    def load(self, model_dir, epoch):
        """
        Load a previously trained model. (Experimental Feature, not tested)
        :param model_dir: directory where models are saved.
        :param epoch: epoch of model to load.
        """

        warnings.simplefilter('always', Warning)
        warnings.warn("This is an experimental feature that we has not been fully tested. Refer to the following issue:"
                      "https://github.com/MilaNLProc/contextualized-topic-models/issues/38",
                      Warning)

        epoch_file = "epoch_" + str(epoch) + ".pth"
        model_file = os.path.join(model_dir, epoch_file)
        with open(model_file, 'rb') as model_dict:
            checkpoint = torch.load(model_dict)

        for (k, v) in checkpoint['dcue_dict'].items():
            setattr(self, k, v)

        self.model.load_state_dict(checkpoint['state_dict'])

    def get_topic_word_matrix(self):
        """
        Return the topic-word matrix (dimensions: number of topics x length of the vocabulary).
        If model_type is LDA, the matrix is normalized; otherwise the matrix is unnormalized.
        """
        return self.model.topic_word_matrix.cpu().detach().numpy()

    def get_topic_word_distribution(self):
        """
        Return the topic-word distribution (dimensions: number of topics x length of the vocabulary).
        """
        mat = self.get_topic_word_matrix()
        return softmax(mat, axis=1)

    def get_word_distribution_by_topic_id(self, topic):
        """
        Return the word probability distribution of a topic sorted by probability.
        :param topic: id of the topic (int)
        :returns list of tuples (word, probability) sorted by the probability in descending order
        """
        if topic >= self.n_components:
            raise Exception('Topic id must be lower than the number of topics')
        else:
            wd = self.get_topic_word_distribution()
            t = [(word, wd[topic][idx]) for idx, word in self.train_data.idx2token.items()]
            t = sorted(t, key=lambda x: -x[1])
        return t

    def get_wordcloud(self, topic_id, n_words=5, background_color="black"):
        """
        Plotting the wordcloud. It is an adapted version of the code found here:
        http://amueller.github.io/word_cloud/auto_examples/simple.html#sphx-glr-auto-examples-simple-py and
        here https://github.com/ddangelov/Top2Vec/blob/master/top2vec/Top2Vec.py
        :param topic_id: id of the topic
        :param n_words: number of words to show in word cloud
        :param background_color: color of the background
        """
        word_score_list = self.get_word_distribution_by_topic_id(topic_id)[:n_words]
        word_score_dict = {tup[0]: tup[1] for tup in word_score_list}
        plt.figure(figsize=(10, 4), dpi=200)
        plt.axis("off")
        plt.imshow(wordcloud.WordCloud(width=1000, height=400, background_color=background_color
                                       ).generate_from_frequencies(word_score_dict))
        plt.title("Displaying Topic " + str(topic_id), loc='center', fontsize=24)
        plt.show()

    def get_predicted_topics(self, dataset, n_samples):
        """
        Return the a list containing the predicted topic for each document (length: number of documents).
        :param dataset: CTMDataset to infer topics
        :param n_samples: number of sampling of theta
        :return: the predicted topics
        """
        predicted_topics = []
        thetas = self.get_doc_topic_distribution(dataset, n_samples)

        for idd in range(len(dataset)):
            predicted_topic = np.argmax(thetas[idd] / np.sum(thetas[idd]))
            predicted_topics.append(predicted_topic)
        return predicted_topics


class ZeroShotTM(CTM):
    """
    ZeroShotTM, as described in https://arxiv.org/pdf/2004.07737v1.pdf
    :param input_size: int, dimension of input
    :param bert_input_size: int, dimension of input that comes from BERT embeddings
    :param n_components: int, number of topic components, (default 10)
    :param model_type: string, 'prodLDA' or 'LDA' (default 'prodLDA')
    :param hidden_sizes: tuple, length = n_layers, (default (100, 100))
    :param activation: string, 'softplus', 'relu', (default 'softplus')
    :param dropout: float, dropout to use (default 0.2)
    :param learn_priors: bool, make priors a learnable parameter (default True)
    :param batch_size: int, size of batch to use for training (default 64)
    :param lr: float, learning rate to use for training (default 2e-3)
    :param momentum: float, momentum to use for training (default 0.99)
    :param solver: string, optimizer 'adam' or 'sgd' (default 'adam')
    :param num_epochs: int, number of epochs to train for, (default 100)
    :param reduce_on_plateau: bool, reduce learning rate by 10x on plateau of 10 epochs (default False)
    :param num_data_loader_workers: int, number of data loader workers (default cpu_count). set it to 0 if you are using Windows
    """

    def __init__(self, input_size, bert_input_size, n_components=10, model_type='prodLDA',
                 hidden_sizes=(100, 100), activation='softplus', dropout=0.2,
                 learn_priors=True, batch_size=64, lr=2e-3, momentum=0.99,
                 solver='adam', num_epochs=100, reduce_on_plateau=False, num_data_loader_workers=mp.cpu_count()):
        inference_type = "zeroshot"
        super().__init__(input_size, bert_input_size, inference_type, n_components, model_type,
                         hidden_sizes, activation, dropout,
                         learn_priors, batch_size, lr, momentum,
                         solver, num_epochs, reduce_on_plateau, num_data_loader_workers)


class CombinedTM(CTM):
    """
    CombinedTM, as described in https://arxiv.org/pdf/2004.03974.pdf
    :param input_size: int, dimension of input
    :param bert_input_size: int, dimension of input that comes from BERT embeddings
    :param n_components: int, number of topic components, (default 10)
    :param model_type: string, 'prodLDA' or 'LDA' (default 'prodLDA')
    :param hidden_sizes: tuple, length = n_layers, (default (100, 100))
    :param activation: string, 'softplus', 'relu', (default 'softplus')
    :param dropout: float, dropout to use (default 0.2)
    :param learn_priors: bool, make priors a learnable parameter (default True)
    :param batch_size: int, size of batch to use for training (default 64)
    :param lr: float, learning rate to use for training (default 2e-3)
    :param momentum: float, momentum to use for training (default 0.99)
    :param solver: string, optimizer 'adam' or 'sgd' (default 'adam')
    :param num_epochs: int, number of epochs to train for, (default 100)
    :param reduce_on_plateau: bool, reduce learning rate by 10x on plateau of 10 epochs (default False)
    :param num_data_loader_workers: int, number of data loader workers (default cpu_count). set it to 0 if you are using Windows
    """

    def __init__(self, input_size, bert_input_size, n_components=10, model_type='prodLDA',
                 hidden_sizes=(100, 100), activation='softplus', dropout=0.2,
                 learn_priors=True, batch_size=64, lr=2e-3, momentum=0.99,
                 solver='adam', num_epochs=100, reduce_on_plateau=False, num_data_loader_workers=mp.cpu_count()):
        inference_type = "combined"
        super().__init__(input_size, bert_input_size, inference_type, n_components, model_type,
                         hidden_sizes, activation, dropout,
                         learn_priors, batch_size, lr, momentum,
                         solver, num_epochs, reduce_on_plateau, num_data_loader_workers)

In [50]:
from sklearn.feature_extraction.text import CountVectorizer
import string
from nltk.corpus import stopwords as stop_words
import warnings

class WhiteSpacePreprocessing():
    """
    Provides a very simple preprocessing script that filters infrequent tokens from text
    """
    def __init__(self, documents, stopwords_language="english", vocabulary_size=2000):
        """
        :param documents: list of strings
        :param stopwords_language: string of the language of the stopwords (see nltk stopwords)
        :param vocabulary_size: the number of most frequent words to include in the documents. Infrequent words will be discarded from the list of preprocessed documents
        """
        self.documents = documents
        self.stopwords = set(stop_words.words(stopwords_language))
        self.vocabulary_size = vocabulary_size

    def preprocess(self):
        """
        Note that if after filtering some documents do not contain words we remove them. That is why we return also the
        list of unpreprocessed documents.
        :return: preprocessed documents, unpreprocessed documents and the vocabulary list
        """
        preprocessed_docs_tmp = self.documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords])
                             for doc in preprocessed_docs_tmp]

        vectorizer = CountVectorizer(max_features=self.vocabulary_size, token_pattern=r'\b[a-zA-Z]{2,}\b')
        vectorizer.fit_transform(preprocessed_docs_tmp)
        vocabulary = set(vectorizer.get_feature_names())
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if w in vocabulary])
                                 for doc in preprocessed_docs_tmp]

        preprocessed_docs, unpreprocessed_docs = [], []
        for i, doc in enumerate(preprocessed_docs_tmp):
            if len(doc) > 0:
                preprocessed_docs.append(doc)
                unpreprocessed_docs.append(self.documents[i])

        return preprocessed_docs, unpreprocessed_docs, list(vocabulary)


class SimplePreprocessing(WhiteSpacePreprocessing):
    def __init__(self, documents, stopwords_language="english"):
        super().__init__(documents, stopwords_language)
        warnings.simplefilter('always', DeprecationWarning)

        if self.__class__.__name__ == "CTM":

            warnings.warn("SimplePrepocessing is deprecated and will be removed in version 2.0, "
                          "use WhiteSpacePreprocessing", DeprecationWarning)



In [51]:
import torch
from torch.utils.data import Dataset
import scipy.sparse

class CTMDataset(Dataset):

    """Class to load BOW dataset."""

    def __init__(self, X, X_bert, idx2token):
        """
        Args
            X : array-like, shape=(n_samples, n_features)
                Document word matrix.
        """
        if X.shape[0] != len(X_bert):
            raise Exception("Wait! BoW and Contextual Embeddings have different sizes! "
                            "You might want to check if the BoW preparation method has removed some documents. ")

        self.X = X
        self.X_bert = X_bert
        self.idx2token = idx2token

    def __len__(self):
        """Return length of dataset."""
        return self.X.shape[0]

    def __getitem__(self, i):
        """Return sample from dataset at index i."""
        if type(self.X[i]) == scipy.sparse.csr.csr_matrix:
            X = torch.FloatTensor(self.X[i].todense())
            X_bert = torch.FloatTensor(self.X_bert[i])
        else:
            X = torch.FloatTensor(self.X[i])
            X_bert = torch.FloatTensor(self.X_bert[i])

        return {'X': X, 'X_bert': X_bert}

In [52]:
import numpy as np
from sentence_transformers import SentenceTransformer
import scipy.sparse
import warnings
from sklearn.feature_extraction.text import CountVectorizer


def get_bag_of_words(data, min_length):
    """
    Creates the bag of words
    """
    vect = [np.bincount(x[x != np.array(None)].astype('int'), minlength=min_length)
            for x in data if np.sum(x[x != np.array(None)]) != 0]

    vect = scipy.sparse.csr_matrix(vect)
    return vect


def bert_embeddings_from_file(text_file, sbert_model_to_load, batch_size=200):
    """
    Creates SBERT Embeddings from an input file
    """
    model = SentenceTransformer(sbert_model_to_load)
    with open(text_file, encoding="utf-8") as filino:
        train_text = list(map(lambda x: x, filino.readlines()))

    return np.array(model.encode(train_text, show_progress_bar=True, batch_size=batch_size))


def bert_embeddings_from_list(texts, sbert_model_to_load, batch_size=200):
    """
    Creates SBERT Embeddings from a list
    """
    model = SentenceTransformer(sbert_model_to_load)
    return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size))


class TopicModelDataPreparation:

    def __init__(self, contextualized_model=None):
        self.contextualized_model = contextualized_model
        self.vocab = []
        self.id2token = {}
        self.vectorizer = None

    def load(self, contextualized_embeddings, bow_embeddings, id2token):
        return CTMDataset(bow_embeddings, contextualized_embeddings, id2token)

    def create_training_set(self, text_for_contextual, text_for_bow):

        if self.contextualized_model is None:
            raise Exception("You should define a contextualized model if you want to create the embeddings")

        # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users
        self.vectorizer = CountVectorizer()

        train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow)
        train_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model)
        self.vocab = self.vectorizer.get_feature_names()
        self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)}

        return CTMDataset(train_bow_embeddings, train_contextualized_embeddings, self.id2token)

    def create_test_set(self, text_for_contextual, text_for_bow=None):

        if self.contextualized_model is None:
            raise Exception("You should define a contextualized model if you want to create the embeddings")

        if text_for_bow is not None:
            test_bow_embeddings = self.vectorizer.transform(text_for_bow)
        else:
            # dummy matrix
            test_bow_embeddings = scipy.sparse.csr_matrix(np.zeros((len(text_for_contextual), 1)))
        test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model)

        return CTMDataset(test_bow_embeddings, test_contextualized_embeddings, self.id2token)

    def create_validation_set(self, text_for_contextual, text_for_bow=None):
        return self.create_test_set(text_for_contextual=text_for_contextual, text_for_bow=text_for_bow)


class QuickText:
    """
    Integrated class to handle all the text preprocessing needed
    """
    def __init__(self, bert_model, text_for_bow, text_for_bert=None):
        """
        :param bert_model: string, bert model to use
        :param text_for_bow: list, list of sentences with the preprocessed text
        :param text_for_bert: list, list of sentences with the unpreprocessed text
        """
        self.vocab_dict = {}
        self.vocab = []
        self.index_dd = None
        self.idx2token = None
        self.bow = None
        self.bert_model = bert_model
        self.text_handler = ""
        self.data_bert = None
        self.text_for_bow = text_for_bow
        self.text_for_bert = text_for_bert
        self.loaded_from_config = False

    def prepare_bow(self):
        indptr = [0]
        indices = []
        data = []
        vocabulary = {}

        if self.text_for_bow is not None:
            docs = self.text_for_bow
        else:
            docs = self.text_for_bert

        for d in docs:
            for term in d.split():
                index = vocabulary.setdefault(term, len(vocabulary))
                indices.append(index)
                data.append(1)
            indptr.append(len(indices))

        self.vocab_dict = vocabulary
        self.vocab = list(vocabulary.keys())

        warnings.simplefilter('always', DeprecationWarning)
        if len(self.vocab) > 2000:
            warnings.warn("The vocab you are using has more than 2000 words, reconstructing high-dimensional vectors requires"
                          "significantly more training epochs and training samples. "
                          "Consider reducing the number of vocabulary items. "
                          "See https://github.com/MilaNLProc/contextualized-topic-models#preprocessing "
                          "and https://github.com/MilaNLProc/contextualized-topic-models#tldr", Warning)

        self.idx2token = {v: k for (k, v) in self.vocab_dict.items()}
        self.bow = scipy.sparse.csr_matrix((data, indices, indptr), dtype=int)

    def load_configuration(self, bow_embeddings, contextualized_embeddings, vocab, id2token):
        """
        This method defines a way to instantiate the model with pre-trained data.
        """

        assert len(contextualized_embeddings) == bow_embeddings.shape[0]
        assert len(vocab) == len(id2token)
        self.data_bert = contextualized_embeddings
        self.bow = bow_embeddings
        self.vocab = vocab
        self.idx2token = id2token
        self.loaded_from_config = True

    def load_pre_trained_contextualized(self, contextualized_embeddings):
        """
        In case the contextualized embeddings have been already trained, it is possible to load them with this method
        """
        self.data_bert = contextualized_embeddings

    def load_dataset(self):
        if self.loaded_from_config:
            training_dataset = CTMDataset(self.bow, self.data_bert, self.idx2token)
        else:
            self.prepare_bow()

            if self.data_bert is None:
                if self.text_for_bert is not None:
                    self.data_bert = bert_embeddings_from_list(self.text_for_bert, self.bert_model)
                else:
                    self.data_bert = bert_embeddings_from_list(self.text_for_bow, self.bert_model)

            training_dataset = CTMDataset(self.bow, self.data_bert, self.idx2token)

        return training_dataset

class TextHandler:
    """
    Class used to handle the text preparation and the BagOfWord
    """
    def __init__(self, file_name=None, sentences=None):
        self.file_name = file_name
        self.sentences = sentences
        self.vocab_dict = {}
        self.vocab = []
        self.index_dd = None
        self.idx2token = None
        self.bow = None

        warnings.simplefilter('always', DeprecationWarning)
        if len(self.vocab) > 2000:
            warnings.warn("TextHandler class is deprecated and will be removed in version 2.0. Use QuickText.", Warning)

    def prepare(self):
        indptr = [0]
        indices = []
        data = []
        vocabulary = {}

        if self.sentences is None and self.file_name is None:
            raise Exception("Sentences and file_names cannot both be none")

        if self.sentences is not None:
            docs = self.sentences
        elif self.file_name is not None:
            with open(self.file_name, encoding="utf-8") as filino:
                docs = filino.readlines()
        else:
            raise Exception("One parameter between sentences and file_name should be selected")

        for d in docs:
            for term in d.split():
                index = vocabulary.setdefault(term, len(vocabulary))
                indices.append(index)
                data.append(1)
            indptr.append(len(indices))

        self.vocab_dict = vocabulary
        self.vocab = list(vocabulary.keys())

        warnings.simplefilter('always', DeprecationWarning)
        if len(self.vocab) > 2000:
            warnings.warn("The vocab you are using has more than 2000 words, reconstructing high-dimensional vectors requires"
                          "significantly more training epochs and training samples. "
                          "Consider reducing the number of vocabulary items. "
                          "See https://github.com/MilaNLProc/contextualized-topic-models#preprocessing "
                          "and https://github.com/MilaNLProc/contextualized-topic-models#tldr", Warning)

        self.idx2token = {v: k for (k, v) in self.vocab_dict.items()}
        self.bow = scipy.sparse.csr_matrix((data, indices, indptr), dtype=int)

In [53]:
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/sdeshpande/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [64]:
titles = health_data.title.tolist()[:100]

  and should_run_async(code)


In [65]:
sp = WhiteSpacePreprocessing(titles, stopwords_language='english')
preprocessed_documents, unpreprocessed_corpus, vocab = sp.preprocess()

  and should_run_async(code)


In [66]:
preprocessed_documents[:2]

  and should_run_async(code)


['absence surface expression feline infectious peritonitis virus fipv antigens infected cells isolated cats fip',
 'correlation antimicrobial consumption incidence health care associated infections due methicillin resistant staphylococcus aureus vancomycin resistant enterococci university hospital taiwan']

In [67]:
tp = TopicModelDataPreparation("distiluse-base-multilingual-cased")
training_dataset = tp.create_training_set(text_for_contextual=unpreprocessed_corpus, text_for_bow=preprocessed_documents)
tp.vocab[:10]

  and should_run_async(code)
0it [1:17:46, ?it/s]
0it [00:15, ?it/s]
Batches: 100%|██████████| 1/1 [00:12<00:00, 12.28s/it]


['absence',
 'absolute',
 'accompanied',
 'acid',
 'acidification',
 'acids',
 'acquired',
 'activate',
 'activation',
 'activities']

In [71]:
ctm = ZeroShotTM(input_size=len(tp.vocab), bert_input_size=512, n_components=50, num_epochs=1)
ctm.fit(training_dataset) # run the model

  and should_run_async(code)
0it [00:00, ?it/s]

KeyboardInterrupt: 

In [None]:
ctm.get_topic_lists(5)

In [None]:
# n_sample how many times to sample the distribution (see the documentation)
topics_predictions = ctm.get_thetas(titles[101:105], n_samples=5) # get all the topic predictions

In [None]:
topic_number = np.argmax(topics_predictions[0]) # get the topic id of the first document
ctm.get_topic_lists(10)[topic_number] 