In [1]:
import pyro
import torch
import json
import pickle
import os 
from prodslda_cls import ProdSLDA

from sklearn.feature_extraction.text import CountVectorizer

MODEL_PATH = '/burg/nlp/users/zfh2000/style_results/pos_bigrams/2023-12-14_17_54_45/model_epoch5_20914.218841552734.pt'
DATA_DIR_PATH = '/burg/nlp/users/zfh2000/style_results/pos_bigrams/maxdf0.5_mindf5_DATA'

with open(os.path.join(DATA_DIR_PATH, 'bows.pickle'), 'rb') as in_file:
    bows = pickle.load(in_file)
        
with open(os.path.join(DATA_DIR_PATH, 'meta_vectorized.pickle'), 'rb') as in_file:
    meta_vectorized = pickle.load(in_file)    

with open(os.path.join(DATA_DIR_PATH, "raw_text.json"), 'r') as in_file:
    raw_text = json.load(in_file)    

with open(os.path.join(DATA_DIR_PATH, "authors_json.json"), 'r') as in_file:
    authors_json = json.load(in_file)    

with open(os.path.join(DATA_DIR_PATH, "meta_feature_to_names.json"), 'r') as in_file:
    meta_feature_to_names = json.load(in_file)

with open(os.path.join(DATA_DIR_PATH, "vectorizer.pickle"), 'rb') as in_file:
    vectorizer = pickle.load(in_file)

In [2]:

pyro.clear_param_store()

prodsdla = torch.load(MODEL_PATH)
prodsdla.eval()


ProdSLDA(
  (encoder): GeneralEncoder(
    (drop): Dropout(p=0, inplace=False)
    (fc1s): ModuleDict(
      (doc): Linear(in_features=9267, out_features=64, bias=True)
    )
    (fc2): Linear(in_features=64, out_features=64, bias=True)
    (fcmu): Linear(in_features=64, out_features=10, bias=True)
    (fclv): Linear(in_features=64, out_features=10, bias=True)
    (bnmu): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (bnlv): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  )
  (decoder): Decoder(
    (beta): Linear(in_features=10, out_features=9267, bias=False)
    (bn): BatchNorm1d(9267, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (drop): Dropout(p=0, inplace=False)
  )
  (style_encoder): GeneralEncoder(
    (drop): Dropout(p=0, inplace=False)
    (fc1s): ModuleDict(
      (pos_bigrams): Linear(in_features=324, out_features=64, bias=True)
    )
    (fc2): Linear(in_features=64, out_featur

In [55]:
from tqdm import tqdm
def top_beta_document(model, vectorizer, top_k=20):
    betas_document = model.beta_document()
    features_to_betas = {}
    idx_to_name = {v:k for k,v in vectorizer.vocabulary_.items()}
    for feature, logits in betas_document.items():
        features_to_betas[feature] = []
        num_features = logits.shape[0]
        top_results = torch.topk(logits, top_k, dim=-1)
        
        ids = top_results.indices.cpu().numpy()
        values = top_results.values.cpu().numpy()
        
        for i in tqdm(range(num_features)):
            features_to_betas[feature].append({'values':values[i], 'top':[idx_to_name[idx] for idx in ids[i]]})
                
    return features_to_betas

def top_beta_meta(model, meta_feature_to_names, top_k=20):
    betas_metas = model.beta_meta()
    features_to_betas = {}
    for feature, logits in betas_metas.items():
        idx_to_name = {i:k for i,k in enumerate(meta_feature_to_names[feature])}
        features_to_betas[feature] = []
        num_features = logits.shape[0]
        top_results = torch.topk(logits, top_k, dim=-1)
        ids = top_results.indices.cpu().numpy()
        values = top_results.values.cpu().numpy()
        for i in tqdm(range(num_features)):
            features_to_betas[feature].append({'values':values[i], 'top':[idx_to_name[idx] for idx in ids[i]]})
        
    return features_to_betas 


In [60]:
top_words_per_latent = top_beta_document(prodsdla, vectorizer,  top_k=20)
top_meta_per_latent = top_beta_meta(prodsdla, meta_feature_to_names, top_k=20)

print('Document Term Info')
for latent, top in top_words_per_latent.items():
    print(f'\t{latent} ({len(top)}):')
    for i, results in enumerate(top):
        print(f'\t\t {latent} ({i}):\n{results["top"]}')
        print()

print('Meta Var Info')
for latent, top in top_meta_per_latent.items():

    print(f'\t{latent} ({len(top)}):')
    for i, results in enumerate(top):
        print(f'\t\t {latent} ({i}):\n{results["top"]}')
        print()


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 2612.79it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 87381.33it/s]

Document Term Info
	beta_topic (10):
		 beta_topic (0):
['fine', 'hope', 'sally', '28', '09', 'stuff', '08', 'hourahead', 'night', 'attached', 'jim', 'tickets', 'energy', 'making', 'presentation', 'kitchen', 'manual', '713', 'linda', 'original']

		 beta_topic (1):
['silva', 'geae', '962', '7566', 'eb3892', 'manis', 'x33278', '1575', 'freyre', 'bambos', 'ccampbell', 'giuseppe', 'vande', 'kupiecki', '4727', 'noncore', 'kayne', 'mckinsey', 'nassos', 'centilli']

		 beta_topic (2):
['fax', '713', 'john', 'market', 'bcc', 'help', 'yes', 'needs', 'hey', 'request', 'credit', 'questions', 'deals', 'address', 'updates', 'message', 'email', '646', 'management', 'www']

		 beta_topic (3):
['print', 'attachment', 'enron', 'report', 'asked', '07', 'format', 'sppc', 'wrong', 'job', 'waiting', 'kate', 'information', 'language', 'retain', 'time', 'sorry', 'attached', '00', 'basis']

		 beta_topic (4):
['fyi', 'need', 'just', 'meeting', 'sent', 'thanks', 'ferc', '2000', 'mark', 'revised', 'going', 'dr




In [52]:

with torch.no_grad():
    for text, author, bow, meta in zip(raw_text['training'], authors_json['training']): #, bows['training'], meta_vectorized['training']):
        print(author, text)
        break
        # print(text)
        # print(author)
        # print(bow)
        # print(meta)
        # print('------------------')
        # result =  F.softmax(prodsdla.guide(bow.unsqueeze(0), meta.unsqueeze(0))[1])
        # print(result)
        # print('------------------')
        # print('------------------'


# label_to_topic = {}
# label_to_max = {}

# for d, text, encoded_label, label in zip(docs, data['data'], labels, data['labels']):
#     if label not in label_to_topic:
#         label_to_topic[label] = []
#         label_to_max[label] = []
#     # print(d)
#     # print(label)
#     # print('------------------')
#     result =  F.softmax(prodLDA.guide(d.unsqueeze(0), encoded_label.unsqueeze(0))[1])

#     # argmax = torch.argmax(result)
#     # print(argmax)
#     label_to_max[label].append(result[0].detach().cpu().numpy())
#     print(label, result)
#     label_to_topic[label].append((text,result))


# for label in label_to_topic:
#     print(label, np.mean(label_to_max[label]))


ValueError: not enough values to unpack (expected 4, got 2)