In [1]:
import re
import os
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_distances
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from llama_cpp import Llama

# Load the LLM
llm = Llama(model_path="./mistral-7b-instruct-v0.2.Q4_K_M.gguf", n_gpu_layers=45,  n_ctx=4096, chat_format='llama-2')

# Initialize the Count Vectorizer
vectorizer = CountVectorizer()

# Get the stop words for the English language from NLTK
stop_words = stopwords.words('english')

AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | 


In [2]:
def generate_slices(input, window_size, overlap_size=100, threshold=0.2):
    """ 
    Generate slices from a tokenized text if it is too long for the LLM.
    
    Args:
        input (list): A tokenized text.
        window_size (int): The size of the window.
        overlap_size (int): The size of the overlap between slices.
        threshold (float): The threshold for the cosine distance.
    
    Returns:
        A list of slices.
    """

    slices = []
    # If the input is shorter than the window size, return it as is
    if len(input) <= window_size:
        return [llm.detokenize(input).decode('utf-8')]
    
    # Otherwise, generate slices of the input and return them
    for i in range(0, len(input), window_size - overlap_size):
        slice = llm.detokenize(input[i:i + window_size]).decode('utf-8')
        if slices:
            X = vectorizer.fit_transform([slice, slices[-1]])
            if cosine_distances(X[0], X[1]) >= threshold:
                slices.append(slice)
            else:
                while True:
                    i += 1
                    slice = llm.detokenize(input[i:i + window_size]).decode('utf-8')
                    X = vectorizer.fit_transform([slice, slices[-1]])
                    if cosine_distances(X[0], X[1]) >= threshold:
                        slices.append(slice)
                        break
                    elif i % window_size == 0:
                        slices.append(slice)
                        break
        else:
            slices.append(slice)
    return slices

In [3]:
# The question to ask the LLM
question = 'Tell me the main topic of this text: '

# The length of the question in tokens
len_encoded_question = len(llm.tokenize(question.encode('utf-8')))


# Each file in the pages folder is split into words using NLTK word_tokenize, then the stop words are removed and using a regex the punctuation is fixed. Finally the text is encoded using the LLM
# tokenizer and passed to the generate_slices function to generate slices of the text that are less than ~4096 tokens long and have a cosine distance of at least 0.2 from the previous slice.
slices = {}
for file in os.listdir('pages'):
    with open(f'pages/{file}', 'r', encoding='utf-8') as f:
        text = f.read()
        text = word_tokenize(text)
        text = [word for word in text if word not in stop_words]
        text = " ".join(text)
        text = re.sub(r"(?: ([.,;%'\":\)\]?!]))", r"\g<1>", text).replace('( ', '(').replace('[ ', '[').replace(' ``', '"').replace('\'\'', '"')
        encoded_full_text = llm.tokenize(text.encode('utf-8'))
        slices[file] = generate_slices(encoded_full_text, window_size=4096-150-len_encoded_question)
        print(f'file: {file}, tokens: {len(encoded_full_text)}, slices: {len(slices[file])}')

file: Affordable Care Act[medical].txt, tokens: 21170, slices: 6
file: Chewing[medical].txt, tokens: 937, slices: 1
file: Coeliac disease[medical].txt, tokens: 9768, slices: 3
file: Hochelaga, Land of Souls[non-medical].txt, tokens: 1866, slices: 1
file: Malignancy[medical].txt, tokens: 2122, slices: 1
file: Occupy movement[non-medical].txt, tokens: 14963, slices: 4
file: Outlaw King[non-medical].txt, tokens: 1832, slices: 1
file: Phoenician–Punic literature[non-medical].txt, tokens: 1711, slices: 1
file: Political journalism[non-medical].txt, tokens: 1386, slices: 1
file: United States Public Health Service[medical].txt, tokens: 3625, slices: 1


In [4]:
# A dictionary to store the topics for each file
topics = {}

# For each file, ask the LLM the question for each slice and store the answer in the topics dictionary
for file, slices in slices.items():
    print(f'file: {file}')
    for slice in slices:
        chat = llm.create_chat_completion(
            messages= [
                {
                    "role": "system",
                    "content": "You are an helpful assistant who can find the main topic of a text and you answer with fewer than 100 words."
                },
                {
                    "role": "user",
                    "content": question + slice
                }
            ],
            max_tokens=150
        )
        print(f'    Generated answer for slice {slices.index(slice) + 1}')
        if topics.get(file) is None:
            topics[file] = [chat['choices'][0]['message']['content']]
        else:
            topics[file].append(chat['choices'][0]['message']['content'])

file: Affordable Care Act[medical].txt
    Generated answer for slice 1


Llama.generate: prefix-match hit


    Generated answer for slice 2


Llama.generate: prefix-match hit


    Generated answer for slice 3


Llama.generate: prefix-match hit


    Generated answer for slice 4


Llama.generate: prefix-match hit


    Generated answer for slice 5


Llama.generate: prefix-match hit


    Generated answer for slice 6
file: Chewing[medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1
file: Coeliac disease[medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1


Llama.generate: prefix-match hit


    Generated answer for slice 2


Llama.generate: prefix-match hit


    Generated answer for slice 3
file: Hochelaga, Land of Souls[non-medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1
file: Malignancy[medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1
file: Occupy movement[non-medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1


Llama.generate: prefix-match hit


    Generated answer for slice 2


Llama.generate: prefix-match hit


    Generated answer for slice 3


Llama.generate: prefix-match hit


    Generated answer for slice 4
file: Outlaw King[non-medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1
file: Phoenician–Punic literature[non-medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1
file: Political journalism[non-medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1
file: United States Public Health Service[medical].txt


Llama.generate: prefix-match hit


    Generated answer for slice 1


In [5]:
# Print the topic for each slice of each file
for file, topics_list in topics.items():
    print(file)
    for topic in topics_list:
        print(f'    Slice {topics_list.index(topic) + 1}:')
        print(topic)
    print('\n')

Affordable Care Act[medical].txt
    Slice 1:
 The main topic of this text is the Affordable Care Act (ACA), also known as Obamacare. It discusses the enactment and major provisions of the law, which represents a significant regulatory overhaul of the U.S. healthcare system. The text covers topics such as insurance regulations, Medicaid expansion, cost savings, opposition, and legal challenges to the ACA.
    Slice 2:
 The main topic of this text is the history and implementation of the individual mandate in healthcare reform in the United States, specifically during the presidencies of Bill Clinton and Barack Obama. The text discusses various proposals for healthcare reform, including those from conservative think-tanks like The Heritage Foundation, and the eventual passage of the Affordable Care Act (ACA) under President Obama, which included an individual mandate. The text also touches on the political debates and controversies surrounding these proposals and their implementation.
 