In [33]:
import numpy as np
import pathlib

def computePurity(goldenClusters, outputClusters, numDocs):
    count = 0
    for label in outputClusters:
        docs = outputClusters[label]
        correctAssignedDocNum = 0
        for goldenLabel in goldenClusters:
            goldenDocs = goldenClusters[goldenLabel]
            outputDocs = docs.intersection(goldenDocs)
            if len(outputDocs) >= correctAssignedDocNum:
                correctAssignedDocNum = len(outputDocs)        
        count += correctAssignedDocNum
    value = count * 1.0 / numDocs
    print("\tPurity accuracy: " + str(value))
    return value

def computeNMIscore(goldenClusters, outputClusters, numDocs):
    MIscore = 0.0
    for label in outputClusters:
        docs = outputClusters[label]
        for goldenLabel in goldenClusters:
            goldenDocs = goldenClusters[goldenLabel]
            outputDocs = docs.intersection(goldenDocs)
            if len(outputDocs) == 0:
                continue
            MIscore += (len(outputDocs) / numDocs) \
                * np.log(len(outputDocs) * numDocs / ( len(docs) * len(goldenDocs)))

    entropy = 0.0
    for label in outputClusters:
        docs = outputClusters[label]
        entropy += (-1.0 * len(docs) / numDocs) * np.log(1.0 * len(docs) / numDocs)    

    for goldenLabel in goldenClusters:
        docs = goldenClusters[goldenLabel]
        entropy += (-1.0 * len(docs) / numDocs) * np.log(1.0 * len(docs) / numDocs)        

    value = 2 * MIscore / entropy
    print("\tNMI score: " + str(value))
    return value


In [34]:
def evaluate(pathGoldenLabelsFile):
    
    goldenLabels = [label.strip() for label in open(pathGoldenLabelsFile).readlines()]     
    goldenClusters = {}
    ind = 0   
    numDocs = len(goldenLabels)
    
    for label in goldenLabels:
        ids = set()
        if label in goldenClusters:
            ids = goldenClusters[label]
        ids.add(ind)
        ind += 1
        goldenClusters[label] = ids     
    
    with open("results/PurityNMI.accuracy", "w") as results_file: 
        results_file.write("Golden-labels in: " + pathGoldenLabelsFile + "\n\n")
        purity = []
        nmi = []        

        for filepath in pathlib.Path("results").glob('**/*'):
            if not filepath.name.endswith("theta"):
                continue   
            docsTopicProbs = [ np.fromstring(docTopicProb, sep=' ') for docTopicProb in open(filepath).readlines()]
            if len(docsTopicProbs) != numDocs:
                print("Error: the number of documents is different to the number of labels!")
                raise Exception() 
            docLabelOutput = {i: "Topic_" + str(np.argmax(pros)) for i, pros in enumerate(docsTopicProbs)}
            outputClusters = {}
            for i in docLabelOutput:
                label = docLabelOutput[i]
                ids = set()
                if label in outputClusters:
                    ids = outputClusters[label]
                ids.add(i)
                outputClusters[label] = ids
            value = computePurity(goldenClusters, outputClusters, numDocs)
            results_file.write("\tPurity: " + str(value) + "\n")
            purity.append(value)
            value = computeNMIscore(goldenClusters, outputClusters, numDocs)
            results_file.write("\tNMI: " + str(value) + "\n")
            nmi.append(value)
            
        if len(purity) == 0 or len(nmi) == 0:
            print("Error: There is no file ending with theta")
            raise Exception()     
            
        purityValues = np.array(purity)
        nmiValues = np.array(nmi)

        results_file.write("\n---\nMean purity: " + str(np.mean(purityValues))
            + ", standard deviation: " + str(np.std(purityValues)))

        results_file.write("\nMean NMI: " + str(np.mean(nmiValues))
            + ", standard deviation: " + str(np.std(nmiValues)))

        print("\n---\nMean purity: " + str(np.mean(purityValues))
            + ", standard deviation: " + str(np.std(purityValues)))

        print("\nMean NMI: " + str(np.mean(nmiValues))
            + ", standard deviation: " + str(np.std(nmiValues)))
