## non-negative matrix factorization (NMF) based topic modeling
This notebook presents the NMF approach

In [1]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
import random

# step 1: import
import torch
from torch import nn 
from torch import optim
from torch.optim.lr_scheduler import StepLR

from utils import *
from estimators import *

### Gensim
import gensim
import gensim.corpora as corpora
from sklearn.feature_extraction.text import CountVectorizer
### load NMF utility functions
from nmf_util import *
### load coherence score
import gensim.downloader as api
from coherence_score import *

from sklearn.preprocessing import LabelEncoder
from sklearn import model_selection, naive_bayes, svm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.metrics import classification_report

### Load Data from Json

In [2]:
### json load the dataset
with open('../cleaned_data/Spam_Ham.json', 'r') as jf:
    cleaned_data = json.load(jf)

In [3]:
### split data into 'sentence' and 'label'
sentences = [it['sentence'] for it in cleaned_data]
labels = [it['label'] for it in cleaned_data]

In [4]:
set(labels)

{'ham', 'spam'}

### Load pre-trained GloVe embeddings

In [5]:
model_glove = api.load("glove-twitter-100")   ## load pretrained glove embeddings

### Use Count Vectors as features

In [6]:
## convert the corpora to Count vectors
count = CountVectorizer(max_df=.95, min_df=10, max_features=5000)
x_count = count.fit_transform(sentences)
## convert to matrix --- feature-document matrix
count_mat = x_count.toarray().T 

In [7]:
## features
features = count.get_feature_names()
len(list(features)),len(list(labels))

(869, 5572)

# 1. gassian_method L2 loss


In [8]:
## NMF methods for topic modeling
k = 100   ## the number of topics -- tune it for better result
W0,H0,err0=gaussian_method(count_mat, k, max_iter=4)  ## will return factor matrices: W, H and root mean squared error


In [9]:
np.square(count_mat - W0@H0).sum()

22905.426368242883

## coherence score

In [10]:
dic0 = top_keywords(W0, features, num=20)

In [11]:
## compute the coherence score for each topic
coherence_vec = []
for i in range(W0.shape[1]):  
    coherence_vec.append(coherence(dic0[i], model_glove))

In [12]:
np.mean(coherence_vec)   ## the mean coherence score of all topics

0.5104895

In [13]:
from sklearn.preprocessing import LabelEncoder
from sklearn import model_selection, naive_bayes, svm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.metrics import classification_report


indices = list(range(len(labels)))   ## indices of documents

## split data into train and test
ind_train, ind_test, y_train, y_test = train_test_split(
    indices, labels, test_size=0.2, random_state=2021, stratify=labels)

## train/test datasets

#H0 = H0.detach().numpy()
print(H0.shape)
x_train, x_test = H0[:, ind_train],H0[:, ind_test]

## encode labels to integers
Encoder = LabelEncoder()
Y_train = Encoder.fit_transform(y_train)
Y_test = Encoder.fit_transform(y_test)


# Classifier - Algorithm - SVM -- linear kernel
# fit the training dataset on the classifier
SVM = svm.SVC(C=1., kernel='linear', degree=3, gamma='auto', random_state=82, class_weight='balanced')
SVM.fit(x_train.T, Y_train)# predict the labels on validation dataset
predictions_SVM = SVM.predict(x_test.T) # make predictions
print(classification_report(Y_test, predictions_SVM, digits=3))

(100, 5572)
              precision    recall  f1-score   support

           0      0.982     0.978     0.980       966
           1      0.863     0.886     0.874       149

    accuracy                          0.966      1115
   macro avg      0.923     0.932     0.927      1115
weighted avg      0.966     0.966     0.966      1115



# 2. SGD without MI

In [14]:
A = torch.FloatTensor(count_mat)
#A. type
print(A.shape)
A

torch.Size([869, 5572])


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [15]:
class SeparableCritic(nn.Module):
    """Separable critic. where the output value is g(x) h(y). """

    def __init__(self, dim1, dim2, hidden_dim, embed_dim, layers, activation, **extra_kwargs):
        super(SeparableCritic, self).__init__()
        self._g = mlp(dim1, hidden_dim, embed_dim, layers, activation)
        self._h = mlp(dim2, hidden_dim, embed_dim, layers, activation)

    def forward(self, x, y):
        scores = torch.matmul(self._h(y), self._g(x).t())
        return scores



In [16]:
#step 2: create model Class
class GaussianNMF(torch.nn.Module):
    """
    class for Non-Negetive Matrix Multiplication using Gaussian Method
    
    """
    def __init__(self, A, k):
        """ initialization """
        super(GaussianNMF, self).__init__()
        self.rows = A.size(0)
        self.cols = A.size(1)
        self.A = A
        
        self.W = torch.nn.Parameter(torch.rand(self.rows, k), requires_grad=True) 
        self.H = torch.nn.Parameter(torch.rand(k, self.cols), requires_grad=True)
        #print(self.H)
        self.num_topics = k
        
  
    def forward(self):
        return self.W.matmul(self.H) 
  
    def batch_gd_train(self, epochs, batch_size, lr):
        """
        train with full batch gradient descent
        :params[in]: epochs,
        
        :params[out]: W, H
        """
        optimizer = torch.optim.SGD(self.parameters(), lr=lr)
        for i in range(epochs):
            pred = self.forward()
            loss = (self.A - pred).pow(2).sum()
            ## backward
            optimizer.zero_grad()   ## zero all gradients
            loss.backward()   ## find derivatives
            optimizer.step()  
            self.W.data[self.W.data < 0] = 0. 
            self.H.data[self.H.data < 0] = 0. 
            print('loss at Epoch ',i, ' ',loss.item())
        ## return
        return self.W, self.H
        
    ## split an iterable of items into batches 
    def chunks(self, ls, batch_size): 
        """ 
        Yield successive n-sized chunks from ls, an iterable. 
        :params[in]: ls, an iterable of items 
        :params[in]: batch_size, an integer, batch size 
        returns a generator 
        """ 
        for i in range(0, len(ls), batch_size): 
            yield ls[i:i + batch_size]
        
    '''def estimate_mutual_information(estimator, x, y, critic_fn,
                                baseline_fn=None, alpha_logit=None, **kwargs):
        
        if estimator == 'smile':
            mi = smile_lower_bound(scores, **kwargs)
    '''

    def sgd_train(self, epochs, batch_size, lr):
        """
        train with stochastic gradient descent
        :params[in]: epochs,
        
        :params[out]: W, H
        
        ** x = mini_data.T
        y = mini_datah.T
        """
        
        optimizer = torch.optim.SGD(self.parameters(), lr=lr)                           
        scheduler= StepLR(optimizer, step_size=10, gamma=0.8)
        data_index = list(range(self.cols))   ## all column indices
        for i in range(epochs):
            mini_batches = self.chunks(data_index, batch_size)
            for it in mini_batches:
                mini_data = A[:, it]
                
                mini_datah= self.H[:, it]
               
                ## data in a minibatch
                pred = self.forward()[:, it]  ## prediction
                loss = (mini_data - pred).pow(2).sum()
                ## backward
                optimizer.zero_grad()   ## zero all gradients
                loss.backward()   ## find derivatives
                optimizer.step()  
                self.W.data[self.W.data < 0] = 0. 
                self.H.data[self.H.data < 0] = 0.
               
            ## shuffle indices
            scheduler.step()
            np.random.shuffle(data_index) 
            
            ## current loss
            cur_loss = (self.A - self.W@self.H).pow(2).sum()
            print('loss at Epoch ',i, ' ',cur_loss.item())
        ## return             
        return self.W, self.H

 

In [17]:
nmf_method1 = GaussianNMF(A, 100) ## matrix factorization
W1, H1 = nmf_method1.sgd_train(epochs=40, batch_size=1024, lr=2.e-2)

loss at Epoch  0   31351094.0
loss at Epoch  1   395727.1875
loss at Epoch  2   56213.27734375
loss at Epoch  3   41646.6015625
loss at Epoch  4   40193.1328125
loss at Epoch  5   40109.69140625
loss at Epoch  6   39858.4765625
loss at Epoch  7   39402.65625
loss at Epoch  8   38831.80859375
loss at Epoch  9   37860.8203125
loss at Epoch  10   37110.359375
loss at Epoch  11   36707.05859375
loss at Epoch  12   36096.63671875
loss at Epoch  13   35323.9921875
loss at Epoch  14   34516.2109375
loss at Epoch  15   33655.1484375
loss at Epoch  16   32799.0390625
loss at Epoch  17   31930.1015625
loss at Epoch  18   31105.078125
loss at Epoch  19   30328.244140625
loss at Epoch  20   29739.7578125
loss at Epoch  21   29199.62109375
loss at Epoch  22   28677.41015625
loss at Epoch  23   28174.40234375
loss at Epoch  24   27692.0390625
loss at Epoch  25   27217.283203125
loss at Epoch  26   26771.93359375
loss at Epoch  27   26339.5546875
loss at Epoch  28   25932.5078125
loss at Epoch  29   

In [18]:
(A-W1@H1).pow(2).sum()

tensor(23056.5703, grad_fn=<SumBackward0>)

In [19]:
from sklearn.preprocessing import LabelEncoder
from sklearn import model_selection, naive_bayes, svm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.metrics import classification_report



indices = list(range(len(labels)))   ## indices of documents

## split data into train and test
ind_train, ind_test, y_train, y_test = train_test_split(
    indices, labels, test_size=0.2, random_state=2021, stratify=labels)

## train/test datasets

H1 = H1.detach().numpy()
print(H1.shape)
x_train, x_test = H1[:, ind_train],H1[:, ind_test]

## encode labels to integers
Encoder = LabelEncoder()
Y_train = Encoder.fit_transform(y_train)
Y_test = Encoder.fit_transform(y_test)


# Classifier - Algorithm - SVM -- linear kernel
# fit the training dataset on the classifier
SVM = svm.SVC(C=1., kernel='linear', degree=3, gamma='auto', random_state=82, class_weight='balanced')
SVM.fit(x_train.T, Y_train)# predict the labels on validation dataset
predictions_SVM = SVM.predict(x_test.T) # make predictions
print(classification_report(Y_test, predictions_SVM, digits=3))

(100, 5572)
              precision    recall  f1-score   support

           0      0.987     0.977     0.982       966
           1      0.862     0.919     0.890       149

    accuracy                          0.970      1115
   macro avg      0.925     0.948     0.936      1115
weighted avg      0.971     0.970     0.970      1115



## coherence score

In [56]:
dic1 = top_keywords(W1, features, num=20)

In [57]:
## compute the coherence score for each topic
coherence_vec = []
for i in range(W1.shape[1]):  
    coherence_vec.append(coherence(dic1[i], model_glove))
    
np.mean(coherence_vec)   ## the mean coherence score of all topics

0.620296

# 3. SGD with MI

In [20]:
A = torch.FloatTensor(count_mat)
#A. type
print(A.shape)
#A=A.type(torch.FloatTensor)
A

torch.Size([869, 5572])


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [21]:
class GaussianNMF_MI(GaussianNMF):
    """
    class for Non-Negetive Matrix Multiplication using Gaussian Method 
    with mutual information
    
    """
    def __init__(self, A, k, critic_config):
        """ initialization
        
        :params[in]: A, k
        :params[in]: critic_config, a dictionary, 
        
        
        """
        super(GaussianNMF_MI, self).__init__(A, k)
        self.critic_config = critic_config
        ### critic function for computing mutual information
        self.critic_config['dim1'] = self.rows
        self.critic_config['dim2'] = k
        self.critic = SeparableCritic(**self.critic_config)
  

    def sgd_train(self, epochs, batch_size, lr, xi, W_init = None):
        """
        train with stochastic gradient descent
        :params[in]: epochs,
        
        :params[out]: W, H
        
        ** 
        x = mini_data.T
        y = mini_datah.T
        """
        if W_init is not None:
            self.W.data = W_init.data
        optimizer = torch.optim.SGD(self.parameters(), lr)
        scheduler = StepLR(optimizer, step_size=1, gamma=0.9)
        data_index = list(range(self.cols))   ## all column indices
        for i in range(epochs):
            mini_batches = self.chunks(data_index, batch_size)
            for it in mini_batches:
                mini_data = A[:, it]
                mini_datah= self.H[:, it]

                # calculate mi 
                mi = estimate_mutual_information('smile', mini_data.T, mini_datah.T, 
                                                 self.critic)
                ## data in a minibatch
                pred = self.forward()[:, it]  ## prediction
                #guassian
                loss = (mini_data - pred).pow(2).sum() - xi * mi
                #possian
                #loss = (pred-mini_data*torch.log(pred)).sum()-xi*mi
              
                ## backward
                #lr = scheduler.get_lr()
                optimizer.zero_grad()   ## zero all gradients
                loss.backward()## find derivatives
                #load_state_dict(state_dict)
                optimizer.step()  
                self.W.data[self.W.data < 0] = 0. 
                self.H.data[self.H.data < 0] = 0. 
            ## renew learning rate
            #print('Epoch前:', i,'LR:', lr)
            scheduler.step()
            #print('Epoch:', i,'LR:', lr)
            ## shuffle indices
            np.random.shuffle(data_index) 
            ## current loss
            cur_loss = (self.A - self.W@self.H).pow(2).sum()
            print('loss at Epoch ',i, ' ',cur_loss.item())
            #print(lr)
        ## return
        return self.W, self.H

In [22]:
A = torch.FloatTensor(count_mat)
#A. type
print(A.shape)
#A=A.type(torch.FloatTensor)
A

torch.Size([869, 5572])


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [23]:

#NMF-MI method 
k=100
dim1,dim2=A.shape[0],k
mi_params = dict(estimator='smile',critic='separable', baseline='unnormalized')


data_params = {
    'dim': dim1,
    'batch_size': 64,
    'cubic': None
}

critic_params = {
     
    'dim1': dim1,
    'dim2': dim2,
    'layers': 2,
    'embed_dim': 32,
    'hidden_dim': 256,
    'activation': 'relu',
}

critic = SeparableCritic(**critic_params)#.cuda()

#A = torch.randn(count_mat.shape[0],count_mat.shape[1])
critic_params = {
    'layers': 2,
    'embed_dim': 32,
    'hidden_dim': 256,
    'activation': 'relu',
}
##init nmf especially the W, H
nmf1 = GaussianNMF_MI(A, k, critic_params)
##put W in every client

W, H = nmf1.sgd_train(epochs=40, batch_size=1024, lr=1.e-1,xi = 0)#, W_init = torch.rand([2153, 20])) 
#def server_train(self, epoch, client_num, batch_sieze, lr, xi)
#W, H = nmf1.server_train(epoch=4, client_num = 3, batch_size=128, lr=2.e-4, xi=0 )
                 

loss at Epoch  0   624907264.0
loss at Epoch  1   128653712.0
loss at Epoch  2   10460243968.0
loss at Epoch  3   3099857408.0
loss at Epoch  4   43187040256.0
loss at Epoch  5   2322850048.0
loss at Epoch  6   807626496.0
loss at Epoch  7   386543264.0
loss at Epoch  8   11900224.0
loss at Epoch  9   155077.375
loss at Epoch  10   41710.21484375
loss at Epoch  11   41894.96875
loss at Epoch  12   38850.515625
loss at Epoch  13   38284.765625
loss at Epoch  14   37976.69921875
loss at Epoch  15   36883.6640625
loss at Epoch  16   36692.74609375
loss at Epoch  17   36463.1015625
loss at Epoch  18   36178.3359375
loss at Epoch  19   36010.66015625
loss at Epoch  20   35840.515625
loss at Epoch  21   35643.20703125
loss at Epoch  22   35435.6328125
loss at Epoch  23   35227.5546875
loss at Epoch  24   35044.3359375
loss at Epoch  25   34875.17578125
loss at Epoch  26   34730.859375
loss at Epoch  27   34603.60546875
loss at Epoch  28   34491.140625
loss at Epoch  29   34389.45703125
loss 

###### SVM Classifier

In [28]:
from sklearn.preprocessing import LabelEncoder
from sklearn import model_selection, naive_bayes, svm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.metrics import classification_report

In [29]:
indices = list(range(len(labels)))   ## indices of documents

In [30]:
## split data into train and test
ind_train, ind_test, y_train, y_test = train_test_split(
    indices, labels, test_size=0.2, random_state=2021, stratify=labels)

In [33]:

## train/test datasets

H = H.detach().numpy()
print(H.shape)
x_train, x_test = H[:, ind_train],H[:, ind_test]

(100, 5572)


In [34]:
## encode labels to integers
Encoder = LabelEncoder()
Y_train = Encoder.fit_transform(y_train)
Y_test = Encoder.fit_transform(y_test)


# Classifier - Algorithm - SVM -- linear kernel
# fit the training dataset on the classifier
SVM = svm.SVC(C=1., kernel='linear', degree=3, gamma='auto', random_state=82, class_weight='balanced')
SVM.fit(x_train.T, Y_train)# predict the labels on validation dataset
predictions_SVM = SVM.predict(x_test.T) # make predictions
print(classification_report(Y_test, predictions_SVM, digits=3))

              precision    recall  f1-score   support

           0      0.882     0.836     0.859       966
           1      0.206     0.275     0.236       149

    accuracy                          0.761      1115
   macro avg      0.544     0.556     0.547      1115
weighted avg      0.792     0.761     0.775      1115



## coherence score 

In [155]:
dic = top_keywords(W, features, num=20)

In [156]:
## compute the coherence score for each topic
coherence_vec = []
for i in range(W.shape[1]):  
    coherence_vec.append(coherence(dic[i], model_glove))

np.mean(coherence_vec)   ## the mean coherence score of all topics

0.50627863

# Federated learning SGD

In [24]:
A = torch.FloatTensor(count_mat)
#A. type
print(A.shape)
#A=A.type(torch.FloatTensor)


torch.Size([869, 5572])


In [25]:
def evaluation(H, labels):
    indices = list(range(len(labels)))   ## indices of documents
    
    ## split data into train and test
    ind_train, ind_test, y_train, y_test = train_test_split(
        indices, labels, test_size=0.2, random_state=2021, stratify=labels)
    H_new = H.detach().numpy()
    x_train, x_test = H_new[:, ind_train],H_new[:, ind_test]
    
    ## encode labels to integers
    Encoder = LabelEncoder()
    Y_train = Encoder.fit_transform(y_train)
    Y_test = Encoder.fit_transform(y_test)


    # Classifier - Algorithm - SVM -- linear kernel
    # fit the training dataset on the classifier
    SVM = svm.SVC(C=1., kernel='linear', degree=3, gamma='auto', random_state=82, class_weight='balanced')
    SVM.fit(x_train.T, Y_train)# predict the labels on validation dataset
    predictions_SVM = SVM.predict(x_test.T) # make predictions
    print(classification_report(Y_test, predictions_SVM, digits=3))

In [32]:
help(random_sample)

NameError: name 'random_sample' is not defined

In [33]:
import random

class Fed_NMF(nn.Module):
    """
    Federaated NMF 
    """
    def __init__(self, A, k, K):
        """
        initialization 
        :params[in], A, full matrix of all texts data
        :params[in], k, number of topics
        
        """
        super(Fed_NMF, self).__init__()
        self.A = A
        self.rows,self.cols = A.size()
        self.k = k  ## topic number
        self.K = K  ## number of clients

    def split_clients(self):
        """
        split the full matrix into m clients by column
        
        :params[in]: K, the number of clients
        
        :params[out]: B, tuple of tensors
        """
        data_index = list(range(self.cols))   ## all column indices
        #np.random.shuffle(data_index) 
        ## split into chunks after shuffle
        B = torch.chunk(self.A[:,data_index], self.K, dim = 1)
        return B


    def split_into_chunks(self, ls, batch_size):
        """
        split a list of number into chunks up to a certain batch_size
        
        :params[in]: ls, a list of numbers
        :params[in]: batch_size, an interger     
        
        :params[out]: a generator
        """
        np.random.shuffle(ls) 
        for i in range(0, len(ls), batch_size): 
            yield ls[i:i + batch_size]

    
    def server_train(self, labels, iters, C, epoch, batch_size, lr, xi):
        """
        federated learning
        
        :params[in], iters, the number of interations for fedrated learning
        :params[in], C, the fraction of clients for each iteration
        :params[in], epoch, the number of epochs for local SGD
        :params[in], batch_size, the batch_size for local SGD
        :params[in], lr, the learning rate for local SGD
        :params[in], xi, the mutual information for local SGD
        
        :params[out]
        :
        """
        m = int(max(C*self.K,1))  ## number of clients for each iteration
        ## split whole dataset into K clients
        # torch.manual_seed(4)
        B = self.split_clients()
        ## column number for all clients -- using list comprehension
        num_cols = [it.size(1) for it in B]
        #Cols = num_cols.sum()
        ## initialization of W tensor
        W = torch.rand(self.rows, self.k)
        ## nmf model for all clients using dictionary comprehension
        nmf_models = {i: GaussianNMF_MI(B[i], self.k, critic_params) for i in range(self.K)}
        random.seed(18)
        for i in range(iters):     ## each iteration                  
            set_clients = random.sample(list(range(self.K)), m)  ## selection of client's id 
            temp = [num_cols[it1] for it1 in set_clients]            
            col_sum = sum(temp)
            tmp_W = torch.zeros(self.rows, self.k)
            ## for each selected client
            for j in set_clients:
                W1, _ = nmf_models[j].sgd_train(epoch, batch_size, lr, xi, W)
                tmp_W += W1.detach().data * num_cols[j]/col_sum
            W = tmp_W
            H_list = [nmf_models[it2].H.detach().data for it2 in range(self.K) ]
            ## after convergence
            H = torch.cat(H_list , 1)
            evaluation(H,labels)
        return W, H

    def server_train_epoch(self, labels, iters, C, epoch, batch_size, lr, xi):
        """
        federated learning by training over all clients in each iteration
        
        :params[in], iters, the number of interations for fedrated learning
        :params[in], list, each data's label
        :params[in], C, the fraction of clients for each iteration
        :params[in], epoch, the number of epochs for local SGD
        :params[in], batch_size, the batch_size for local SGD
        :params[in], lr, the learning rate for local SGD
        :params[in], xi, the mutual information for local 
        SGD
        
        :params[out]
        :
        """
        m = int(max(C*self.K,1))  ## batch_size for clients
        ## split whole dataset into K clients
        #torch.manual_seed(4)
        B = self.split_clients()
        ## column number for all clients -- using list comprehension
        num_cols = [it.size(1) for it in B]
        #Cols = num_cols.sum()
        ## initialization of W tensor
        W = torch.rand(self.rows, self.k)
        ## nmf model for all clients using dictionary comprehension
        nmf_models = {i: GaussianNMF_MI(B[i], self.k, critic_params) for i in range(self.K)}
        #random.seed(18)
        for i in range(iters):     ## each iteration -- training over all clients 
            ## split clients into batches
            batches = self.split_into_chunks(list(range(self.K)), m)  
            ### for loop over all batches of clients
            for set_clients in batches:   ## selection of client's id 
                temp = [num_cols[it1] for it1 in set_clients]            
                col_sum = sum(temp)
                tmp_W = torch.zeros(self.rows, self.k)
                ## for each selected client
                for j in set_clients:
                    W1, _ = nmf_models[j].sgd_train(epoch, batch_size, lr, xi, W)
                    tmp_W += W1.detach().data * num_cols[j]/col_sum
                W = tmp_W     ##update W
            ## list of H matrices for each iteration
            H_list = [nmf_models[it2].H.detach().data for it2 in range(self.K) ]
            ## merge into H
            H = torch.cat(H_list , 1)
            ### evaluate its performance on classification
            evaluation(H, labels)
        return W, H

In [34]:
nmf1 = Fed_NMF(A, 100,30)

In [35]:
dim1,dim2=A.shape[0],100
mi_params = dict(estimator='smile',critic='separable', baseline='unnormalized')

data_params = {
    'dim': dim1,
    'batch_size': 64,
    'cubic': None
}

critic_params = {
     
    'dim1': dim1,
    'dim2': dim2,
    'layers': 2,
    'embed_dim': 32,
    'hidden_dim': 256,
    'activation': 'relu',
}

critic = SeparableCritic(**critic_params)#.cuda()

#A = torch.randn(count_mat.shape[0],count_mat.shape[1])
critic_params = {
    'layers': 2,
    'embed_dim': 32,
    'hidden_dim': 256,
    'activation': 'relu',
}


W, H = nmf1.server_train(labels, iters =20, C=0.2, epoch=40, batch_size=128, lr=2.e-4, xi=0)

loss at Epoch  0   254087.75
loss at Epoch  1   168398.421875
loss at Epoch  2   125886.34375
loss at Epoch  3   100978.515625
loss at Epoch  4   84788.328125
loss at Epoch  5   73557.1328125
loss at Epoch  6   65394.4453125
loss at Epoch  7   59239.75
loss at Epoch  8   54470.19921875
loss at Epoch  9   50679.96484375
loss at Epoch  10   47623.0234375
loss at Epoch  11   45117.265625
loss at Epoch  12   43037.98828125
loss at Epoch  13   41297.10546875
loss at Epoch  14   39825.12890625
loss at Epoch  15   38572.50390625
loss at Epoch  16   37498.28515625
loss at Epoch  17   36571.90234375
loss at Epoch  18   35769.3203125
loss at Epoch  19   35071.6796875
loss at Epoch  20   34462.1953125
loss at Epoch  21   33928.078125
loss at Epoch  22   33458.578125
loss at Epoch  23   33044.78125
loss at Epoch  24   32679.310546875
loss at Epoch  25   32356.00390625
loss at Epoch  26   32069.48046875
loss at Epoch  27   31814.962890625
loss at Epoch  28   31588.767578125
loss at Epoch  29   3138

loss at Epoch  33   1519.640625
loss at Epoch  34   1519.5570068359375
loss at Epoch  35   1519.506103515625
loss at Epoch  36   1519.457275390625
loss at Epoch  37   1519.4085693359375
loss at Epoch  38   1519.38330078125
loss at Epoch  39   1519.3212890625
              precision    recall  f1-score   support

           0      0.853     0.522     0.647       966
           1      0.118     0.416     0.184       149

    accuracy                          0.508      1115
   macro avg      0.486     0.469     0.416      1115
weighted avg      0.755     0.508     0.586      1115

loss at Epoch  0   6181.7587890625
loss at Epoch  1   4041.716796875
loss at Epoch  2   3104.574462890625
loss at Epoch  3   2616.222412109375
loss at Epoch  4   2329.746337890625
loss at Epoch  5   2145.603515625
loss at Epoch  6   2019.164794921875
loss at Epoch  7   1927.225341796875
loss at Epoch  8   1857.8516845703125
loss at Epoch  9   1803.41796875
loss at Epoch  10   1760.50927734375
loss at Epoch  11 

loss at Epoch  6   1348.661376953125
loss at Epoch  7   1348.10400390625
loss at Epoch  8   1347.16552734375
loss at Epoch  9   1347.5599365234375
loss at Epoch  10   1347.36572265625
loss at Epoch  11   1346.77490234375
loss at Epoch  12   1346.5726318359375
loss at Epoch  13   1346.265869140625
loss at Epoch  14   1346.1595458984375
loss at Epoch  15   1345.942626953125
loss at Epoch  16   1345.664794921875
loss at Epoch  17   1345.6893310546875
loss at Epoch  18   1345.78369140625
loss at Epoch  19   1345.463623046875
loss at Epoch  20   1345.395751953125
loss at Epoch  21   1345.378662109375
loss at Epoch  22   1345.294189453125
loss at Epoch  23   1345.24560546875
loss at Epoch  24   1345.24462890625
loss at Epoch  25   1345.172607421875
loss at Epoch  26   1345.145751953125
loss at Epoch  27   1345.0474853515625
loss at Epoch  28   1345.028076171875
loss at Epoch  29   1345.068359375
loss at Epoch  30   1345.0101318359375
loss at Epoch  31   1344.962646484375
loss at Epoch  32   

loss at Epoch  24   1377.1927490234375
loss at Epoch  25   1377.2265625
loss at Epoch  26   1377.188720703125
loss at Epoch  27   1377.1904296875
loss at Epoch  28   1377.17724609375
loss at Epoch  29   1377.177490234375
loss at Epoch  30   1377.1773681640625
loss at Epoch  31   1377.1837158203125
loss at Epoch  32   1377.2012939453125
loss at Epoch  33   1377.2034912109375
loss at Epoch  34   1377.181396484375
loss at Epoch  35   1377.1734619140625
loss at Epoch  36   1377.17041015625
loss at Epoch  37   1377.1728515625
loss at Epoch  38   1377.1719970703125
loss at Epoch  39   1377.1734619140625
loss at Epoch  0   1380.6124267578125
loss at Epoch  1   1381.3973388671875
loss at Epoch  2   1379.09765625
loss at Epoch  3   1377.844482421875
loss at Epoch  4   1377.942138671875
loss at Epoch  5   1378.327392578125
loss at Epoch  6   1377.849853515625
loss at Epoch  7   1377.876708984375
loss at Epoch  8   1377.8868408203125
loss at Epoch  9   1377.9642333984375
loss at Epoch  10   1377.

loss at Epoch  38   1196.2454833984375
loss at Epoch  39   1196.240234375
loss at Epoch  0   1324.9906005859375
loss at Epoch  1   1324.6497802734375
loss at Epoch  2   1324.2763671875
loss at Epoch  3   1323.49951171875
loss at Epoch  4   1323.549072265625
loss at Epoch  5   1323.5543212890625
loss at Epoch  6   1322.789306640625
loss at Epoch  7   1322.995361328125
loss at Epoch  8   1323.2796630859375
loss at Epoch  9   1323.1947021484375
loss at Epoch  10   1323.498046875
loss at Epoch  11   1323.361083984375
loss at Epoch  12   1323.3780517578125
loss at Epoch  13   1323.6546630859375
loss at Epoch  14   1323.683349609375
loss at Epoch  15   1323.7537841796875
loss at Epoch  16   1323.817626953125
loss at Epoch  17   1323.7481689453125
loss at Epoch  18   1323.88818359375
loss at Epoch  19   1323.841796875
loss at Epoch  20   1323.751708984375
loss at Epoch  21   1323.89599609375
loss at Epoch  22   1323.9044189453125
loss at Epoch  23   1323.94287109375
loss at Epoch  24   1323.9

loss at Epoch  10   1388.955322265625
loss at Epoch  11   1389.107177734375
loss at Epoch  12   1389.238525390625
loss at Epoch  13   1389.350830078125
loss at Epoch  14   1389.456298828125
loss at Epoch  15   1389.5416259765625
loss at Epoch  16   1389.62841796875
loss at Epoch  17   1389.699951171875
loss at Epoch  18   1389.765380859375
loss at Epoch  19   1389.82275390625
loss at Epoch  20   1389.8748779296875
loss at Epoch  21   1389.923583984375
loss at Epoch  22   1389.964599609375
loss at Epoch  23   1390.000732421875
loss at Epoch  24   1390.034423828125
loss at Epoch  25   1390.065185546875
loss at Epoch  26   1390.091552734375
loss at Epoch  27   1390.115966796875
loss at Epoch  28   1390.1370849609375
loss at Epoch  29   1390.1572265625
loss at Epoch  30   1390.174072265625
loss at Epoch  31   1390.189453125
loss at Epoch  32   1390.20361328125
loss at Epoch  33   1390.2164306640625
loss at Epoch  34   1390.228271484375
loss at Epoch  35   1390.2379150390625
loss at Epoch  

loss at Epoch  25   1239.90625
loss at Epoch  26   1239.9151611328125
loss at Epoch  27   1239.8914794921875
loss at Epoch  28   1239.894775390625
loss at Epoch  29   1239.9149169921875
loss at Epoch  30   1239.9493408203125
loss at Epoch  31   1239.96240234375
loss at Epoch  32   1239.96337890625
loss at Epoch  33   1239.978759765625
loss at Epoch  34   1240.0106201171875
loss at Epoch  35   1240.005126953125
loss at Epoch  36   1239.9832763671875
loss at Epoch  37   1239.9981689453125
loss at Epoch  38   1240.0201416015625
loss at Epoch  39   1240.02197265625
loss at Epoch  0   1529.6455078125
loss at Epoch  1   1528.760009765625
loss at Epoch  2   1528.313232421875
loss at Epoch  3   1527.858154296875
loss at Epoch  4   1528.1275634765625
loss at Epoch  5   1528.100830078125
loss at Epoch  6   1529.130126953125
loss at Epoch  7   1528.0740966796875
loss at Epoch  8   1528.5941162109375
loss at Epoch  9   1528.3458251953125
loss at Epoch  10   1528.85888671875
loss at Epoch  11   152

loss at Epoch  35   1373.896240234375
loss at Epoch  36   1373.908447265625
loss at Epoch  37   1373.8975830078125
loss at Epoch  38   1373.898681640625
loss at Epoch  39   1373.899169921875
loss at Epoch  0   1323.5894775390625
loss at Epoch  1   1323.4425048828125
loss at Epoch  2   1321.3482666015625
loss at Epoch  3   1322.31884765625
loss at Epoch  4   1322.6090087890625
loss at Epoch  5   1322.4227294921875
loss at Epoch  6   1322.6824951171875
loss at Epoch  7   1322.0006103515625
loss at Epoch  8   1322.271728515625
loss at Epoch  9   1322.47607421875
loss at Epoch  10   1322.6971435546875
loss at Epoch  11   1322.804443359375
loss at Epoch  12   1322.9461669921875
loss at Epoch  13   1322.9658203125
loss at Epoch  14   1323.1116943359375
loss at Epoch  15   1322.989501953125
loss at Epoch  16   1323.16162109375
loss at Epoch  17   1323.15966796875
loss at Epoch  18   1323.1326904296875
loss at Epoch  19   1323.130615234375
loss at Epoch  20   1323.16259765625
loss at Epoch  21

loss at Epoch  9   1324.9462890625
loss at Epoch  10   1325.045654296875
loss at Epoch  11   1325.3675537109375
loss at Epoch  12   1325.2314453125
loss at Epoch  13   1325.2987060546875
loss at Epoch  14   1325.2000732421875
loss at Epoch  15   1325.1285400390625
loss at Epoch  16   1325.2579345703125
loss at Epoch  17   1325.2679443359375
loss at Epoch  18   1325.385498046875
loss at Epoch  19   1325.4296875
loss at Epoch  20   1325.5321044921875
loss at Epoch  21   1325.529296875
loss at Epoch  22   1325.5439453125
loss at Epoch  23   1325.5770263671875
loss at Epoch  24   1325.6005859375
loss at Epoch  25   1325.5789794921875
loss at Epoch  26   1325.57177734375
loss at Epoch  27   1325.6248779296875
loss at Epoch  28   1325.6873779296875
loss at Epoch  29   1325.668701171875
loss at Epoch  30   1325.68359375
loss at Epoch  31   1325.7198486328125
loss at Epoch  32   1325.6907958984375
loss at Epoch  33   1325.698486328125
loss at Epoch  34   1325.70361328125
loss at Epoch  35   13

loss at Epoch  26   1490.221923828125
loss at Epoch  27   1490.25146484375
loss at Epoch  28   1490.2200927734375
loss at Epoch  29   1490.2685546875
loss at Epoch  30   1490.2808837890625
loss at Epoch  31   1490.294677734375
loss at Epoch  32   1490.316162109375
loss at Epoch  33   1490.3519287109375
loss at Epoch  34   1490.3548583984375
loss at Epoch  35   1490.3536376953125
loss at Epoch  36   1490.36181640625
loss at Epoch  37   1490.349609375
loss at Epoch  38   1490.3636474609375
loss at Epoch  39   1490.36767578125
loss at Epoch  0   1323.646240234375
loss at Epoch  1   1324.6273193359375
loss at Epoch  2   1322.731201171875
loss at Epoch  3   1321.9564208984375
loss at Epoch  4   1321.8370361328125
loss at Epoch  5   1322.1629638671875
loss at Epoch  6   1322.3167724609375
loss at Epoch  7   1322.15185546875
loss at Epoch  8   1322.227294921875
loss at Epoch  9   1322.796875
loss at Epoch  10   1322.751708984375
loss at Epoch  11   1322.8326416015625
loss at Epoch  12   1322.

loss at Epoch  7   1239.2899169921875
loss at Epoch  8   1239.4578857421875
loss at Epoch  9   1239.35986328125
loss at Epoch  10   1239.55419921875
loss at Epoch  11   1239.7303466796875
loss at Epoch  12   1239.7041015625
loss at Epoch  13   1239.868408203125
loss at Epoch  14   1239.987060546875
loss at Epoch  15   1240.189453125
loss at Epoch  16   1240.2109375
loss at Epoch  17   1240.14697265625
loss at Epoch  18   1240.2655029296875
loss at Epoch  19   1240.28955078125
loss at Epoch  20   1240.402099609375
loss at Epoch  21   1240.3619384765625
loss at Epoch  22   1240.474365234375
loss at Epoch  23   1240.56591796875
loss at Epoch  24   1240.53466796875
loss at Epoch  25   1240.5340576171875
loss at Epoch  26   1240.54052734375
loss at Epoch  27   1240.59326171875
loss at Epoch  28   1240.58251953125
loss at Epoch  29   1240.6190185546875
loss at Epoch  30   1240.64599609375
loss at Epoch  31   1240.640380859375
loss at Epoch  32   1240.66552734375
loss at Epoch  33   1240.6761

loss at Epoch  32   1121.266357421875
loss at Epoch  33   1121.25244140625
loss at Epoch  34   1121.2630615234375
loss at Epoch  35   1121.2655029296875
loss at Epoch  36   1121.274658203125
loss at Epoch  37   1121.300537109375
loss at Epoch  38   1121.287109375
loss at Epoch  39   1121.293701171875
              precision    recall  f1-score   support

           0      0.853     0.524     0.649       966
           1      0.119     0.416     0.185       149

    accuracy                          0.509      1115
   macro avg      0.486     0.470     0.417      1115
weighted avg      0.755     0.509     0.587      1115

loss at Epoch  0   1374.76025390625
loss at Epoch  1   1376.156494140625
loss at Epoch  2   1375.51904296875
loss at Epoch  3   1375.49853515625
loss at Epoch  4   1375.8515625
loss at Epoch  5   1376.506591796875
loss at Epoch  6   1376.09814453125
loss at Epoch  7   1375.937744140625
loss at Epoch  8   1376.246826171875
loss at Epoch  9   1376.0955810546875
loss at E

loss at Epoch  4   1320.6243896484375
loss at Epoch  5   1319.796630859375
loss at Epoch  6   1320.654052734375
loss at Epoch  7   1320.8037109375
loss at Epoch  8   1320.81982421875
loss at Epoch  9   1320.28662109375
loss at Epoch  10   1320.708740234375
loss at Epoch  11   1320.5562744140625
loss at Epoch  12   1320.585693359375
loss at Epoch  13   1320.67138671875
loss at Epoch  14   1320.7220458984375
loss at Epoch  15   1320.79052734375
loss at Epoch  16   1320.930908203125
loss at Epoch  17   1321.015869140625
loss at Epoch  18   1320.9912109375
loss at Epoch  19   1321.0775146484375
loss at Epoch  20   1321.10400390625
loss at Epoch  21   1321.1055908203125
loss at Epoch  22   1321.1175537109375
loss at Epoch  23   1321.093505859375
loss at Epoch  24   1321.131103515625
loss at Epoch  25   1321.141357421875
loss at Epoch  26   1321.121826171875
loss at Epoch  27   1321.18115234375
loss at Epoch  28   1321.1702880859375
loss at Epoch  29   1321.18994140625
loss at Epoch  30   13

loss at Epoch  23   1194.9674072265625
loss at Epoch  24   1194.95849609375
loss at Epoch  25   1194.9970703125
loss at Epoch  26   1195.029296875
loss at Epoch  27   1195.055908203125
loss at Epoch  28   1195.04638671875
loss at Epoch  29   1195.0413818359375
loss at Epoch  30   1195.05859375
loss at Epoch  31   1195.054931640625
loss at Epoch  32   1195.0601806640625
loss at Epoch  33   1195.077392578125
loss at Epoch  34   1195.0950927734375
loss at Epoch  35   1195.1055908203125
loss at Epoch  36   1195.10498046875
loss at Epoch  37   1195.11376953125
loss at Epoch  38   1195.120849609375
loss at Epoch  39   1195.1224365234375
loss at Epoch  0   1265.1605224609375
loss at Epoch  1   1264.1766357421875
loss at Epoch  2   1263.590087890625
loss at Epoch  3   1263.308349609375
loss at Epoch  4   1262.7158203125
loss at Epoch  5   1263.513916015625
loss at Epoch  6   1263.048828125
loss at Epoch  7   1263.715576171875
loss at Epoch  8   1263.439208984375
loss at Epoch  9   1263.3847656

loss at Epoch  34   1374.1949462890625
loss at Epoch  35   1374.1915283203125
loss at Epoch  36   1374.2177734375
loss at Epoch  37   1374.2271728515625
loss at Epoch  38   1374.2265625
loss at Epoch  39   1374.242431640625
loss at Epoch  0   1264.423095703125
loss at Epoch  1   1265.2327880859375
loss at Epoch  2   1263.24267578125
loss at Epoch  3   1263.49560546875
loss at Epoch  4   1262.3172607421875
loss at Epoch  5   1262.400390625
loss at Epoch  6   1262.691650390625
loss at Epoch  7   1262.635009765625
loss at Epoch  8   1262.759521484375
loss at Epoch  9   1262.724853515625
loss at Epoch  10   1263.16845703125
loss at Epoch  11   1263.169921875
loss at Epoch  12   1263.106689453125
loss at Epoch  13   1263.1976318359375
loss at Epoch  14   1263.394775390625
loss at Epoch  15   1263.3046875
loss at Epoch  16   1263.292724609375
loss at Epoch  17   1263.410888671875
loss at Epoch  18   1263.547119140625
loss at Epoch  19   1263.5164794921875
loss at Epoch  20   1263.57678222656

loss at Epoch  9   1215.294921875
loss at Epoch  10   1215.001220703125
loss at Epoch  11   1215.522705078125
loss at Epoch  12   1215.4298095703125
loss at Epoch  13   1215.577392578125
loss at Epoch  14   1215.7806396484375
loss at Epoch  15   1215.6497802734375
loss at Epoch  16   1215.713623046875
loss at Epoch  17   1215.733154296875
loss at Epoch  18   1215.76220703125
loss at Epoch  19   1215.8720703125
loss at Epoch  20   1215.92333984375
loss at Epoch  21   1215.929931640625
loss at Epoch  22   1215.97900390625
loss at Epoch  23   1216.0
loss at Epoch  24   1216.0263671875
loss at Epoch  25   1216.0966796875
loss at Epoch  26   1216.07958984375
loss at Epoch  27   1216.111328125
loss at Epoch  28   1216.119140625
loss at Epoch  29   1216.1419677734375
loss at Epoch  30   1216.1689453125
loss at Epoch  31   1216.1724853515625
loss at Epoch  32   1216.1669921875
loss at Epoch  33   1216.16796875
loss at Epoch  34   1216.17431640625
loss at Epoch  35   1216.1793212890625
loss at 

loss at Epoch  28   1376.33544921875
loss at Epoch  29   1376.340087890625
loss at Epoch  30   1376.346435546875
loss at Epoch  31   1376.3643798828125
loss at Epoch  32   1376.3748779296875
loss at Epoch  33   1376.37109375
loss at Epoch  34   1376.375
loss at Epoch  35   1376.3720703125
loss at Epoch  36   1376.3880615234375
loss at Epoch  37   1376.38720703125
loss at Epoch  38   1376.3966064453125
loss at Epoch  39   1376.401123046875
loss at Epoch  0   1221.7762451171875
loss at Epoch  1   1221.673095703125
loss at Epoch  2   1220.3038330078125
loss at Epoch  3   1219.8951416015625
loss at Epoch  4   1220.2618408203125
loss at Epoch  5   1220.9361572265625
loss at Epoch  6   1220.1708984375
loss at Epoch  7   1220.3309326171875
loss at Epoch  8   1220.775390625
loss at Epoch  9   1220.757080078125
loss at Epoch  10   1220.706787109375
loss at Epoch  11   1220.728271484375
loss at Epoch  12   1220.81982421875
loss at Epoch  13   1221.2449951171875
loss at Epoch  14   1221.163574218

loss at Epoch  6   1213.126220703125
loss at Epoch  7   1213.423095703125
loss at Epoch  8   1213.6156005859375
loss at Epoch  9   1213.3729248046875
loss at Epoch  10   1213.4788818359375
loss at Epoch  11   1213.32470703125
loss at Epoch  12   1213.4437255859375
loss at Epoch  13   1213.6826171875
loss at Epoch  14   1213.6944580078125
loss at Epoch  15   1213.90234375
loss at Epoch  16   1213.87890625
loss at Epoch  17   1213.786376953125
loss at Epoch  18   1213.826904296875
loss at Epoch  19   1213.9202880859375
loss at Epoch  20   1213.880859375
loss at Epoch  21   1213.97265625
loss at Epoch  22   1213.9498291015625
loss at Epoch  23   1213.99169921875
loss at Epoch  24   1214.0635986328125
loss at Epoch  25   1214.0318603515625
loss at Epoch  26   1214.097412109375
loss at Epoch  27   1214.04931640625
loss at Epoch  28   1214.0684814453125
loss at Epoch  29   1214.105712890625
loss at Epoch  30   1214.10205078125
loss at Epoch  31   1214.135986328125
loss at Epoch  32   1214.15

loss at Epoch  21   1320.81396484375
loss at Epoch  22   1320.781005859375
loss at Epoch  23   1320.9044189453125
loss at Epoch  24   1320.87890625
loss at Epoch  25   1320.910400390625
loss at Epoch  26   1320.9482421875
loss at Epoch  27   1320.9375
loss at Epoch  28   1320.982177734375
loss at Epoch  29   1320.9462890625
loss at Epoch  30   1320.9471435546875
loss at Epoch  31   1320.9544677734375
loss at Epoch  32   1320.9835205078125
loss at Epoch  33   1320.9892578125
loss at Epoch  34   1320.988037109375
loss at Epoch  35   1321.007080078125
loss at Epoch  36   1321.0086669921875
loss at Epoch  37   1321.0054931640625
loss at Epoch  38   1321.0184326171875
loss at Epoch  39   1321.02294921875
loss at Epoch  0   1343.7109375
loss at Epoch  1   1343.6456298828125
loss at Epoch  2   1342.7125244140625
loss at Epoch  3   1342.1123046875
loss at Epoch  4   1342.319580078125
loss at Epoch  5   1341.94921875
loss at Epoch  6   1342.072509765625
loss at Epoch  7   1342.129150390625
loss

loss at Epoch  32   1282.8448486328125
loss at Epoch  33   1282.87109375
loss at Epoch  34   1282.8671875
loss at Epoch  35   1282.863525390625
loss at Epoch  36   1282.872802734375
loss at Epoch  37   1282.8828125
loss at Epoch  38   1282.876953125
loss at Epoch  39   1282.884521484375
loss at Epoch  0   1323.30224609375
loss at Epoch  1   1321.76708984375
loss at Epoch  2   1321.8446044921875
loss at Epoch  3   1320.7996826171875
loss at Epoch  4   1321.265380859375
loss at Epoch  5   1320.927978515625
loss at Epoch  6   1320.794189453125
loss at Epoch  7   1320.277099609375
loss at Epoch  8   1320.6337890625
loss at Epoch  9   1320.64501953125
loss at Epoch  10   1320.938232421875
loss at Epoch  11   1321.10888671875
loss at Epoch  12   1320.990966796875
loss at Epoch  13   1321.009521484375
loss at Epoch  14   1320.84765625
loss at Epoch  15   1321.1116943359375
loss at Epoch  16   1321.250244140625
loss at Epoch  17   1321.3194580078125
loss at Epoch  18   1321.264892578125
loss a

loss at Epoch  7   1231.6912841796875
loss at Epoch  8   1231.358642578125
loss at Epoch  9   1231.8426513671875
loss at Epoch  10   1231.767578125
loss at Epoch  11   1231.827880859375
loss at Epoch  12   1232.07666015625
loss at Epoch  13   1232.18115234375
loss at Epoch  14   1232.2239990234375
loss at Epoch  15   1232.21044921875
loss at Epoch  16   1232.3045654296875
loss at Epoch  17   1232.3818359375
loss at Epoch  18   1232.508544921875
loss at Epoch  19   1232.47216796875
loss at Epoch  20   1232.46630859375
loss at Epoch  21   1232.492919921875
loss at Epoch  22   1232.6011962890625
loss at Epoch  23   1232.570556640625
loss at Epoch  24   1232.608642578125
loss at Epoch  25   1232.6768798828125
loss at Epoch  26   1232.6488037109375
loss at Epoch  27   1232.635498046875
loss at Epoch  28   1232.6724853515625
loss at Epoch  29   1232.6956787109375
loss at Epoch  30   1232.711669921875
loss at Epoch  31   1232.716064453125
loss at Epoch  32   1232.731689453125
loss at Epoch  3

loss at Epoch  27   1328.1900634765625
loss at Epoch  28   1328.2078857421875
loss at Epoch  29   1328.2041015625
loss at Epoch  30   1328.229248046875
loss at Epoch  31   1328.232177734375
loss at Epoch  32   1328.2391357421875
loss at Epoch  33   1328.2540283203125
loss at Epoch  34   1328.268798828125
loss at Epoch  35   1328.2745361328125
loss at Epoch  36   1328.2744140625
loss at Epoch  37   1328.272705078125
loss at Epoch  38   1328.2823486328125
loss at Epoch  39   1328.3031005859375
              precision    recall  f1-score   support

           0      0.854     0.525     0.650       966
           1      0.119     0.416     0.185       149

    accuracy                          0.510      1115
   macro avg      0.486     0.470     0.418      1115
weighted avg      0.755     0.510     0.588      1115

loss at Epoch  0   1208.8160400390625
loss at Epoch  1   1208.2530517578125
loss at Epoch  2   1209.5950927734375
loss at Epoch  3   1210.267822265625
loss at Epoch  4   1209.8

loss at Epoch  0   1324.048583984375
loss at Epoch  1   1323.12158203125
loss at Epoch  2   1322.5458984375
loss at Epoch  3   1323.298583984375
loss at Epoch  4   1323.7313232421875
loss at Epoch  5   1323.61474609375
loss at Epoch  6   1323.3167724609375
loss at Epoch  7   1322.570556640625
loss at Epoch  8   1322.7291259765625
loss at Epoch  9   1322.794677734375
loss at Epoch  10   1322.98779296875
loss at Epoch  11   1323.055419921875
loss at Epoch  12   1322.8994140625
loss at Epoch  13   1322.95703125
loss at Epoch  14   1323.21044921875
loss at Epoch  15   1323.4781494140625
loss at Epoch  16   1323.305419921875
loss at Epoch  17   1323.427734375
loss at Epoch  18   1323.5086669921875
loss at Epoch  19   1323.568603515625
loss at Epoch  20   1323.6171875
loss at Epoch  21   1323.58349609375
loss at Epoch  22   1323.61328125
loss at Epoch  23   1323.71875
loss at Epoch  24   1323.75732421875
loss at Epoch  25   1323.74462890625
loss at Epoch  26   1323.7503662109375
loss at Epoc

In [30]:
import random
help(random.sample)

Help on method sample in module random:

sample(population, k) method of random.Random instance
    Chooses k unique random elements from a population sequence or set.
    
    Returns a new list containing elements from the population while
    leaving the original population unchanged.  The resulting list is
    in selection order so that all sub-slices will also be valid random
    samples.  This allows raffle winners (the sample) to be partitioned
    into grand prize and second place winners (the subslices).
    
    Members of the population need not be hashable or unique.  If the
    population contains repeats, then each occurrence is a possible
    selection in the sample.
    
    To choose a sample in a range of integers, use range as an argument.
    This is especially fast and space efficient for sampling from a
    large population:   sample(range(10000000), 60)



coherence

In [125]:
dic0 = top_keywords(W, features, num=20)


## compute the coherence score for each topic
coherence_vec = []
for i in range(W.shape[1]):  
    coherence_vec.append(coherence(dic0[i], model_glove))

np.mean(coherence_vec)   ## the mean coherence score of all topics


0.5654555

# SVM classifier

In [126]:

    indices = list(range(len(labels)))   ## indices of documents
    
    ## split data into train and test
    ind_train, ind_test, y_train, y_test = train_test_split(
        indices, labels, test_size=0.2, random_state=2021, stratify=labels)
    H_new = H.detach().numpy()
    x_train, x_test = H_new[:, ind_train],H_new[:, ind_test]
    
    ## encode labels to integers
    Encoder = LabelEncoder()
    Y_train = Encoder.fit_transform(y_train)
    Y_test = Encoder.fit_transform(y_test)


    # Classifier - Algorithm - SVM -- linear kernel
    # fit the training dataset on the classifier
    SVM = svm.SVC(C=1., kernel='linear', degree=3, gamma='auto', random_state=82, class_weight='balanced')
    SVM.fit(x_train.T, Y_train)# predict the labels on validation dataset
    predictions_SVM = SVM.predict(x_test.T) # make predictions
    print(classification_report(Y_test, predictions_SVM, digits=3))


              precision    recall  f1-score   support

           0      0.514     0.603     0.555      1067
           1      0.519     0.429     0.469      1066

    accuracy                          0.516      2133
   macro avg      0.516     0.516     0.512      2133
weighted avg      0.516     0.516     0.512      2133



'\n#A = A.detach().numpy()\n##print(H.shape)\n#\nx_train, x_test = A[:, ind_train],A[:, ind_test]\n'