## DEMO

In [4]:
from sentence_transformers import SentenceTransformer, util
import numpy as np
from utils import create_document_embeddings, get_document_embeddings

In [2]:
PATH_TO_EMBEDDINGS = r"../data/embeddings"

In [5]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [57]:
create_document_embeddings(model, [["Hello, world!", "Goodbye. This is the end."], ["This is a test", "This is not a test"]], PATH_TO_EMBEDDINGS, "demo")


array([[-2.09693518e-02,  5.88474944e-02,  2.82677300e-02,
         2.62138695e-02, -9.06362943e-03, -7.64047354e-02,
         5.53913191e-02, -3.68839912e-02, -7.84823671e-04,
        -1.27433864e-02,  1.58711728e-02,  2.14681029e-02,
         6.23353664e-03,  8.90833605e-03, -7.68442303e-02,
         8.39350000e-03, -1.81107409e-02,  1.22036822e-02,
        -1.23330548e-01,  2.18525846e-02, -2.26291306e-02,
         4.76487502e-02, -2.50382796e-02,  3.09831463e-02,
        -5.96614145e-02,  3.93032730e-02,  4.28193845e-02,
         2.59722732e-02, -1.50488382e-02, -6.17700703e-02,
         2.94973440e-02,  4.22532260e-02,  3.13867591e-02,
        -2.40303874e-02,  9.24553722e-03,  6.24796227e-02,
        -6.07810020e-02, -7.88905621e-02,  2.15569939e-02,
         7.95919355e-03,  1.69848315e-02, -1.63327418e-02,
        -4.67240810e-02, -1.32707031e-02,  2.02383064e-02,
         2.97268182e-02, -2.98732631e-02, -1.25523005e-02,
         8.27551261e-02,  5.37742898e-02, -1.58221647e-0

In [58]:
l_embs = get_document_embeddings(r"../data/embeddings/demo.npy")

## Testing REGEX

In [None]:
import re

In [None]:
content = 'compression quality, or color which are capable enough to dramatically modify or restructure the underlying image properties111https://www.nytimes.com/interactive/2023/06/28/technology/ai-detection-midjourney-stable-diffusion-dalle.html?auth=register-google&utm_source=pocket-newtab-intl-en. Image forgeries are almost always followed by these operations like adding some noise or applying JPEG compression to obscure any traces of image generation or forgeries and thereby'

# content = re.sub(r'\[(\d+[,]?)+\]([^\w\t])', r'\2', content) # matches when punctuation follows the citation
# content = re.sub(r'\[(\d+[,]?)+\](\w+)', r' \2', content)
content = re.sub(r'(\d+)?http[s]?:\/\/[^\s]+([^\.\?\!\,:; ])', r'', content)
# content = re.sub(r'(\d+)?(http[s]?|file):\/\/(^.\s)*([\.,]+)?', r'\3', content)
print(content)

In [None]:
text = r'''Due
to the efficiency of hierarchical architectures, U-Net-based Diffusion Models (DMs)[1]have achieved remarkable performance in visual generation tasks[2,3,4].
Recently, inspired by the success of Transformers[5,6,7], Transformer-based DMs dubbed Diffusion Transformers (DiTs) have been developed[8]and exhibit great scalability on more complex generation tasks. Particularly, the state-of-the-art (SOTA) generation framework Sora[9]is built upon DiTs, highlighting their great potential and effectiveness.
However, their large model size and intensive computations involved in the iterative denoising process result in slow inference speeds[10], calling for effective model compression methods.'''


## Official

In [1]:
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
import re

  from tqdm.autonotebook import tqdm, trange


**TODO**
- Skip poorly fomarted papers
- Split paper by SECTION rather than '\n' char
- Remove other citation types (e.g. (Name et. al., year))
- Obtain Keywords from first section if possible
- Make section embedding Hierarchical (combine the embeddings of subections!)
- Add better ending condition (The end of the paper is not always denoted by a SECTION)
- Conver ending_words matching to regex
- Adjust multi_sec_embed to create an list of ndarrays before concatenating
- Change how the `preprocess` and `is_paper_good` methods are used (e.g. change them to be class methods)

**Potential for cleaning via HTML Parsing**
- Remove Code
- See when paper content ends?
- Remove citations
- Remove links

**Problem Papers**
- 2411.13913v1_content (Unusable)
- 2403.11585v3_content (Should end earlier)


**Embeddings**
- Each embedding file contains a collection of embeddings with shape = (num_of_secs, emb_size). NOTE THIS IS SUBJECT TO CHANGE!
- The last embedding in the file is the papers mean embedding.


**Model Info**
- https://huggingface.co/sentence-transformers/sentence-t5-large

**Things to Read**
- https://huggingface.co/docs/transformers/model_memory_anatomy
- https://huggingface.co/docs/transformers/perf_train_gpu_one
- https://huggingface.co/docs/transformers/llm_tutorial_optimization

**Training End Dates**
- LLAMA : Dec. 2023
- sentence-t5-large : ??


In [85]:
class EmbeddingLibrary():
    def __init__(self, path_to_papers:Path, path_to_embs:Path, model:SentenceTransformer, end_of_paper_words, name:str, log_lvl=logging.INFO):
        self.path_to_papers:Path = path_to_papers
        self.path_to_embs:Path = path_to_embs
        self.model = model
        self.emb_size = model.encode(['']).shape[-1]
        self.paper_ids:list[str] = [file.stem.replace('_content', '') for file in path_to_papers.iterdir()]
        self.log_lvl:int = log_lvl
        self.name:str = name
        self.end_of_paper_words:list[str] = end_of_paper_words # TODO Change to use regex
        self.full_paper_embs:np.ndarray = None

    def preprocess(self, content: str):
        # remove links (do not capture (most) punctuation at the end)
        content = re.sub(r'(\d+)?http[s]?:\/\/[^\s]+([^\.\,:; ])', r'', content) 
        # remove [] citations
        content = re.sub(r'\[(\d+[,]?)+\]([^\w\s])', r'\2', content) # for when punc follow cite
        content = re.sub(r'\[(\d+[,]?)+\](\w+)', r' \2', content) # for when chars follow cite
        return content

    def is_paper_good(self, paper:str):
        # check for usage of SECTION
        # check for section names: introduction, conclusion
        n_secs = len(re.findall(r'SECTION:', paper))
        paper_lower = paper.lower()
        n_important_secs = len(re.findall(r'section: [\d\.]*[introduction|methodology|methods|conclusions|conclusion|results|experimental results and analysis|experimental results]', paper_lower))

        return True if n_secs > 4 and n_important_secs > 2 else False

    def multi_sec_embed(self, skip_existing=True, encoding='utf-8'):
    
        logger = logging.getLogger(f'{self.name}_embs_{logging._levelToName[self.log_lvl]}.log')
        logger.setLevel(self.log_lvl)
        logging.basicConfig(
            filename=f'../log/{self.name}_embs_{logging._levelToName[self.log_lvl]}.log',
            level=self.log_lvl,
            encoding=encoding
        )

        try:
            # calculate embedding size
            emb_size = self.model.encode(['']).shape[-1]


            # TODO Implement Skip 
            for file in self.path_to_papers.iterdir():
                try:
                    with open(file, 'r', encoding='utf-8') as f:

                        # read in the content, split into sections, extract and embed the title 
                        paper = f.read()
                        if not self.is_paper_good(paper): logger.info(f'SKIPPING FILE {file.name}\n'); continue

                        paper_chunks = paper.split('\n')
                        logger.info(f'STARTING TO EMBED: {file.name}')
                        title = paper_chunks[0]

                        if 'introduction' in title.lower(): logger.info(f'SKIPPING FILE {file.name}\n'); continue 
                        else: title = title.replace('SECTION: ', '')
                        title_emb = self.model.encode(title)
                        paper_embs = title_emb.reshape(1, emb_size)

                        n_secs = 1 # start at one for title
                        sec = [] # a list of all sentences in a section
                        all_sents = []
                        for chunk in paper_chunks[1:]:
                            if chunk == '':
                                continue
                            
                            elif 'SECTION' in chunk:
                                logger.debug(f'STARTING NEW SECTION:')
                                logger.debug(f'TITLE: {chunk}\n')
                                # check if at end
                                if any([word in chunk.lower() for word in self.end_of_paper_words]):
                                    full_paper_emb = self.model.encode(all_sents).mean(axis=0).reshape(1, emb_size)
                                    paper_embs = np.concatenate([paper_embs, full_paper_emb], axis=0)
                                    logger.info(f'END OF FILE - SAVING EMBEDDINGS FOR {file.name}\n')
                                    np.save(self.path_to_embs / file.stem.replace('_content', ''), paper_embs)
                                    break
                                if len(sec) > 0:
                                    logger.debug('EMBEDDING PREVIOUOS SECTION')
                                    n_secs += 1
                                    sec_emb = model.encode(sec).mean(axis=0)
                                    logging.info(f'NUM OF EMBS IN SECTION {n_secs} = {len(sec)}')
                                    logging.debug(f'SEC EMB SHAPE: {sec_emb.shape}')
                                    paper_embs = np.concatenate([paper_embs, sec_emb.reshape(1, emb_size)], axis=0)
                                    logging.debug(f'CURRENT PAPER_EMB SHAPE: {paper_embs.shape}')
                                    sec = []
                                continue

                            else:
                                # preprocess the chunk
                                logger.debug(f'CHUNK BEING PREPROCESSED:\n{chunk}\n')
                                chunk = self.preprocess(chunk)
                                sec.append(chunk)
                                all_sents.extend(chunk)
                                logger.debug(f'PREPROCESSED CHUNK:\n{chunk}\n')

                    assert paper_embs.shape[0] - 1 == n_secs, f"Num of embeddings {paper_embs.shape[0] - 1} does NOT match num of sections {n_secs}"

                except AssertionError as assEx:
                    logger.error(f'Exception caught while embedding file:{file.stem}')
                    logger.error(assEx)
                    continue
            return None
        except Exception as ex:
            print('SHIT')
            print(ex)
            raise ex
        
    def update_full_paper_embs(self) -> None:
        paper_embs = []
        for file in self.path_to_embs.iterdir():
            paper_embs.append(np.load(file)[-1, :]) # get the last embedding

        self.full_paper_embs = np.stack(paper_embs, axis=0)
        return
    
    def search_papers(self, query:str, n_results=5):
        assert self.full_paper_embs is not None, print("self.full_paper_embs must be set to use this function!")
        q_emb = self.model.encode([query])
        scores = self.model.similarity(q_emb, self.full_paper_embs)
        print(scores)
        top_n_idxs = np.argsort(scores).tolist()[0][:n_results]
        top_papers = [self.paper_ids[idx] for idx in top_n_idxs]
        return top_papers



        


In [83]:
path_to_papers = Path(r'../papers')
path_to_embs = Path(r'../data/embeddings')

END_OF_PAPER_WORDS = [
    'references',
    'acknowledgements', 
    'acknowledgement',
]

papers_to_skip = [
    '2311.11329v2_content.txt',
    '2411.09324v2_content.txt',
    '2411.14259v1_content.txt'
]


In [84]:
model = SentenceTransformer('sentence-transformers/sentence-t5-large', device='cuda')

In [76]:
emb_lib = EmbeddingLibrary(
    path_to_papers=path_to_papers,
    path_to_embs=path_to_embs,
    model=model,
    end_of_paper_words=END_OF_PAPER_WORDS,
    papers_to_skip=papers_to_skip,
    name="Prototype"
)

emb_lib.update_full_paper_embs()

In [80]:
x = model.encode(['RF jamming transmits radio signals indiscriminately across a range of frequencies, causing interference and disrupting communication. Jamming can be broadly categorized as active jamming and reactive jamming. Active jamming continuously emits powerful interference signals, but its continuous operation leaves detectable traces, making it vulnerable to defensive techniques. Reactive jamming adjusts its jamming behavior according to observed signals in the environment. It remains silent when the channel is idle but initiates high-power signal transmission upon detecting activity on the channel. The drawback of these approaches is that spectrum owners may promptly detect the presence of an attack and respond accordingly.'])

y = model.encode(['The cat was fat. So was the dog.', 'He ran so fast. But not fast enough. Sad fat dog'])

In [82]:
y.shape

(2, 768)

In [77]:
query = ' the research provides an innovative and efficient solution for code generation from ML task descriptions, showcasing the capabilities of Linguacodus. By capitalizing on the Code4ML dataset’s wealth of resources and introducing a structured approach to instruction synthesis and code generation, we bridge the gap between natural language task descriptions and executable code, making machine learning development more accessible and efficient'

emb_lib.search_papers(query)

tensor([[0.7425, 0.7420, 0.7429, 0.7443, 0.7437, 0.7426, 0.7426, 0.7418, 0.7421,
         0.7428, 0.7431, 0.7434, 0.7432, 0.7447, 0.7426, 0.7425, 0.7430, 0.7425,
         0.7437, 0.7419, 0.7440, 0.7423, 0.7412, 0.7420, 0.7433, 0.7431, 0.7438,
         0.7427, 0.7441, 0.7427, 0.7429, 0.7430, 0.7430, 0.7424, 0.7423, 0.7447,
         0.7436, 0.7422, 0.7425, 0.7439, 0.7442, 0.7432, 0.7438]])


['2411.09598v1',
 '2403.12778v2',
 '2411.09101v1',
 '2411.09702v1',
 '2309.01837v3']

In [8]:
x = np.load(r'..\data\embeddings\2308.07279v2.npy')
x.shape

(18, 768)

In [6]:
# def paper_emb(model:SentenceTransformer, path_to_papers:Path, save_dir:Path, ending_words:list[str], log_lvl=logging.INFO, encoding='utf-8'):
#     logger = logging.getLogger(f'embs_{model._get_name()}_{path_to_papers.name}_{log_lvl}')
#     logger.setLevel(log_lvl)
#     logging.basicConfig(
#         filename=f'../log/embs_{model._get_name()}_{path_to_papers.name}_{logging._levelToName[log_lvl]}.log',
#         level=log_lvl,
#         encoding=encoding
#     )
#     try:
#         emb_size = model.encode(['']).shape
#         for file in path_to_papers.iterdir():
#             try:
#                 with open(file, 'r', encoding='utf-8') as f:
#                     paper = f.read()
#                     if not is_paper_good(paper): logger.info(f'SKIPPING FILE {file.name}\n'); continue 

#                     logger.info(f'STARTING TO EMBED: {file.name}')
                    

def preprocess(content: list[str]|str):
    # remove links (do not capture (most) punctuation at the end)
    content = re.sub(r'(\d+)?http[s]?:\/\/[^\s]+([^\.\,:; ])', r'', content) 
    # remove [] citations
    content = re.sub(r'\[(\d+[,]?)+\]([^\w\s])', r'\2', content) # for when punc follow cite
    content = re.sub(r'\[(\d+[,]?)+\](\w+)', r' \2', content) # for when chars follow cite
    return content
def is_paper_good(paper:str):
    # check for usage of SECTION
    # check for section names: introduction, conclusion, etc.
    n_secs = len(re.findall(r'SECTION:', paper))
    paper_lower = paper.lower()
    n_important_secs = len(re.findall(r'section: [\d\.]*[introduction|methodology|methods|conclusions|conclusion|results|experimental results and analysis|experimental results]', paper_lower))
    return True if n_secs > 4 and n_important_secs > 2 else False

def multi_sec_embed(model:SentenceTransformer, path_to_papers:Path, save_dir:Path, ending_words:list[str], log_lvl=logging.INFO, encoding='utf-8'):
    
    logger = logging.getLogger(f'embs_{model._get_name()}_{path_to_papers.name}_{logging._levelToName[log_lvl]}')
    logger.setLevel(log_lvl)
    logging.basicConfig(
        filename=f'../log/embs_{model._get_name()}_{path_to_papers.name}_{logging._levelToName[log_lvl]}.log',
        level=log_lvl,
        encoding=encoding
    )
    
    try:
        # calculate embedding size
        emb_size = model.encode(['']).shape[-1]
        for file in path_to_papers.iterdir():
            try:
                with open(file, 'r', encoding='utf-8') as f:

                    # read in the content, split into sections, extract and embed the title 
                    paper = f.read()
                    if not is_paper_good(paper): logger.info(f'SKIPPING FILE {file.name}\n'); continue 
                    logger.info(f'STARTING TO EMBED: {file.name}')
                    paper_chunks = paper.split('\n')
                    title = paper_chunks[0]

                    if 'introduction' in title.lower(): logger.info(f'SKIPPING FILE {file.name}\n'); continue 
                    else: title = title.replace('SECTION: ', '')
                    title_emb = model.encode(title)
                    paper_embs = title_emb.reshape(1, emb_size)

                    n_secs = 1 # start at one for title
                    sec = [] # a list of all sentences in a section
                    all_sents = []
                    for chunk in paper_chunks[1:]:
                        if chunk == '':
                            continue
                        
                        elif 'SECTION' in chunk:
                            logger.debug(f'STARTING NEW SECTION:')
                            logger.debug(f'TITLE: {chunk}\n')
                            # check if at end
                            if any([word in chunk.lower() for word in ending_words]):
                                full_paper_emb = model.encode(all_sents).mean(axis=0).reshape(1, emb_size)
                                paper_embs = np.concatenate([paper_embs, full_paper_emb], axis=0)
                                logger.info(f'END OF FILE - SAVING EMBEDDINGS FOR {file.name}\n')
                                np.save(save_dir / file.stem.replace('_content', ''), paper_embs)
                                break
                            if len(sec) > 0:
                                logger.debug('EMBEDDING PREVIOUOS SECTION')
                                n_secs += 1
                                sec_emb = model.encode(sec).mean(axis=0)
                                logging.info(f'NUM OF EMBS IN SECTION {n_secs} = {len(sec)}')
                                logging.debug(f'SEC EMB SHAPE: {sec_emb.shape}')
                                paper_embs = np.concatenate([paper_embs, sec_emb.reshape(1, emb_size)], axis=0)
                                logging.debug(f'CURRENT PAPER_EMB SHAPE: {paper_embs.shape}')
                                sec = []
                            continue
                            
                        else:
                            # preprocess the chunk
                            logger.debug(f'CHUNK BEING PREPROCESSED:\n{chunk}\n')
                            chunk = preprocess(chunk)
                            sec.append(chunk)
                            all_sents.extend(chunk)
                            logger.debug(f'PREPROCESSED CHUNK:\n{chunk}\n')

                assert paper_embs.shape[0] - 1 == n_secs, f"Num of embeddings {paper_embs.shape[0] - 1} does NOT match num of sections {n_secs}"

            except AssertionError as assEx:
                logger.error(f'Exception caught while embedding file:{file.stem}')
                logger.error(assEx)
                continue
        return None
    except Exception as ex:
        print('SHIT')
        print(ex)
        raise ex


In [7]:
multi_sec_embed(model=model, path_to_papers=path_to_papers, save_dir=path_to_embs, ending_words=ENDING_WORDS, log_lvl=logging.INFO)

Batches: 100%|██████████| 1/1 [00:00<00:00,  7.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  6.22it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  7.80it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 21.98it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  5.36it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  7.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.75it/s]
Batches: 1

In [9]:
emb_files = list(Path(r'../data/embeddings/').iterdir()) 
files = list(Path(r'../papers/').iterdir())
for e, f in zip(emb_files, files):
    print(e, f)
    if e.stem != f.stem.replace('_content', ''): print('ERROR'); break

..\data\embeddings\2308.07279v2.npy ..\papers\2308.07279v2_content.txt
..\data\embeddings\2309.01837v3.npy ..\papers\2309.01837v3_content.txt
..\data\embeddings\2311.00207v3.npy ..\papers\2311.00207v3_content.txt
..\data\embeddings\2312.09440v2.npy ..\papers\2311.11329v2_content.txt
ERROR


In [107]:
a = np.load('../data/embeddings/2308.07279v2.npy')

In [29]:
arrs = model.encode(['The dog ran away. I caught up to him though. This is where things got werid... The cat ran away too.', 'The cat ran away. I couldn\'t find her'])

Batches: 100%|██████████| 1/1 [00:00<00:00,  4.85it/s]


In [30]:
arrs

array([[-0.0573801 , -0.02413465, -0.02862503, ..., -0.01144615,
         0.01713123,  0.02775556],
       [-0.01482539, -0.00964323, -0.01514543, ..., -0.00653667,
         0.01400161,  0.05429142]], dtype=float32)

In [31]:
x, y = arrs[0, :], arrs[1, :]

model.similarity(x, y)

tensor([[0.8562]])