In [None]:
%%capture
!pip install bertopic
!pip install -U sentence-transformers

In [None]:
import pandas as pd 
from bertopic import BERTopic
import torch
import torch.nn as nn
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)
from sentence_transformers import SentenceTransformer

import nltk
nltk.download('punkt')

Mounted at /content/gdrive/
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
%cd "/content/gdrive/MyDrive/Yu-Wen-Chen-PhD"
%run "./constants.ipynb"
%run "./utils.ipynb"
%run "./modules.ipynb"

/content/gdrive/MyDrive/Yu-Wen-Chen-PhD
Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
/content/gdrive/MyDrive/Yu-Wen-Chen-PhD
Using device: cuda
Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).
/content/gdrive/MyDrive/Yu-Wen-Chen-PhD
Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
/content/g

## **Topic Embedding Module**

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)


class TopicEmbeddingLayer(nn.Module):

    # TODO: use sentence transfomer
    def __init__(self, emb_model='all-mpnet-base-v2'): #emb_model='roberta-base'
        super().__init__()
        #self.document_embeddings = SentenceTransformerDocumentEmbeddings(emb_model)
        
        #self.glove_embedding = WordEmbeddings('glove')
        #self.document_embeddings = DocumentRNNEmbeddings([self.glove_embedding], hidden_size=300)
        self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device=device)
        
    # prob: ((# of examples, num_topics))
    # TODO: check how to get d_doc dynamically? 
    def forward(self, texts, probs, d_doc=768):
        
        doc_emb = torch.zeros([len(texts), d_doc], dtype=torch.float64, device=device)
        
        #for i,doc in enumerate(texts):
          #doc = Sentence(doc)
          #self.document_embeddings.embed(doc)
          #doc_emb[i] = doc.embedding
        
        for i,doc in enumerate(texts):
          embeddings = self.model.encode(doc)
          doc_emb[i] = torch.from_numpy(embeddings)

        probs = torch.from_numpy(probs).to(dtype=torch.float64, device=device)
       
        # # weighted topic embeddings: (n_topic, 128) 
        topic_emb = torch.matmul(torch.t(probs),doc_emb)

        return topic_emb

Using device: cuda


##**Load the data**##


In [None]:
bmap, cmap, wmap, revmap, span_data, span_size, target_pred_ix_set, We = torch.load('./larn_data_smaller.pt')

In [None]:
# TODO: put us-cn, us-ru in
interesting_pair_num = [0, 1, 4, 5, 6, 7, 8, 16]
filter_entity_name = [us_cn, us_sy, us_ca, us_ru, us_in, us_uk, us_jp, cn_in]
entity_name = ["us-cn", "us-sy", "us-ca", "us-ru", "us-in", "us-uk", "us-jp", "cn-in"]

## **BERTopic & create topic embeddings** ##
We create topic embeddings only for interesting pairs


In [None]:
for num, filter_entity, entity  in zip(interesting_pair_num, filter_entity_name, entity_name): 

  topic_model = BERTopic(language="english", calculate_probabilities=True, n_gram_range=(1,3), nr_topics=10)
  print (f"We are now working on {entity}.")

  # preprocess data for BERTopic 
  texts = preprocess_bertopic(span_data[num][5], filter_entity)

  topics, probs = topic_model.fit_transform(texts)

  # save topic & prob into the file
  data = [topics, probs]

  with open(f"./Topic/prob-{entity}.pk","wb") as f:
      pickle.dump(data, f)  

  n_topics = probs.shape[1]
  topic_embedding_layer = TopicEmbeddingLayer()

  # (n_topic, 768dim)
  print("Generating topic embedding...")
  topic_embs = topic_embedding_layer(texts, probs)
  print(f"Finished! The shape of topic embedding is {topic_embs.shape}")

  # save the embeddings
  with open(f"./Topic/embedding-{entity}.pk","wb") as f:
      pickle.dump(topic_embs, f)

  # save topic info 
  with open(f"./Topic/info-{entity}.pk","wb") as f:
      pickle.dump(topic_model.get_topics(), f)

  topic_model.visualize_barchart()


We are now working on us-cn.
Starting preprocessing!
We are at 0 documents!
We are at 10000 documents!
We are at 20000 documents!
We are at 30000 documents!
We are at 40000 documents!
We are at 50000 documents!


Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Generating topic embedding...
Finished! The shape of topic embedding is torch.Size([10, 768])
We are now working on us-sy.
Starting preprocessing!
We are at 0 documents!
We are at 10000 documents!
Generating topic embedding...
Finished! The shape of topic embedding is torch.Size([10, 768])
We are now working on us-ca.
Starting preprocessing!
We are at 0 documents!
We are at 10000 documents!
We are at 20000 documents!
We are at 30000 documents!
Generating topic embedding...
Finished! The shape of topic embedding is torch.Size([10, 768])
We are now working on us-ru.
Starting preprocessing!
We are at 0 documents!
We are at 10000 documents!
We are at 20000 documents!
We are at 30000 documents!
We are at 40000 documents!
Generating topic embedding...
Finished! The shape of topic embedding is torch.Size([10, 768])
We are now working on us-in.
Starting preprocessing!
We are at 0 documents!
We are at 10000 documents!
We are at 20000 documents!
We are at 30000 documents!
Generating topic embedd