##**Environment detection and dependency installation**

In [None]:
# Perform Google Colab installs (if running in Google Colab)
import os

if "COLAB_GPU" in os.environ:
    print("[INFO] Running in Google Colab, installing requirements.")
    !pip install -U torch # requires torch 2.1.1+ (for efficient sdpa implementation)
    !pip install tqdm # for progress bars
    !pip install sentence-transformers # for embedding models
    !pip install accelerate # for quantization model loading
    !pip install bitsandbytes # for quantizing models (less storage space)
#     !pip install flash-attn --no-build-isolation # for faster attention mechanism = faster LLM inference

[INFO] Running in Google Colab, installing requirements.


##**Import libraries**

In [None]:
import pandas as pd
import random
from spacy.lang.en import English
from tqdm import tqdm
import re
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
re.compile('<title>(.*)title>')

re.compile(r'<title>(.*)title>', re.UNICODE)

## **Data Input**

In [None]:
# Get TEXT document
text_path = pd.read_csv("IELTS.txt", sep="\t")

In [None]:
text_path

Unnamed: 0,Skip to main content
0,Texts
1,Video
2,Audio
3,Software
4,Images
...,...
19926,University Printing House
19927,Shaftesbury Road
19928,Cambridge CB2 8BS
19929,UK


In [None]:
def open_and_read_txt(txt_path: str) -> list[dict]:

    pages_and_texts = []

    # Open a text file and read it line by line
    with open(txt_path, "r", encoding="utf-8") as file:
        lines = file.readlines()

    for line_number, line in enumerate(lines):
        text = line.strip()  # Remove extra whitespace characters
        pages_and_texts.append({
            "line_number": line_number + 1,  # Current line number, starting from 1
            "line_char_count": len(text),  # Number of characters
            "line_word_count": len(text.split(" ")),  # Number of words
            "line_sentence_count_raw": len(text.split(". ")),  # Number of sentences
            "line_token_count": len(text) / 4,  # Estimated number of tokens (1 token is about 4 characters)
            "text": text  # The text content of the current line
        })

    return pages_and_texts


# Call the function to read the IELTS.txt file
txt_path = "IELTS.txt"
lines_and_texts = open_and_read_txt(txt_path)

# View the results of the first two rows
lines_and_texts[:2]

[{'line_number': 1,
  'line_char_count': 20,
  'line_word_count': 4,
  'line_sentence_count_raw': 1,
  'line_token_count': 5.0,
  'text': 'Skip to main content'},
 {'line_number': 2,
  'line_char_count': 0,
  'line_word_count': 1,
  'line_sentence_count_raw': 1,
  'line_token_count': 0.0,
  'text': ''}]

In [None]:
random.sample(lines_and_texts, k=3)

[{'line_number': 3075,
  'line_char_count': 0,
  'line_word_count': 1,
  'line_sentence_count_raw': 1,
  'line_token_count': 0.0,
  'text': ''},
 {'line_number': 3155,
  'line_char_count': 0,
  'line_word_count': 1,
  'line_sentence_count_raw': 1,
  'line_token_count': 0.0,
  'text': ''},
 {'line_number': 24954,
  'line_char_count': 0,
  'line_word_count': 1,
  'line_sentence_count_raw': 1,
  'line_token_count': 0.0,
  'text': ''}]

In [None]:
df = pd.DataFrame(lines_and_texts)
df.head()

Unnamed: 0,line_number,line_char_count,line_word_count,line_sentence_count_raw,line_token_count,text
0,1,20,4,1,5.0,Skip to main content
1,2,0,1,1,0.0,
2,3,5,1,1,1.25,Texts
3,4,0,1,1,0.0,
4,5,5,1,1,1.25,Video


In [None]:
# Get stats
df.describe().round(2)

Unnamed: 0,line_number,line_char_count,line_word_count,line_sentence_count_raw,line_token_count
count,40321.0,40321.0,40321.0,40321.0,40321.0
mean,20161.0,20.31,4.24,1.09,5.08
std,11639.81,26.61,4.55,0.31,6.65
min,1.0,0.0,1.0,1.0,0.0
25%,10081.0,0.0,1.0,1.0,0.0
50%,20161.0,3.0,1.0,1.0,0.75
75%,30241.0,42.0,7.0,1.0,10.5
max,40321.0,117.0,32.0,6.0,29.25


##**Data and text processing**

#### **Use spaCy and tqdm for sentence segmentation and statistics on text**

In [None]:
from spacy.lang.en import English # see https://spacy.io/usage for install instructions

nlp = English()

# Add a sentencizer pipeline, see https://spacy.io/api/sentencizer/
nlp.add_pipe("sentencizer")

# Create a document instance as an example
doc = nlp("This is a sentence. This another sentence.")
assert len(list(doc.sents)) == 2

# Access the sentences of the document
list(doc.sents)

[This is a sentence., This another sentence.]

In [None]:
# Step 2: Initialize the English model of spaCy
nlp = English()
nlp.add_pipe("sentencizer")  # Add the sentencizer to the spaCy processing pipeline

# Step 3: Split the text column of the dataset into sentences
pages_and_texts = [{"text": str(row[0]).strip()} for row in text_path.values]  # Extract the text column

# Step 4: Iterate over each row of text, split it into sentences, and count the number of sentences
for item in tqdm(pages_and_texts):
    # Use spaCy to split sentences
    item["sentences"] = list(nlp(item["text"]).sents)

    # Ensure sentences are in string format
    item["sentences"] = [str(sentence) for sentence in item["sentences"]]

    # Count the number of sentences
    item["page_sentence_count_spacy"] = len(item["sentences"])

# Step 5: View the results (for example, the first two rows)
for page in pages_and_texts[:2]:
    print(page)

100%|██████████| 19931/19931 [00:02<00:00, 8967.13it/s]

{'text': 'Texts', 'sentences': ['Texts'], 'page_sentence_count_spacy': 1}
{'text': 'Video', 'sentences': ['Video'], 'page_sentence_count_spacy': 1}





In [None]:
# Inspect an example
random.sample(pages_and_texts, k=1)

[{'text': 'C', 'sentences': ['C'], 'page_sentence_count_spacy': 1}]

In [None]:
df = pd.DataFrame(pages_and_texts)
df.describe().round(2)

Unnamed: 0,page_sentence_count_spacy
count,19931.0
mean,1.19
std,2.12
min,1.0
25%,1.0
50%,1.0
75%,1.0
max,295.0


#### **Chunking ten sentences together**

In [None]:
# Define split size to turn groups of sentences into chunks
num_sentence_chunk_size = 10

# Create a function that recursively splits a list into desired sizes
def split_list(input_list: list,
               slice_size: int) -> list[list[str]]:
    """
    Splits the input_list into sublists of size slice_size (or as close as possible).

    For example, a list of 17 sentences would be split into two lists of [[10], [7]]
    """
    return [input_list[i:i + slice_size] for i in range(0, len(input_list), slice_size)]

# Loop through pages and texts and split sentences into chunks
for item in tqdm(pages_and_texts):
    item["sentence_chunks"] = split_list(input_list=item["sentences"],
                                         slice_size=num_sentence_chunk_size)
    item["num_chunks"] = len(item["sentence_chunks"])

100%|██████████| 19931/19931 [00:00<00:00, 680671.52it/s]


In [None]:
# Sample an example from the group (note: many samples have only 1 chunk as they have <=10 sentences total)
random.sample(pages_and_texts, k=1)

[{'text': "fh ii ii really doesn't suit ihe way we work these* days. Its",
  'sentences': ["fh ii ii really doesn't suit ihe way we work these* days.",
   'Its'],
  'page_sentence_count_spacy': 2,
  'sentence_chunks': [["fh ii ii really doesn't suit ihe way we work these* days.",
    'Its']],
  'num_chunks': 1}]

In [None]:
# Create a DataFrame to get stats
df = pd.DataFrame(pages_and_texts)
df.describe().round(2)

Unnamed: 0,page_sentence_count_spacy,num_chunks
count,19931.0,19931.0
mean,1.19,1.0
std,2.12,0.21
min,1.0,1.0
25%,1.0,1.0
50%,1.0,1.0
75%,1.0,1.0
max,295.0,30.0


#### **Splitting each chunk into its own item**

In [None]:
# Create a new list to store information about each sentence chunk
pages_and_chunks = []

# Iterate over each text block
for item in tqdm(pages_and_texts):
    for sentence_chunk in item["sentence_chunks"]:  # Iterate over each chunk
        chunk_dict = {}  # Store information about the current chunk

        # Optionally, add page number information
        chunk_dict["page_number"] = item.get("page_number", None)  # Default to None if no page number

        # Join the sentences in the chunk into a single string
        joined_sentence_chunk = " ".join(sentence_chunk).replace("\n", " ").strip()
        # Regular expression replacement: replace ". A" with ".\nA" to handle sentence separators
        joined_sentence_chunk = re.sub(r"\. ([A-Z])", r". \1", joined_sentence_chunk)
        chunk_dict["sentence_chunk"] = joined_sentence_chunk

        # Gather statistics about the chunk
        chunk_dict["chunk_char_count"] = len(joined_sentence_chunk)  # Character count
        chunk_dict["chunk_word_count"] = len(joined_sentence_chunk.split(" "))  # Word count
        chunk_dict["chunk_token_count"] = len(joined_sentence_chunk) / 4  # Estimate token count (1 token ≈ 4 characters)

        # Add the current chunk to the list
        pages_and_chunks.append(chunk_dict)

# View statistics: how many chunks there are
print(f"Total chunks: {len(pages_and_chunks)}")

# Example print of the first two chunks
for chunk in pages_and_chunks[:2]:
    print(chunk)


100%|██████████| 19931/19931 [00:00<00:00, 254172.03it/s]


Total chunks: 19960
{'page_number': None, 'sentence_chunk': 'Texts', 'chunk_char_count': 5, 'chunk_word_count': 1, 'chunk_token_count': 1.25}
{'page_number': None, 'sentence_chunk': 'Video', 'chunk_char_count': 5, 'chunk_word_count': 1, 'chunk_token_count': 1.25}


In [None]:
# View a random sample
random.sample(pages_and_chunks, k=1)

[{'page_number': None,
  'sentence_chunk': 'Test Tip Pay attention',
  'chunk_char_count': 22,
  'chunk_word_count': 4,
  'chunk_token_count': 5.5}]

Now we've broken our whole textbook into chunks of 10 sentences or less as well as the page number they came from.

In [None]:
# Get stats about our chunks
df = pd.DataFrame(pages_and_chunks)
df.describe().round(2)

Unnamed: 0,chunk_char_count,chunk_word_count,chunk_token_count
count,19960.0,19960.0,19960.0
mean,41.26,7.67,10.31
std,58.21,11.59,14.55
min,1.0,1.0,0.25
25%,16.0,3.0,4.0
50%,41.0,7.0,10.25
75%,60.0,11.0,15.0
max,2346.0,438.0,586.5


Here,because we foung that Chunks that are too short (token count ≤ 30) may lack sufficient contextual information, resulting in embeddings generated that are not meaningful enough.So,we selcet token_length more than 30.

#### **Select token_length >30**

In [None]:
# Show random chunks with under 30 tokens in length
min_token_length = 30
for row in df[df["chunk_token_count"] <= min_token_length].sample(5).iterrows():
    print(f'Chunk token count: {row[1]["chunk_token_count"]} | Text: {row[1]["sentence_chunk"]}')

Chunk token count: 1.75 | Text: Writing
Chunk token count: 3.25 | Text: party starter
Chunk token count: 2.75 | Text: attach it):
Chunk token count: 5.75 | Text: Choose TWO letters. A-E


Hmm looks like some of our chunks have quite a low token count.

How about we check for samples with less than 30 tokens (about the length of a sentence) and see if they are worth keeping?

In [None]:
pages_and_chunks_over_min_token_len = df[df["chunk_token_count"] > min_token_length].to_dict(orient="records")
pages_and_chunks_over_min_token_len[:2]

[{'page_number': None,
  'sentence_chunk': '[fudging from] the complexity of the material that has been collected from different parts of the landscape \r and brought to the site, they | the people] must have had an elementary knowledge of chemistry to be able to \r combine these materials to produce ibis form. Its not a straightforward process,™ said Henshilwood. \r \r \r 1 *2 Scanning involves searching a text quickly for a specific piece \r of information. Practise scanning the passage for the words/ \r numbers in the box. \r \r \r 75,000 100,000 200,000 artefacts ochre \r \r \r 48 \r \r \r \r \r \r \r \r \r Reading skills \r \r \r 2 Using words from the passage \r \r Their are several types of question that ask you to write a word and/or \r number from the passage. \r \r * You will be told the maximum number of words to write. \r \r * You must only write words that are in the passage. Make sure you \r copy the spelling correctly, \r \r 1 ^ ^ need to change the words in the passage 

#### **Embedding our text chunks**
Our goal is to turn each of our chunks into a numerical representation (an embedding vector, where a vector is a sequence of numbers arranged in order).

In [None]:
 !pip install sentence-transformers



In [None]:
# !pip install --upgrade --force-reinstall torchvision torchaudio torchtext torch

In [None]:
embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2",device="cpu") # choose the device to load the model to (note: GPU will often be *much* faster than CPU)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


How about we add an embedding field to each of our chunk items in Single processing?

In [None]:
%%time

# Send the model to the GPU
embedding_model.to("cuda") # requires a GPU installed, for reference on my local machine, I'm using a NVIDIA RTX 4090

# Create embeddings one by one on the GPU
for item in tqdm(pages_and_chunks_over_min_token_len):
    item["embedding"] = embedding_model.encode(item["sentence_chunk"])

100%|██████████| 30/30 [00:01<00:00, 24.88it/s]

CPU times: user 1.83 s, sys: 281 ms, total: 2.12 s
Wall time: 1.45 s





How about batch processing?

In [None]:
# Turn text chunks into a single list
text_chunks = [item["sentence_chunk"] for item in pages_and_chunks_over_min_token_len]

In [None]:
%%time

# Embed all texts in batches
text_chunk_embeddings = embedding_model.encode(text_chunks,
                                               batch_size=32, # you can use different batch sizes here for speed/performance, I found 32 works well for this use case
                                               convert_to_tensor=True) # optional to return embeddings as tensor instead of array

text_chunk_embeddings

CPU times: user 465 ms, sys: 7.39 ms, total: 472 ms
Wall time: 395 ms


tensor([[ 0.0071, -0.0755, -0.0205,  ...,  0.0258, -0.0396,  0.0141],
        [ 0.0259, -0.0702, -0.0171,  ...,  0.0283, -0.0479, -0.0148],
        [ 0.0564, -0.0397, -0.0207,  ...,  0.0192, -0.0396, -0.0039],
        ...,
        [ 0.0109,  0.0319, -0.0289,  ...,  0.0763,  0.0237, -0.0272],
        [ 0.0297, -0.0098, -0.0201,  ...,  0.0699,  0.0285, -0.0238],
        [ 0.0270, -0.0286,  0.0103,  ...,  0.0457, -0.0323, -0.0239]],
       device='cuda:0')

#### **Save embeddings to file**


In [None]:
# Save embeddings to file
text_chunks_and_embeddings_df = pd.DataFrame(pages_and_chunks_over_min_token_len)
embeddings_df_save_path = "text_chunks_and_embeddings_df.csv"
text_chunks_and_embeddings_df.to_csv(embeddings_df_save_path, index=False)

In [None]:
# Import saved file and view
text_chunks_and_embedding_df_load = pd.read_csv(embeddings_df_save_path)
text_chunks_and_embedding_df_load.head()

Unnamed: 0,page_number,sentence_chunk,chunk_char_count,chunk_word_count,chunk_token_count,embedding
0,,[fudging from] the complexity of the material ...,1381,293,345.25,[ 7.10453186e-03 -7.55081177e-02 -2.05419790e-...
1,,You do not \r need to write full sentences or ...,1113,220,278.25,[ 2.58541796e-02 -7.01962784e-02 -1.70616377e-...
2,,49 \r \r \r \r \r \r \r \r \r \r \r \r \r \r \...,722,153,180.5,[ 5.64004555e-02 -3.96810472e-02 -2.07439456e-...
3,,"1 For Question 4, which word/s in the passage ...",1628,302,407.0,[ 6.36236519e-02 -6.75108954e-02 -3.08494326e-...
4,,50 \r \r \r \r \r \r \r \r \r Reading skills \...,993,204,248.25,[-1.61707476e-02 -7.01474622e-02 -4.12495732e-...


#### **Chunking and embedding questions**

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Import texts and embedding df
text_chunks_and_embedding_df = pd.read_csv("text_chunks_and_embeddings_df.csv")

# Convert embedding column back to np.array (it got converted to string when it got saved to CSV)
text_chunks_and_embedding_df["embedding"] = text_chunks_and_embedding_df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" "))

# Convert texts and embedding df to list of dicts
pages_and_chunks = text_chunks_and_embedding_df.to_dict(orient="records")

# Convert embeddings to torch tensor and send to device (note: NumPy arrays are float64, torch tensors are float32 by default)
embeddings = torch.tensor(np.array(text_chunks_and_embedding_df["embedding"].tolist()), dtype=torch.float32).to(device)
embeddings.shape

torch.Size([30, 768])

####**Similarity search**

In [None]:
text_chunks_and_embedding_df.head()

Unnamed: 0,page_number,sentence_chunk,chunk_char_count,chunk_word_count,chunk_token_count,embedding
0,,[fudging from] the complexity of the material ...,1381,293,345.25,"[0.00710453186, -0.0755081177, -0.020541979, 0..."
1,,You do not \r need to write full sentences or ...,1113,220,278.25,"[0.0258541796, -0.0701962784, -0.0170616377, 0..."
2,,49 \r \r \r \r \r \r \r \r \r \r \r \r \r \r \...,722,153,180.5,"[0.0564004555, -0.0396810472, -0.0207439456, 0..."
3,,"1 For Question 4, which word/s in the passage ...",1628,302,407.0,"[0.0636236519, -0.0675108954, -0.0308494326, 0..."
4,,50 \r \r \r \r \r \r \r \r \r Reading skills \...,993,204,248.25,"[-0.0161707476, -0.0701474622, -0.0412495732, ..."


In [None]:
embeddings[0]

tensor([ 7.1045e-03, -7.5508e-02, -2.0542e-02,  5.3325e-02, -7.2728e-02,
        -2.5395e-02,  1.0261e-02,  5.3646e-02, -1.3570e-02,  4.7879e-03,
         6.0068e-02, -1.1683e-02,  8.1880e-02,  1.5494e-02, -4.0406e-02,
        -1.0858e-02,  4.4307e-02, -2.1993e-02, -2.7438e-02,  2.4333e-02,
        -2.9353e-02,  2.8078e-02, -8.1387e-03, -5.5367e-02, -2.4519e-02,
         1.2610e-02, -3.0707e-02, -3.1068e-02,  1.2954e-02, -6.4856e-02,
        -1.5680e-02,  3.6647e-02, -3.9773e-02, -1.8977e-02,  2.2121e-06,
        -5.6490e-02, -2.3480e-02,  1.3385e-02, -5.0297e-02,  2.9676e-02,
         7.0152e-02,  5.9260e-02,  4.5085e-02, -9.0790e-03, -2.8074e-03,
         4.1529e-03,  2.0545e-02,  5.0574e-02,  3.6151e-02,  1.8902e-02,
         1.0677e-02, -1.4560e-02,  4.3371e-02, -1.5906e-02,  9.5423e-02,
         9.4999e-03,  6.2746e-03,  1.3461e-02,  5.6207e-02,  1.2438e-01,
        -2.8209e-02,  3.6967e-02, -1.4678e-02,  1.2529e-02,  1.0953e-02,
        -1.3343e-02,  3.8065e-02, -8.5211e-02,  5.8

## **Time to perform a semantic search**

In [None]:
from sentence_transformers import util, SentenceTransformer

embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2",
                                      device=device) # choose the device to load the model to

In [None]:
# 1. Define the query
# Note: This could be anything. But since we're working with a nutrition textbook, we'll stick with nutrition-based queries.
query = "reading skill"
print(f"Query: {query}")

# 2. Embed the query to the same numerical space as the text examples
# Note: It's important to embed your query with the same model you embedded your examples with.
query_embedding = embedding_model.encode(query, convert_to_tensor=True)

# 3. Get similarity scores with the dot product (we'll time this for fun)
from time import perf_counter as timer

start_time = timer()
dot_scores = util.dot_score(a=query_embedding, b=embeddings)[0]
end_time = timer()

print(f"Time take to get scores on {len(embeddings)} embeddings: {end_time-start_time:.5f} seconds.")

# 4. Get the top-k results (we'll keep this to 5)
top_results_dot_product = torch.topk(dot_scores, k=5)
top_results_dot_product

Query: reading skill
Time take to get scores on 30 embeddings: 0.00833 seconds.


torch.return_types.topk(
values=tensor([0.5677, 0.5357, 0.5271, 0.5006, 0.4489], device='cuda:0'),
indices=tensor([23,  2,  4, 17,  5], device='cuda:0'))

In [None]:
# Define helper function to print wrapped text
import textwrap

def print_wrapped(text, wrap_length=80):
    wrapped_text = textwrap.fill(text, wrap_length)
    print(wrapped_text)

Show the result!!

In [None]:
print(f"Query: '{query}'\n")
print("Results:")
# Loop through zipped together scores and indicies from torch.topk
for score, idx in zip(top_results_dot_product[0], top_results_dot_product[1]):
    print(f"Score: {score:.4f}")
    # Print relevant sentence chunk (since the scores are in descending order, the most relevant chunk will be first)
    print("Text:")
    print_wrapped(pages_and_chunks[idx]["sentence_chunk"])
    # Print the page number too so we can reference the textbook further (and check the results)
    print(f"Page number: {pages_and_chunks[idx]['page_number']}")
    print("\n")

Query: 'reading skill'

Results:
Score: 0.5677
Text:
61                                       Reading skills       2.2 Look at tliis
task based on (he Reading passage. For each     question, underline the type of
information you need to scan Tor.   The first two have been done for you.
Which paragraph contains the following information?     N. B You may use any
letter more than once     Write the correct letter. A-E, next to questions 1-7
below,     1 visual evidence of the gecko's ability to resist water     2 a
question that is yet to be answered by the researchers     3 the method used to
calculate the gripping power of geckos     4 the researcher's opinion of the
gecko’s gripping ability     5 a mention of the different environments where
geckos can be found     6 the contrast between Stark's research and the work of
other researchers     7 the definition of a scientific term       2.3 It is
important to fully understand what you are looking for in   the passage. Answer
these quest

### Functionizing our semantic search pipeline

In [None]:
def retrieve_relevant_resources(query: str,
                                embeddings: torch.tensor,
                                model: SentenceTransformer=embedding_model,
                                n_resources_to_return: int=5,
                                print_time: bool=True):
    """
    Embeds a query with model and returns top k scores and indices from embeddings.
    """

    # Embed the query
    query_embedding = model.encode(query,
                                   convert_to_tensor=True)

    # Get dot product scores on embeddings
    start_time = timer()
    dot_scores = util.dot_score(query_embedding, embeddings)[0]
    end_time = timer()

    if print_time:
        print(f"[INFO] Time taken to get scores on {len(embeddings)} embeddings: {end_time-start_time:.5f} seconds.")

    scores, indices = torch.topk(input=dot_scores,
                                 k=n_resources_to_return)

    return scores, indices

def print_top_results_and_scores(query: str,
                                 embeddings: torch.tensor,
                                 pages_and_chunks: list[dict]=pages_and_chunks,
                                 n_resources_to_return: int=5):
    """
    Takes a query, retrieves most relevant resources and prints them out in descending order.

    Note: Requires pages_and_chunks to be formatted in a specific way (see above for reference).
    """

    scores, indices = retrieve_relevant_resources(query=query,
                                                  embeddings=embeddings,
                                                  n_resources_to_return=n_resources_to_return)

    print(f"Query: {query}\n")
    print("Results:")
    # Loop through zipped together scores and indicies
    for score, index in zip(scores, indices):
        print(f"Score: {score:.4f}")
        # Print relevant sentence chunk (since the scores are in descending order, the most relevant chunk will be first)
        print_wrapped(pages_and_chunks[index]["sentence_chunk"])
        # Print the page number too so we can reference the textbook further and check the results
        print(f"Page number: {pages_and_chunks[index]['page_number']}")
        print("\n")

In [None]:
query = "listerning skills"

# Get just the scores and indices of top related results
scores, indices = retrieve_relevant_resources(query=query,
                                              embeddings=embeddings)
scores, indices

[INFO] Time taken to get scores on 30 embeddings: 0.00007 seconds.


(tensor([0.3542, 0.3339, 0.3295, 0.2792, 0.2771], device='cuda:0'),
 tensor([ 4, 23,  2,  0,  5], device='cuda:0'))

##**Loading the LLM**

In [None]:
!pip install bitsandbytes accelerate



In [None]:
from huggingface_hub import login
login("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")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import is_flash_attn_2_available

model_id = "google/gemma-2b-it"
use_quantization_config = False

from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                         bnb_4bit_compute_dtype=torch.float16)


if (is_flash_attn_2_available()) and (torch.cuda.get_device_capability(0)[0] >= 8):
  attn_implementation = "flash_attention_2"
else:
  attn_implementation = "sdpa"
print(f"[INFO] Using attention implementation: {attn_implementation}")

# 2. Pick a model we'd like to use (this will depend on how much GPU memory you have available)
#model_id = "google/gemma-7b-it"
model_id = model_id # (we already set this above)
print(f"[INFO] Using model_id: {model_id}")

# 3. Instantiate tokenizer (tokenizer turns text into numbers ready for the model)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_id)

# 4. Instantiate the model
llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id,
                                                 torch_dtype=torch.float16, # datatype to use, we want float16
                                                 quantization_config=quantization_config if use_quantization_config else None,
                                                 low_cpu_mem_usage=False, # use full memory
                                                 attn_implementation=attn_implementation) # which attention version to use

if not use_quantization_config: # quantization takes care of device setting automatically, so if it's not used, send model to GPU
    llm_model.to("cuda")

[INFO] Using attention implementation: sdpa
[INFO] Using model_id: google/gemma-2b-it


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

We've got an LLM!

Let's check it out.

In [None]:
llm_model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): GemmaRMSNorm((2048,), 

How about we get the number of parameters in our model?

In [None]:
def get_model_num_params(model: torch.nn.Module):
    return sum([param.numel() for param in model.parameters()])

get_model_num_params(llm_model)

2506172416

####**Generating text with our LLM(gemma)**

In [None]:
input_text = "how can I improve speaking skills"
print(f"Input text:\n{input_text}")

# Create prompt template for instruction-tuned model
dialogue_template = [
    {"role": "user",
     "content": input_text}
]

# Apply the chat template
prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                       tokenize=False, # keep as raw text (not tokenized)
                                       add_generation_prompt=True)
print(f"\nPrompt (formatted):\n{prompt}")

Input text:
how can I improve speaking skills

Prompt (formatted):
<bos><start_of_turn>user
how can I improve speaking skills<end_of_turn>
<start_of_turn>model



In [None]:
%%time

# Tokenize the input text (turn it into numbers) and send it to GPU
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
print(f"Model input (tokenized):\n{input_ids}\n")

# Generate outputs passed on the tokenized input
outputs = llm_model.generate(**input_ids,
                             max_new_tokens=256) # define the maximum number of new tokens to create
print(f"Model output (tokens):\n{outputs[0]}\n")

Model input (tokenized):
{'input_ids': tensor([[    2,     2,   106,  1645,   108,  1139,   798,   590,  4771, 13041,
          7841,   107,   108,   106,  2516,   108]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

Model output (tokens):
tensor([     2,      2,    106,   1645,    108,   1139,    798,    590,   4771,
         13041,   7841,    107,    108,    106,   2516,    108,    688, 235274,
        235265,  19670, 186522,  66058,    108, 235290, 100922,    575,  30893,
           675,  11634,  22660,    689,   6016,    675,    476,   5255,   9670,
        235265,    108, 235290,  20470,    476,   5255,  10036,   2778,    689,
          3650,  16875, 235265,    108, 235290,  15940,   5804,  13041,    578,
         10724,   1355,    577,  11441,   4516,    604,  13194, 235265,    109,
           688, 235284, 235265,  26349,    611,  58212,  42535,  66058,    108,
        235290,   8138,   6137,    577,    573,   97

In [None]:
# Decode the output tokens to text
outputs_decoded = tokenizer.decode(outputs[0])
print(f"Model output (decoded):\n{outputs_decoded}\n")

Model output (decoded):
<bos><bos><start_of_turn>user
how can I improve speaking skills<end_of_turn>
<start_of_turn>model
**1. Practice Regularly:**
- Engage in conversations with native speakers or practice with a language partner.
- Join a language exchange group or online forum.
- Record yourself speaking and listen back to identify areas for improvement.

**2. Focus on Pronunciation:**
- Pay attention to the sounds of the language, including intonation, rhythm, and stress.
- Use pronunciation tools and recordings to learn correct pronunciation.
- Practice speaking words and phrases out loud, focusing on different accents.

**3. Expand Your Vocabulary:**
- Read extensively in the target language, both fiction and non-fiction.
- Listen to podcasts, audiobooks, and music in the language.
- Use flashcards and spaced repetition techniques to learn new words.

**4. Read Fluently:**
- Start with children's books or simple texts and gradually progress to more complex materials.
- Pay atten

In [None]:
print(f"Input text: {input_text}\n")
print(f"Output text:\n{outputs_decoded.replace(prompt, '').replace('<bos>', '').replace('<eos>', '')}")

Input text: how can I improve speaking skills

Output text:
**1. Practice Regularly:**
- Engage in conversations with native speakers or practice with a language partner.
- Join a language exchange group or online forum.
- Record yourself speaking and listen back to identify areas for improvement.

**2. Focus on Pronunciation:**
- Pay attention to the sounds of the language, including intonation, rhythm, and stress.
- Use pronunciation tools and recordings to learn correct pronunciation.
- Practice speaking words and phrases out loud, focusing on different accents.

**3. Expand Your Vocabulary:**
- Read extensively in the target language, both fiction and non-fiction.
- Listen to podcasts, audiobooks, and music in the language.
- Use flashcards and spaced repetition techniques to learn new words.

**4. Read Fluently:**
- Start with children's books or simple texts and gradually progress to more complex materials.
- Pay attention to grammar, punctuation, and sentence structure.
- Use a 

Augmentation.

In [None]:
# IELTS-style questions generated with GPT-4
gpt4_questions = [
    "What are the best strategies to improve your IELTS speaking score?",
    "How can you effectively manage time during the IELTS reading test?",
    "Describe techniques to write a high-scoring IELTS essay.",
    "What role does vocabulary play in the IELTS listening section?",
    "Explain the importance of practicing mock tests for the IELTS exam.",
    "How can you build fluency for the IELTS speaking test?",
]

# Manually created question list
manual_questions = [
    "What are common mistakes to avoid in the IELTS writing task?",
    "How should you prepare for the IELTS listening section?",
    "What is the ideal structure for an IELTS Task 2 essay?",
    "What are the key differences between IELTS Academic and General Training?",
    "How can you improve your band score in the IELTS reading test?",
]

# Combine GPT-4 generated and manually created questions
query_list = gpt4_questions + manual_questions

And now let's check if our `retrieve_relevant_resources()` function works with our list of queries.

In [None]:
import random
query = random.choice(query_list)

print(f"Query: {query}")

# Get just the scores and indices of top related results
scores, indices = retrieve_relevant_resources(query=query,
                                              embeddings=embeddings)
scores, indices

Query: What are common mistakes to avoid in the IELTS writing task?
[INFO] Time taken to get scores on 30 embeddings: 0.00009 seconds.


(tensor([0.4694, 0.4677, 0.4606, 0.4605, 0.4586], device='cuda:0'),
 tensor([12, 23,  4, 17,  0], device='cuda:0'))

####**Augmenting our prompt with context items**

In [None]:
def prompt_formatter(query: str, context_items: list[dict]) -> str:
    """
    Augments query with text-based context from context_items.
    """
    # Join context items into a single paragraph
    context = "\n\n".join([item["sentence_chunk"] for item in context_items])

    # Create the improved base prompt
    base_prompt = f"""Based on the following context items, provide the most helpful and detailed answer to the query below.
    If you cannot find relevant information in the provided context, use your general knowledge and logical reasoning to generate a well-informed, accurate, and practical answer.
    Ensure that your response remains factual, logical, and does not speculate beyond reasonable assumptions.

    Before generating the final answer, show your thought process step by step. These steps should include:
    1. Identifying relevant information from the provided context (if available).
    2. Explaining how the context or your reasoning is applied to answer the query.
    3. Highlighting any assumptions made if the context is insufficient.

    Finally, provide your answer in a clear and concise manner. Use the following examples as a reference for the ideal answer style. Your answer should not include the examples themselves, only follow their structure and tone.

    Example 1:
    Query: What are the best strategies to improve your IELTS speaking score?
    Answer: To improve your IELTS speaking score, focus on fluency and coherence by practicing speaking with friends or recording yourself and listening for areas of improvement. Expand your vocabulary by learning phrases and idioms relevant to common IELTS topics, such as education, environment, and technology. Additionally, practice answering past IELTS speaking questions under timed conditions to simulate the test environment.

    Example 2:
    Query: How can you effectively manage time during the IELTS reading test?
    Answer: To manage time effectively during the IELTS reading test, start by quickly skimming the passage to get a general idea of its content. Then, read the questions and underline key information. Divide your time equally across the three sections, spending no more than 20 minutes per section. If you encounter difficult questions, move on and return to them later if time permits.

    Example 3:
    Query: What is the ideal structure for an IELTS Writing Task 2 essay?
    Answer: A high-scoring IELTS Writing Task 2 essay should include an introduction that clearly states your position, two or three body paragraphs with arguments supported by examples, and a conclusion that summarizes your key points. Ensure coherence and cohesion by using linking words such as "however," "therefore," and "in addition." Also, proofread your essay to avoid grammatical mistakes and spelling errors.

    Context:
    {context}

    Query: {query}

    Explain your thought process step by step:
    1. ...
    2. ...
    3. ...

    Final Answer:"""

    # Update the base prompt with context items and query
    dialogue_template = [
        {"role": "user", "content": base_prompt}
    ]

    # Generate the prompt
    prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                          tokenize=False,
                                          add_generation_prompt=True)
    return prompt

Let's try our function out.

In [None]:
query = random.choice(query_list)
print(f"Query: {query}")

# Get relevant resources
scores, indices = retrieve_relevant_resources(query=query,
                                              embeddings=embeddings)

# Create a list of context items
context_items = [pages_and_chunks[i] for i in indices]

# Format prompt with context items
prompt = prompt_formatter(query=query,
                          context_items=context_items)
print(prompt)

Query: What are the key differences between IELTS Academic and General Training?
[INFO] Time taken to get scores on 30 embeddings: 0.00008 seconds.
<bos><start_of_turn>user
Based on the following context items, provide the most helpful and detailed answer to the query below.
    If you cannot find relevant information in the provided context, use your general knowledge and logical reasoning to generate a well-informed, accurate, and practical answer.
    Ensure that your response remains factual, logical, and does not speculate beyond reasonable assumptions.

    Before generating the final answer, show your thought process step by step. These steps should include:
    1. Identifying relevant information from the provided context (if available).
    2. Explaining how the context or your reasoning is applied to answer the query.
    3. Highlighting any assumptions made if the context is insufficient.

    Finally, provide your answer in a clear and concise manner. Use the following exam

We can tokenize this and pass it straight to our LLM.

In [None]:
%%time

input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")

# Generate an output of tokens
outputs = llm_model.generate(**input_ids,
                             temperature=0.7, # lower temperature = more deterministic outputs, higher temperature = more creative outputs
                             do_sample=True,
                             max_new_tokens=256) # how many new tokens to generate from prompt

# Turn the output tokens into text
output_text = tokenizer.decode(outputs[0])

print(f"Query: {query}")
print(f"RAG answer:\n{output_text.replace(prompt, '')}")

Query: What are the key differences between IELTS Academic and General Training?
RAG answer:
<bos>**Thought Process:**

**1. Identifying Relevant Information**

* The passage does not provide any directly relevant information about IELTS Academic and General Training, so I cannot identify any key differences between the two programs from the context.

**2. Explaining Reasoning**

I am unable to generate a response because the context does not provide any information about the key differences between IELTS Academic and General Training.

**3. Assumptions Made**

The context does not provide any assumptions, so I cannot generate a response.<eos>
CPU times: user 3.31 s, sys: 5.68 ms, total: 3.31 s
Wall time: 3.31 s


How about we functionize the generation step to make it easier to use?

In [None]:
def ask(query,
        temperature=0.7,
        max_new_tokens=512,
        format_answer_text=True,
        return_answer_only=True):
    """
    Takes a query, finds relevant resources/context and generates an answer to the query based on the relevant resources.
    """

    # Get just the scores and indices of top related results
    scores, indices = retrieve_relevant_resources(query=query,
                                                  embeddings=embeddings)

    # Create a list of context items
    context_items = [pages_and_chunks[i] for i in indices]

    # Add score to context item
    for i, item in enumerate(context_items):
        item["score"] = scores[i].cpu() # return score back to CPU

    # Format the prompt with context items
    prompt = prompt_formatter(query=query,
                              context_items=context_items)

    # Tokenize the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")

    # Generate an output of tokens
    outputs = llm_model.generate(**input_ids,
                                 temperature=temperature,
                                 do_sample=True,
                                 max_new_tokens=max_new_tokens)

    # Turn the output tokens into text
    output_text = tokenizer.decode(outputs[0])

    if format_answer_text:
        # Replace special tokens and unnecessary help message
        output_text = output_text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "").replace("Sure, here is the answer to the user query:\n\n", "")

    # Only return the answer without the context items
    if return_answer_only:
        return output_text

    return output_text, context_items

Let's try it out.

In [None]:
# random.choice(query_list)

In [None]:
query = random.choice(query_list)
print(f"Query: {query}")

# Answer query with context and return context
answer, context_items = ask(query=query,
                            temperature=0.7,
                            max_new_tokens=512,
                            return_answer_only=False)

print(f"Answer:\n")
print_wrapped(answer)
# print(f"Context items:")
#context_items

Query: What are the best strategies to improve your IELTS speaking score?
[INFO] Time taken to get scores on 30 embeddings: 0.00007 seconds.
Answer:

## Thought process:  **Step 1: Identifying relevant information**  * The context
mentions that improving your IELTS speaking score requires practicing speaking
with friends or recording yourself and listening for areas of improvement. * It
also suggests learning phrases and idioms relevant to common IELTS topics. *
These suggest that practicing speaking in a social setting, learning vocabulary,
and being familiar with idiomatic expressions are key strategies for improving
speaking skills.  **Step 2: Applying the context**  The context advises
practicing speaking with friends, recording yourself, and listening for areas of
improvement. It also suggests learning vocabulary and idioms relevant to common
IELTS topics.  **Step 3: Assumptions**  * The context does not provide any
specific information or guidelines for practicing speaking in a s

####**Text to speech model**

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Using the gemma-2b-it model
model_id = "google/gemma-2b-it"

# Loading the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Set chat_template correctly
tokenizer.chat_template = "{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '\\n' }}{% endfor %}"

# Loading the model
llm_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)

# Move the model to the GPU (if available)
llm_model.to("cuda")

print("The model and tokenizer were successfully loaded, and the chat_template was set!")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

模型和 tokenizer 已成功加载，并设置了 chat_template！


In [None]:
!pip install transformers==4.51.3 accelerate==1.6.0 --no-warn-script-location --quiet

In [None]:
#!pip install --upgrade transformers accelerate


In [None]:
from transformers import VitsModel, AutoTokenizer
import torch

model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")

text = answer
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    output = model(**inputs).waveform


In [None]:
from IPython.display import Audio

Audio(output.numpy(), rate=model.config.sampling_rate)

In [None]:
from scipy.io.wavfile import write
from google.colab import files
import numpy as np

# Assume `output` is the generated audio data (Tensor) with sampling rate `model.config.sampling_rate`
sampling_rate = model.config.sampling_rate
output_filename = "generated_audio.wav"

# 1. Convert audio data to numpy array
audio_data = output.numpy()  # 转换为 numpy 数组

# 2. Check the range of the audio data and normalize it to [-1.0, 1.0]
# If the data range is not [-1.0, 1.0], you need to normalize it first
if audio_data.min() < -1.0 or audio_data.max() > 1.0:
    audio_data = audio_data / np.max(np.abs(audio_data))  # Normalized to [-1.0, 1.0]

# 3. Make sure the audio data is a one-dimensional array (mono)
if len(audio_data.shape) > 1:
    audio_data = audio_data.squeeze()  # Remove redundant dimensions

# 4. Convert data from [-1.0, 1.0] to int16 range [-32768, 32767]
audio_data = (audio_data * 32767).astype(np.int16)

# 5. Save audio to .wav file
write(output_filename, sampling_rate, audio_data)

# 6. Download the audio file to your local computer
files.download(output_filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
print(tokenizer)  # Check if the current tokenizer is correct
print(hasattr(tokenizer, "chat_template"))  # Check if chat_template is set

VitsTokenizer(name_or_path='facebook/mms-tts-eng', vocab_size=38, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '<unk>', 'pad_token': 'k'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("k", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	38: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)
True


####**SoVITS model**

In [None]:
%%writefile /content/setup.sh
set -e
cd /content
rm -rf GPT-SoVITS
git clone https://github.com/RVC-Boss/GPT-SoVITS.git
cd GPT-SoVITS

if conda env list | awk '{print $1}' | grep -Fxq "GPTSoVITS"; then
    :
else
    conda create -n GPTSoVITS python=3.10 -y
fi

source activate GPTSoVITS

bash install.sh --source HF --download-uvr5

In [None]:
%pip install -q condacolab
import condacolab
condacolab.install_from_url("https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh")
!cd /content && bash setup.sh

In [None]:
!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py

##**Gradio with writing evluating system+ tts + system workflow**

In [None]:
!pip install --upgrade gradio



In [None]:
pip install gradio requests



In [None]:
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, VitsModel, AutoTokenizer as TTSTokenizer
from sentence_transformers import SentenceTransformer, util
import numpy as np
import time

# Add language detection function
def detect_language(text):
    """Detect whether the text is Chinese or English"""
    # Simple judgment: if it contains Chinese characters, it is considered Chinese
    for char in text:
        if '\u4e00' <= char <= '\u9fff':
            return "zh"
    return "en"

# 1. Load existing models and data
device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_model = SentenceTransformer("all-mpnet-base-v2").to(device)

# Simple user data storage
user_progress = {}

# Load your text data and embeddings
try:
    import pandas as pd
    text_chunks_and_embedding_df = pd.read_csv("text_chunks_and_embeddings_df.csv")
    # Transform Embed
    text_chunks_and_embedding_df["embedding"] = text_chunks_and_embedding_df["embedding"].apply(
        lambda x: np.fromstring(x.strip("[]"), sep=" "))
    pages_and_chunks = text_chunks_and_embedding_df.to_dict(orient="records")
    embeddings = torch.tensor(np.array(text_chunks_and_embedding_df["embedding"].tolist()),
                              dtype=torch.float32).to(device)
    print("✅ Successfully loaded data")
except Exception as e:
    print(f"❌ Failed to load data: {e}")
    # If the data loading fails, you can provide some sample data
    pages_and_chunks = []
    embeddings = None

# 2. Load LLM
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Add TTS model (lazy loading to save memory)
tts_model = None
tts_tokenizer = None

def load_tts_model():
    global tts_model, tts_tokenizer
    if tts_model is None:
        print("Loading TTS model...")
        tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
        tts_tokenizer = TTSTokenizer.from_pretrained("facebook/mms-tts-eng")
        print("✅ TTS model loaded")
    return tts_model, tts_tokenizer

# 3. Define retrieval and generation functions
def retrieve_relevant_resources(query, embeddings, n_resources=5):
    """Search for related resources"""
    query_embedding = embedding_model.encode(query, convert_to_tensor=True)
    dot_scores = util.dot_score(query_embedding, embeddings)[0]
    scores, indices = torch.topk(dot_scores, k=n_resources)
    return scores, indices

def generate_answer(query, context_items, avg_relevance_score=0.0):
    """Generate answers, determine the answer language based on the query language, and decide whether to use the search content based on relevance"""
    # Set a relevance threshold below which LLM knowledge is used instead of search content
    relevance_threshold = 0.65

    # All output uses English by default
    language = "en"

    # Determine the prompt content based on relevance
    if avg_relevance_score < relevance_threshold:
        # Low relevance, let the model use its own knowledge
        prompt = f"""Here is a question about the IELTS exam. Since no sufficiently relevant reference materials were found, please use your own knowledge to answer.

Question: {query}

Answer:"""
    else:
        # Highly relevant, using the retrieved content
        context = "\n\n".join([item["sentence_chunk"] for item in context_items])

        prompt = f"""Based on the following IELTS materials, answer the question:

Content:
{context}

Question: {query}

Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
    return response, language

def evaluate_response_quality(query, response, relevance_score):
    """Evaluate the quality of the system's answers"""
    # Evaluate based on relevance and heuristic rules
    coherence_score = min(1.0, 0.5 + relevance_score * 0.5)  # Simple heuristic rules

    # Generate English evaluation report
    if relevance_score > 0.8:
        relevance_comment = "Highly Relevant - The answer directly addresses the core of the question"
    elif relevance_score > 0.6:
        relevance_comment = "Relevant - The answer covers main aspects of the question"
    else:
        relevance_comment = "Partially Relevant - The answer only partially addresses the question"

    # Text length evaluation
    if len(response) < 50:
        length_comment = "Too Short - The answer may not be comprehensive"
    elif len(response) > 500:
        length_comment = "Extensive - The answer is very comprehensive"
    else:
        length_comment = "Adequate - The answer length is reasonable"

    # Portfolio Assessment Report
    report = f"""
### Answer Quality Assessment

**Relevance**: {relevance_score:.2f}/1.0 - {relevance_comment}
**Coherence**: {coherence_score:.2f}/1.0
**Length**: {len(response)} characters - {length_comment}

**Overall Rating**: {(relevance_score + coherence_score) / 2:.2f}/1.0
    """

    return report

def process_query(query):
    """Processes the query and returns results and performance indicators (improved version with relevance evaluation)"""
    # Check if data has been loaded
    if embeddings is None:
        return "Data not loaded, please run the data processing code first", "", "Performance metrics unavailable: data not loaded"

    # Recording start time
    start_time = time.time()

    # Get relevant resources
    retrieval_start = time.time()
    scores, indices = retrieve_relevant_resources(query, embeddings)
    retrieval_time = time.time() - retrieval_start

    context_items = [pages_and_chunks[i] for i in indices]

    # Calculate the average relevance score
    avg_relevance = scores.mean().item()

    # Record generation start time
    generation_start = time.time()

    # Generate Answer (now passes relevance score)
    answer, detected_language = generate_answer(query, context_items, avg_relevance)

    # Calculate generation time
    generation_time = time.time() - generation_start
    total_time = time.time() - start_time

    # Preparing context information for display
    context_display = ""
    for i, (score, idx) in enumerate(zip(scores, indices)):
        context_display += f"**Reference {i+1}** (Relevance: {score:.2f}):\n{pages_and_chunks[idx]['sentence_chunk'][:200]}...\n\n"

    # Quality Assessment
    quality_report = evaluate_response_quality(query, answer, avg_relevance)

    # Preparing performance indicators
    metrics = f"""
### System Performance Metrics
- **Retrieval Time**: {retrieval_time:.2f} seconds
- **Generation Time**: {generation_time:.2f} seconds
- **Total Response Time**: {total_time:.2f} seconds
- **Average Relevance Score**: {avg_relevance:.4f}/1.0
- **Response Mode**: {"Retrieved Context" if avg_relevance >= 0.65 else "LLM Knowledge"}
{quality_report}
"""

    return answer, context_display, metrics

def process_query_with_history(query, history=""):
    """Processing queries and saving history"""
    answer, context, metrics = process_query(query)

    # Update History
    timestamp = time.strftime("%H:%M:%S")
    new_history = f"{history}<hr><b>[{timestamp}] Q:</b> {query}<br><b>A:</b> {answer}<br>"

    return answer, context, metrics, new_history

def track_user_activity(username, activity_type, content, score=None):
    """Tracking user learning activities"""
    if username not in user_progress:
        user_progress[username] = {
            "queries": [],
            "writing_samples": [],
            "practice_tests": [],
            "last_active": None
        }

    timestamp = time.strftime("%Y-%m-%d %H:%M:%S")

    if activity_type == "query":
        user_progress[username]["queries"].append({
            "timestamp": timestamp,
            "query": content,
            "relevance_score": score
        })
    elif activity_type == "writing":
        user_progress[username]["writing_samples"].append({
            "timestamp": timestamp,
            "sample": content[:100] + "...",  # Save Summary
            "score": score
        })
    elif activity_type == "practice":
        user_progress[username]["practice_tests"].append({
            "timestamp": timestamp,
            "test_type": content,
            "completed": True
        })

    user_progress[username]["last_active"] = timestamp

    # Build Progress Summary
    summary = f"""
### Learning Progress Summary ({username})
- Questions asked: {len(user_progress[username]["queries"])}
- Writing samples: {len(user_progress[username]["writing_samples"])}
- Practice tests: {len(user_progress[username]["practice_tests"])}
- Last active: {user_progress[username]["last_active"]}

#### Recent Activity
"""

    # Add the last 5 events
    recent_queries = user_progress[username]["queries"][-3:] if user_progress[username]["queries"] else []
    recent_writings = user_progress[username]["writing_samples"][-2:] if user_progress[username]["writing_samples"] else []

    for q in recent_queries:
        summary += f"- [{q['timestamp']}] Question: {q['query'][:50]}...\n"

    for w in recent_writings:
        summary += f"- [{w['timestamp']}] Writing practice\n"

    return summary

def process_query_with_tracking(query, username):
    """Handle inquiries and track learning progress"""
    answer, context, metrics = process_query(query)

    # Extracting relevance scores from metrics
    import re
    relevance_match = re.search(r"Average Relevance Score: ([\d\.]+)", metrics)
    relevance_score = float(relevance_match.group(1)) if relevance_match else None

    # Track this query
    progress = track_user_activity(username, "query", query, relevance_score)

    return answer, context, metrics, progress

# Modify the TTS function to ensure that the complete content is processed
def text_to_speech(text):
    """Convert text to speech, process full English text"""
    if not text:
        return None

    # Load the TTS model (if not already loaded)
    model, tokenizer = load_tts_model()

    # If the text is too long, process it in segments and concatenate them
    max_segment_length = 500  # Maximum length of each segment
    segments = []

    # Segment text
    if len(text) > max_segment_length:
        words = text.split()
        current_segment = []
        current_length = 0

        for word in words:
            current_length += len(word) + 1  # +1 for space
            if current_length <= max_segment_length:
                current_segment.append(word)
            else:
                segments.append(" ".join(current_segment))
                current_segment = [word]
                current_length = len(word) + 1

        if current_segment:
            segments.append(" ".join(current_segment))
    else:
        segments = [text]

    # Process each paragraph and concatenate
    full_waveform = None
    sample_rate = None

    for segment in segments:
        inputs = tokenizer(segment, return_tensors="pt")
        with torch.no_grad():
            output = model(**inputs).waveform

        if full_waveform is None:
            full_waveform = output[0].numpy()
            sample_rate = model.config.sampling_rate
        else:
            # Add a short pause (0.3 seconds of silence)
            pause = np.zeros(int(0.3 * sample_rate))
            full_waveform = np.concatenate([full_waveform, pause, output[0].numpy()])

    # Return to full audio
    return (sample_rate, full_waveform)

# Text-to-speech language detection wrapper function
def tts_with_language_check(text):
    if not text:
        return None, "Please provide text content"

    language = detect_language(text)
    if language == "zh":
        return None, "⚠️ Only English text is supported for TTS. Please provide English text."
    else:
        try:
            audio = text_to_speech(text)
            return audio, "✅ Conversion successful! Full content converted to speech."
        except Exception as e:
            return None, f"❌ Error during conversion: {str(e)}"

# Added IELTS writing scoring function
def evaluate_ielts_writing(writing_sample, username="default_user"):
    """Evaluate IELTS writing samples and keep track of records"""
    if not writing_sample:
        return "Please provide a writing sample for evaluation."

    prompt = f"""As an IELTS examiner, please assess the following student writing sample.
    Provide scores and specific suggestions based on these criteria:
    1. Task Response
    2. Coherence and Cohesion
    3. Lexical Resource
    4. Grammatical Range and Accuracy

    Give scores in 0.5 increments (e.g., 6.0, 6.5) for each category, and provide an overall score.

    Student writing sample:
    {writing_sample}

    Score and detailed feedback:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.2,
            do_sample=True
        )
    feedback = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")

    # Extract total score
    import re
    score_match = re.search(r"Overall score.*?(\d+\.?\d*)", feedback)
    overall_score = float(score_match.group(1)) if score_match else None

    # Record writing activities
    track_user_activity(username, "writing", writing_sample, overall_score)

    return feedback

# Added mock exam feature
def generate_practice_question(section_type):
    """Generate IELTS practice questions"""
    section_type_english = section_type.split(" ")[0]  # Get the English section

    prompt = f"""Create an IELTS {section_type_english} practice question.
    Include detailed questions, guidance, and scoring criteria.
    For Writing or Speaking sections, provide a sample question and response framework.
    For Listening or Reading sections, provide sample questions and answer options."""

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True
        )
    practice = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
    return practice

# Add learning plan generation function
def generate_study_plan(target_score, weeks_available, strengths, weaknesses):
    """Generate a personalized IELTS study plan"""
    prompt = f"""Create a personalized IELTS study plan with the following conditions:

    Target Score: {target_score}
    Available Time: {weeks_available} weeks
    Strengths: {strengths}
    Weaknesses: {weaknesses}

    Please provide:
    1. Detailed weekly study plan
    2. Recommended learning resources
    3. Specific exercises for weaknesses
    4. Regular mock test schedule
    5. Pre-exam preparation strategy
    """

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.7,
            do_sample=True
        )
    plan = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
    return plan

def export_history(history):
    """Export session history as text"""
    if not history:
        return "No conversation history to export"

    try:
        from bs4 import BeautifulSoup
        import re

        # Parsing HTML and extracting text using BeautifulSoup
        soup = BeautifulSoup(history, "html.parser")
        text = soup.get_text()

        # Cleaning up the text
        cleaned_text = re.sub(r'\s+', ' ', text).strip()

        # Returns the cleaned text and timestamp
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        return f"Conversation exported. Filename: ielts_conversation_{timestamp}.txt\n\n{cleaned_text[:100]}..."
    except:
        # Simple alternate extraction
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        return f"Conversation exported. Filename: ielts_conversation_{timestamp}.txt"

# Improve system architecture diagram generation function
def generate_system_diagram():
    """Generate comprehensive system architecture diagram"""
    from PIL import Image, ImageDraw, ImageFont
    import io
    import base64
    import os

    # Create a larger image to show the full content
    width, height = 1000, 650
    image = Image.new("RGB", (width, height), "white")
    draw = ImageDraw.Draw(image)

    # Try loading a better font
    try:
        # Try common fonts, depending on the system
        font_paths = [
            "arial.ttf",
            "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
            "/System/Library/Fonts/Helvetica.ttc"
        ]

        font = None
        for path in font_paths:
            if os.path.exists(path):
                font = ImageFont.truetype(path, 16)
                title_font = ImageFont.truetype(path, 24)
                break

        if font is None:
            font = ImageFont.load_default()
            title_font = font
    except:
        font = ImageFont.load_default()
        title_font = font

    # Define the main process box
    main_flow_boxes = [
        {"x": 100, "y": 150, "width": 150, "height": 80, "text": "User Query", "color": "#FFD580", "description": "Questions about IELTS"},
        {"x": 320, "y": 150, "width": 150, "height": 80, "text": "Vector Retrieval", "color": "#90EE90", "description": "Embedding search"},
        {"x": 540, "y": 150, "width": 150, "height": 80, "text": "Relevant Content", "color": "#ADD8E6", "description": "Retrieved materials"},
        {"x": 760, "y": 150, "width": 150, "height": 80, "text": "LLM Generation", "color": "#DDA0DD", "description": "Gemma-2b-it model"},
        {"x": 760, "y": 300, "width": 150, "height": 80, "text": "Result Display", "color": "#FFCCCB", "description": "Answer with context"}
    ]

    # Define other functional modules
    feature_boxes = [
        {"x": 100, "y": 300, "width": 150, "height": 80, "text": "Writing Assessment", "color": "#F0E68C", "description": "Essay evaluation"},
        {"x": 320, "y": 300, "width": 150, "height": 80, "text": "Study Plan", "color": "#98FB98", "description": "Learning roadmap"},
        {"x": 540, "y": 300, "width": 150, "height": 80, "text": "Practice Tests", "color": "#87CEEB", "description": "Mock exams"},
        {"x": 320, "y": 450, "width": 150, "height": 80, "text": "Text-to-Speech", "color": "#D8BFD8", "description": "Voice output"},
        {"x": 540, "y": 450, "width": 150, "height": 80, "text": "Progress Tracking", "color": "#FFB6C1", "description": "Learning journey"}
    ]

    # Draw the main title
    title = "IELTS Learning Assistant System Architecture"
    title_w, title_h = draw.textsize(title, font=title_font) if hasattr(draw, 'textsize') else (width//2, 40)
    draw.text(((width - title_w) // 2, 30), title, fill="#4B0082", font=title_font)

    # Draw the main process box and description
    for box in main_flow_boxes:
        x, y = box["x"], box["y"]
        w, h = box["width"], box["height"]
        draw.rectangle([(x, y), (x+w, y+h)], fill=box["color"], outline="#000000", width=2)

        # Adding Main Text
        text_w, text_h = draw.textsize(box["text"], font=font) if hasattr(draw, 'textsize') else (w//2, h//3)
        text_x = x + (w - text_w) // 2
        text_y = y + (h - text_h) // 3
        draw.text((text_x, text_y), box["text"], fill="#000000", font=font)

        # Add description text
        desc_w, desc_h = draw.textsize(box["description"], font=font) if hasattr(draw, 'textsize') else (w//2, h//3)
        desc_x = x + (w - desc_w) // 2
        desc_y = y + h - text_h - 10
        draw.text((desc_x, desc_y), box["description"], fill="#333333", font=font)

    # Draw function module boxes and descriptions
    for box in feature_boxes:
        x, y = box["x"], box["y"]
        w, h = box["width"], box["height"]
        draw.rectangle([(x, y), (x+w, y+h)], fill=box["color"], outline="#000000", width=2)

        # Adding Main Text
        text_w, text_h = draw.textsize(box["text"], font=font) if hasattr(draw, 'textsize') else (w//2, h//3)
        text_x = x + (w - text_w) // 2
        text_y = y + (h - text_h) // 3
        draw.text((text_x, text_y), box["text"], fill="#000000", font=font)

        # Add description text
        desc_w, desc_h = draw.textsize(box["description"], font=font) if hasattr(draw, 'textsize') else (w//2, h//3)
        desc_x = x + (w - desc_w) // 2
        desc_y = y + h - text_h - 10
        draw.text((desc_x, desc_y), box["description"], fill="#333333", font=font)

    # Drawing Connection Lines - Main Process
    for i in range(len(main_flow_boxes)-2):  # The last frame is processed separately
        x1 = main_flow_boxes[i]["x"] + main_flow_boxes[i]["width"]
        y1 = main_flow_boxes[i]["y"] + main_flow_boxes[i]["height"]//2
        x2 = main_flow_boxes[i+1]["x"]
        y2 = main_flow_boxes[i+1]["y"] + main_flow_boxes[i+1]["height"]//2

        # Lire
        draw.line([(x1, y1), (x2, y2)], fill="#000000", width=3)

        # Arrow
        arrow_size = 10
        draw.polygon([(x2-arrow_size, y2-arrow_size//2), (x2, y2), (x2-arrow_size, y2+arrow_size//2)], fill="#000000")

    # Connection from LLM Generation to Result Display
    llm_idx = 3  # LLM Generation Index
    result_idx = 4  #Result Display Index

    x1 = main_flow_boxes[llm_idx]["x"] + main_flow_boxes[llm_idx]["width"]//2
    y1 = main_flow_boxes[llm_idx]["y"] + main_flow_boxes[llm_idx]["height"]
    x2 = main_flow_boxes[result_idx]["x"] + main_flow_boxes[result_idx]["width"]//2
    y2 = main_flow_boxes[result_idx]["y"]

    # Lines and arrows
    draw.line([(x1, y1), (x1, y1+30), (x2, y1+30), (x2, y2)], fill="#000000", width=3)
    arrow_size = 10
    draw.polygon([(x2-arrow_size//2, y2-arrow_size), (x2, y2), (x2+arrow_size//2, y2-arrow_size)], fill="#000000")

    # LLM Generation to TTS connection
    tts_idx = 3  # Text-to-Speech Index

    x1 = main_flow_boxes[llm_idx]["x"]
    y1 = main_flow_boxes[llm_idx]["y"] + main_flow_boxes[llm_idx]["height"]//2
    x2 = feature_boxes[tts_idx]["x"] + feature_boxes[tts_idx]["width"]//2
    y2 = feature_boxes[tts_idx]["y"]

    draw.line([(x1, y1), (x1-30, y1), (x1-30, y2-30), (x2, y2-30), (x2, y2)], fill="#000000", width=2)
    arrow_size = 10
    draw.polygon([(x2-arrow_size//2, y2-arrow_size), (x2, y2), (x2+arrow_size//2, y2-arrow_size)], fill="#000000")

    #Add a legend
    legend_y = 580
    legend_items = [
        {"text": "Main Process Flow", "color": "#000000", "width": 3},
        {"text": "Additional Features", "color": "#000000", "width": 2},
    ]

    for i, item in enumerate(legend_items):
        x_pos = 100 + i * 300
        # Line
        draw.line([(x_pos, legend_y), (x_pos + 50, legend_y)], fill=item["color"], width=item["width"])
        # text
        draw.text((x_pos + 60, legend_y - 5), item["text"], fill="#000000", font=font)

    # Add bottom note
    footer_text = "Built with SentenceTransformer, Google Gemma-2b-it, and Facebook MMS-TTS-Eng"
    footer_w, footer_h = draw.textsize(footer_text, font=font) if hasattr(draw, 'textsize') else (width//2, 20)
    draw.text(((width - footer_w) // 2, height - 30), footer_text, fill="#666666", font=font)

    return image

# 4. Creating the Gradio Interface
def create_interface():
    """Creating the Gradio Interface"""
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        with gr.Tab("Main Application"):
            gr.Markdown("# IELTS Learning Assistant")
            gr.Markdown("""
            This application uses artificial intelligence to answer questions about the IELTS exam. It is based on IELTS textbook content and can help you understand key aspects of the exam and improve your skills.

            **Language Support**:
            - Text-to-speech functionality is only available for English
            """)

            # User Information
            with gr.Row():
                username_input = gr.Textbox(
                    label="Username",
                    placeholder="Enter your username to track learning progress",
                    value="default_user"
                )

            with gr.Row():
                with gr.Column(scale=2):
                    query_input = gr.Textbox(
                        label="Question",
                        placeholder="Enter your question about the IELTS exam...",
                        lines=2
                    )
                    with gr.Row():
                        submit_btn = gr.Button("Submit", variant="primary")
                        clear_btn = gr.Button("Clear")

                    response_output = gr.Textbox(
                        label="Answer",
                        lines=10,
                        placeholder="AI's answer will appear here...",
                    )

                    # New: Independent performance indicator display
                    metrics_output = gr.Markdown(label="Performance Metrics")

                with gr.Column(scale=1):
                    with gr.Accordion("References", open=False):
                        context_output = gr.Markdown()

            # Session History
            with gr.Accordion("Conversation History", open=False):
                history_display = gr.HTML()
                with gr.Row():
                    export_btn = gr.Button("Export Conversation")
                    clear_history_btn = gr.Button("Clear History")
                export_status = gr.Textbox(label="Export Status", visible=False)

            # Example Question
            examples = [
                ["How can I improve my IELTS listening score?"],
                ["What is the ideal structure for IELTS Writing Task 2?"],
                ["How to manage time effectively in the IELTS reading test?"],
                ["What should I pay attention to in the IELTS speaking test?"],
                ["What are the common mistakes in IELTS Writing Task 1?"],
                ["How to achieve band 7+ in IELTS?"]
            ]

            # Setup Events - Updated to use tracking and history
            submit_btn.click(
                fn=process_query_with_history,
                inputs=[query_input, history_display],
                outputs=[response_output, context_output, metrics_output, history_display]
            )
            clear_btn.click(
                lambda: ["", "", ""],
                outputs=[query_input, response_output, metrics_output]
            )
            export_btn.click(
                fn=export_history,
                inputs=history_display,
                outputs=export_status
            )
            clear_history_btn.click(
                lambda: "",
                outputs=history_display
            )
            gr.Examples(
                examples=examples,
                inputs=query_input
            )

            # Adding text-to-speech functionality
            with gr.Accordion("Text-to-Speech Feature", open=False):
                gr.Markdown("## Convert Text to Speech")
                gr.Markdown("Enter English text or use the answer content to convert to speech.")

                with gr.Row():
                    text_for_tts = gr.Textbox(
                        label="Text to convert to speech",
                        lines=2,
                        placeholder="Enter English text to convert to speech..."
                    )
                    with gr.Column():
                        convert_btn = gr.Button("Convert Text", variant="secondary")
                        convert_response_btn = gr.Button("Convert Answer", variant="secondary")

                audio_output = gr.Audio(label="Speech Output")
                tts_status = gr.Markdown()  # Add status information display

                # Connecting TTS Function
                convert_btn.click(
                    fn=tts_with_language_check,
                    inputs=text_for_tts,
                    outputs=[audio_output, tts_status]
                )
                convert_response_btn.click(
                    fn=tts_with_language_check,
                    inputs=response_output,
                    outputs=[audio_output, tts_status]
                )

            # Add learning progress display
            view_progress_btn = gr.Button("View Learning Progress")
            progress_display = gr.Markdown(label="Learning Progress")

            view_progress_btn.click(
                fn=lambda username: track_user_activity(username, "query", "View progress"),
                inputs=username_input,
                outputs=progress_display
            )

        # Writing Grading Tab
        with gr.Tab("Writing Assessment"):
            gr.Markdown("# IELTS Writing Assessment")
            gr.Markdown("Upload your IELTS writing sample to get professional scoring and feedback.")

            writing_username = gr.Textbox(
                label="Username",
                placeholder="Enter your username to track progress",
                value="default_user"
            )

            writing_input = gr.Textbox(
                label="Paste your IELTS writing sample",
                placeholder="Paste your Task 1 or Task 2 writing here...",
                lines=10
            )
            evaluate_btn = gr.Button("Get Assessment", variant="primary")
            evaluation_output = gr.Markdown(label="Scoring and Feedback")
            writing_progress = gr.Markdown(label="Writing Progress")

            # Updated writing grades to track progress
            def evaluate_with_tracking(sample, username):
                feedback = evaluate_ielts_writing(sample, username)
                progress = track_user_activity(username, "writing", sample)
                return feedback, progress

            evaluate_btn.click(
                fn=evaluate_with_tracking,
                inputs=[writing_input, writing_username],
                outputs=[evaluation_output, writing_progress]
            )

        # Added mock exam tab
        with gr.Tab("Practice Tests"):
            gr.Markdown("# IELTS Practice Tests")
            gr.Markdown("Select an exam section to generate corresponding practice questions.")

            practice_username = gr.Textbox(
                label="Username",
                placeholder="Enter your username to track progress",
                value="default_user"
            )

            section_selector = gr.Dropdown(
                label="Select Exam Section",
                choices=["Listening", "Reading", "Writing", "Speaking"],
                value="Writing"
            )

            generate_btn = gr.Button("Generate Practice Question", variant="primary")
            practice_output = gr.Markdown(label="Practice Question")

            # Generate mock questions and track activity
            def generate_practice_with_tracking(section, username):
                practice = generate_practice_question(section)
                # Record this mock exam activity
                track_user_activity(username, "practice", section)
                return practice

            generate_btn.click(
                fn=generate_practice_with_tracking,
                inputs=[section_selector, practice_username],
                outputs=practice_output
            )

        # Add a learning plan tab
        with gr.Tab("Study Plan"):
            gr.Markdown("# Personalized IELTS Study Plan")
            gr.Markdown("Input your goals and conditions to get a customized IELTS study plan.")

            with gr.Row():
                target_score = gr.Slider(
                    label="Target Overall Score",
                    minimum=5.0,
                    maximum=9.0,
                    step=0.5,
                    value=7.0
                )
                weeks = gr.Slider(
                    label="Available Study Weeks",
                    minimum=1,
                    maximum=24,
                    step=1,
                    value=8
                )

            strengths = gr.Textbox(
                label="Your Strengths",
                placeholder="e.g.: Listening, Reading...",
                lines=2
            )

            weaknesses = gr.Textbox(
                label="Areas to Improve",
                placeholder="e.g.: Writing, Speaking...",
                lines=2
            )

            plan_btn = gr.Button("Generate Study Plan", variant="primary")
            plan_output = gr.Markdown(label="Personalized Study Plan")

            plan_btn.click(
                fn=generate_study_plan,
                inputs=[target_score, weeks, strengths, weaknesses],
                outputs=plan_output
            )

        # System Architecture Tab
        with gr.Tab("System Architecture"):
            with gr.Row():
                gr.Markdown("# IELTS Learning Assistant System Architecture")

            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("""
        ## Overall Architecture
        This system uses Retrieval-Augmented Generation (RAG) combined with Text-to-Speech (TTS) technology to provide IELTS learning assistance.

        ## How This System Works

        This IELTS Learning Assistant uses a powerful combination of **Retrieval-Augmented Generation (RAG)** and **Text-to-Speech** technology to provide accurate, contextually relevant responses to your IELTS questions.

        The system follows these steps:
        1. When you ask a question, it's converted to a vector representation
        2. This vector is compared against our IELTS materials database
        3. The most relevant content is retrieved
        4. The AI combines this content with its own knowledge to generate a helpful answer
        5. If relevance is low, the AI relies more on its own knowledge

        All answers are based on standard IELTS curriculum materials and best practices in IELTS preparation.
        """)

                with gr.Column(scale=1):
                    gr.Markdown("""
        ## Components
        1. **Data Preprocessing Module**
          - Splits IELTS textbook materials into semantic chunks
          - Generates embedding vectors using SentenceTransformer
          - Stores text chunks and corresponding embedding vectors

        2. **Retrieval Module**
          - Uses vector similarity search
          - Calculates similarity based on all-mpnet-base-v2 model
          - Selects the N most relevant text chunks

        3. **Generation Module**
          - Uses Google Gemma-2b-it model
          - Constructs context prompts based on retrieved content
          - Generates answers based on relevance threshold

        4. **Text-to-Speech Module**
          - Uses Facebook MMS-TTS-Eng model
          - Converts generated text to natural speech
        """)

            # Display the enhanced system diagram
            gr.Image(value=generate_system_diagram(), label="System Flow Diagram")

            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("""
        ## Learning & Assessment Components

        5. **Writing Assessment Module**
          - Evaluates IELTS writing using large language model
          - Provides detailed scoring and improvement suggestions

        6. **Study Plan Module**
          - Generates personalized learning plans based on user goals and time
          - Provides targeted learning suggestions and resource recommendations
        """)

                with gr.Column(scale=1):
                    gr.Markdown("""
        7. **Practice Test Module**
          - Generates IELTS practice questions for each section
          - Helps users familiarize with test format and content

        8. **Learning Progress Tracking**
          - Records user learning activities and achievements
          - Provides visualization of learning journey
        """)

            with gr.Accordion("Technical Details", open=False):
                gr.Markdown("""
        ### Implementation Details

        **Models Used:**
        - **Embedding Model**: SentenceTransformer (all-mpnet-base-v2)
        - **Language Model**: Google Gemma-2b-it (2 billion parameters)
        - **TTS Model**: Facebook MMS-TTS-Eng

        **RAG Implementation:**
        ```python
        # Vector search with relevance threshold
        def retrieve_relevant_resources(query, embeddings, n_resources=5):
            query_embedding = embedding_model.encode(query, convert_to_tensor=True)
            dot_scores = util.dot_score(query_embedding, embeddings)[0]
            scores, indices = torch.topk(dot_scores, k=n_resources)
            return scores, indices

        # Dynamic prompt selection based on relevance
        def generate_answer(query, context_items, avg_relevance_score=0.0):
            # Use threshold to decide prompt strategy
            relevance_threshold = 0.65

            if avg_relevance_score < relevance_threshold:
                # Low relevance - rely on model knowledge
                prompt = "Here is a question about the IELTS exam. Since no sufficiently relevant reference materials were found, please use your own knowledge to answer.\\n\\nQuestion: " + query + "\\n\\nAnswer:"
            else:
                # High relevance - use retrieved content
                context = "\\n\\n".join([item["sentence_chunk"] for item in context_items])
                prompt = "Based on the following IELTS materials, answer the question:\\n\\nContent:\\n" + context + "\\n\\nQuestion: " + query + "\\n\\nAnswer:"
                """)
    return demo

# 5. Startup interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)

✅ Successfully loaded data


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://e60a11c6369c472413.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
