In [1]:
import nltk
from nltk.util import ngrams
from collections import Counter
import numpy as np
import pdb
from moseq2_nlp.utils import load_data

In [2]:
model_path='/media/data_cifs/matt/abraira_data/2021-02-19_Meloxicam/rST_model_1000.p'
index_path='/media/data_cifs/matt/abraira_data/2021-02-19_Meloxicam/moseq2-index.role.yaml'
emissions=True
custom_groupings=[]
num_syllables=70
num_transitions=100
bad_syllables=[-5]
labels, usages, transitions, sentences, bigram_sentences = load_data(model_path,
                                                                     index_path,
                                                                       emissions=emissions,
                                                                       custom_groupings=custom_groupings,
                                                                       num_syllables=num_syllables,
                                                                       num_transitions=num_transitions,
                                                                       bad_syllables=bad_syllables)


80it [00:15,  5.19it/s]


In [3]:
def mikolov_score(count_super, sub_gram_counts, min_count, vocab_size):
    den = np.prod(sub_gram_counts)
    return ((count_super - min_count)) * vocab_size / (den)

In [75]:
def contrastive_scoring(corpus1, corpus2, max_n, min_count, vocab_size, threshold=1.0):
    
    if type(min_count) != list:
        min_count = (max_n - 1) * [min_count]
    if type(threshold) != list:
        threshold = (max_n - 1) * [threshold]
    ngram_dicts1 = [corpus_ngrams(corpus1,m+1) for m in range(max_n)]
    ngram_dicts2 = [corpus_ngrams(corpus2,m+1) for m in range(max_n)]

    contrastive_scoring_dict = {}
    for i, m in enumerate(range(1,max_n)):
        th = threshold[i]
        mc = min_count[i]
        
        m_plus_gram_dict1 = ngram_dicts1[m]
        m_plus_gram_dict2 = ngram_dicts2[m]
        # For each bigram in the first corpus

        for mpg1,mpg_count1 in m_plus_gram_dict1.items():
            
            sub_grams = get_subsequences(mpg1)
            all_sg1_counts = []
            for sg in sub_grams:
                k = len(sg)
                all_sg1_counts.append(ngram_dicts1[k-1][sg])
                
            mks1 = mikolov_score(mpg_count1, all_sg1_counts, mc, vocab_size)
            
            try:
                mpg_count2 = m_plus_gram_dict2[mpg1]
            except:
                mpg_count2 = 1e-3
                
            all_sg2_counts = []
            for sg in sub_grams:
                try:
                    k = len(sg)
                    all_sg2_counts.append(ngram_dicts2[k-1][sg])
                except:
                    all_sg2_counts.append(1e-3)
            mks2 = mikolov_score(mpg_count2, all_sg2_counts, mc, vocab_size)
            
            if  mks1 > mks2 and mks1 > threshold[i]:
                contrastive_scoring_dict[mpg1] = mks1
    return contrastive_scoring_dict

In [5]:
import itertools
 
def get_subsequences(array): 
    subseqs = []
    L = len(array)
    for l in range(1,L):
        for i in range(L):
            ss = array[i:i+l]
            if ss not in subseqs:
                subseqs.append(ss)
    return subseqs

In [6]:
def corpus_ngrams(corpus, n):
    ngram_dict = {}
    for document in corpus:
        fdist = nltk.FreqDist(ngrams(document,n))
        ngram_dict = dict(Counter(ngram_dict) + Counter(fdist))
    return ngram_dict

In [6]:
corpus1 = []
corpus2 = []
p1 = .25
p2 = .01
for _ in range(100):
    doc1 = [str(i) for i in np.random.randint(100,size=1000)]
    doc2 = [str(i) for i in np.random.randint(100,size=1000)]
    for j,k in enumerate(doc1[:-2]):
        if np.random.rand() < p1:
#             pdb.set_trace()
            doc1[j] = '1'
            doc1[j+1] = '2'
            doc1[j+2] = '3'
        if np.random.rand() < p2:
            doc2[j] = '1'
            doc2[j+1] = '2'
            doc2[j+2] = '3'
    corpus1.append(doc1)
    corpus2.append(doc2)

In [84]:
vocab_size = 70
min_count = 1
threshold=[.01, .0001, .00005]
max_n = 4

all_cngs = []
for group in range(5):
    corpus1 = [sentence for i,sentence in enumerate(sentences) if labels[i] == group]
    corpus2 = [sentence for i,sentence in enumerate(sentences) if labels[i] != group]

    contrastive_scoring_dict = contrastive_scoring(corpus1,corpus2,max_n, min_count, vocab_size, threshold)
    all_cngs.append(set([key for key in contrastive_scoring_dict.keys()]))

pruned_cngs = []
for group in range(5):
    counter_groups = [gr for g, gr in enumerate(all_cngs) if g != group]
    union_counter_groups = []
    for cg in counter_groups:
        union_counter_groups += cg
    union_counter_groups = set(union_counter_groups)
    pruned_cngs.append(all_cngs[group].difference(union_counter_groups))

In [85]:
for pcng in pruned_cngs:
    print(pcng)
    print('\n')

{('62', '42'), ('62', '56'), ('11', '63'), ('59', '49', '59'), ('63', '11'), ('33', '62'), ('33', '49')}


{('61', '59', '61'), ('20', '65'), ('58', '61')}


{('43', '55'), ('58', '32', '58'), ('25', '0'), ('29', '34'), ('53', '50'), ('48', '4'), ('21', '22'), ('5', '21'), ('31', '56'), ('23', '58', '23'), ('45', '52'), ('17', '28'), ('25', '9'), ('59', '43'), ('29', '37'), ('34', '50'), ('51', '1'), ('7', '13'), ('47', '37'), ('26', '24'), ('31', '36'), ('7', '47'), ('37', '30'), ('33', '38'), ('13', '11'), ('57', '55'), ('28', '17'), ('35', '45'), ('27', '22'), ('2', '36'), ('38', '18'), ('16', '46'), ('48', '53'), ('57', '60', '57'), ('19', '32'), ('2', '13'), ('28', '41'), ('30', '53'), ('53', '31'), ('50', '40'), ('27', '40'), ('60', '26'), ('53', '40'), ('45', '56'), ('10', '46'), ('9', '26'), ('53', '50', '53'), ('18', '50'), ('59', '60', '59'), ('49', '5'), ('12', '27'), ('56', '8'), ('28', '62'), ('22', '21'), ('15', '22'), ('25', '5'), ('62', '40'), ('43', '33'), ('53', '21')

In [38]:
for cng in all_cngs:
    print(cng)
    print('\n')

{('28', '12'), ('41', '51'), ('17', '46'), ('47', '29'), ('40', '53'), ('36', '34'), ('31', '4'), ('31', '1'), ('10', '50'), ('14', '59'), ('61', '39'), ('34', '36'), ('7', '60'), ('1', '14'), ('27', '28'), ('12', '28'), ('56', '4'), ('16', '47'), ('56', '1'), ('5', '35'), ('44', '47'), ('0', '2'), ('52', '1'), ('59', '14'), ('30', '37'), ('52', '48'), ('29', '16'), ('31', '24'), ('34', '4'), ('48', '52'), ('42', '53'), ('15', '10'), ('32', '43'), ('53', '56'), ('36', '51'), ('9', '19'), ('59', '49'), ('35', '1'), ('3', '44'), ('45', '38'), ('39', '60'), ('55', '7'), ('23', '43'), ('32', '7'), ('42', '48'), ('39', '59'), ('17', '15'), ('42', '4'), ('38', '32'), ('13', '9'), ('24', '1'), ('41', '55'), ('7', '44'), ('49', '6'), ('33', '14'), ('47', '6'), ('25', '49'), ('54', '18'), ('9', '7'), ('58', '23'), ('21', '30'), ('7', '55'), ('3', '11'), ('60', '39'), ('19', '51'), ('7', '9'), ('41', '16'), ('55', '57'), ('5', '23'), ('15', '41'), ('50', '31'), ('23', '45'), ('24', '31'), ('20',