# Multimodal Article Question Answering Assistant

In [1]:
import json
from utility_functions import *
from transformers import CLIPProcessor, CLIPModel
from torch import load, matmul, argsort
from torch.nn.functional import softmax
import gradio as gr
from IPython.display import Image

import ollama

### Load Data

In [2]:
# load article contents
text_content_list = load_from_json('data/text_content.json')
image_content_list = load_from_json('data/image_content.json')

# load embeddings
text_embeddings = load('data/text_embeddings.pt', weights_only=True)
image_embeddings = load('data/image_embeddings.pt', weights_only=True)

In [3]:
print(text_embeddings.shape)
print(image_embeddings.shape)

torch.Size([86, 512])
torch.Size([17, 512])


### Multimodal search

In [4]:
def similarity_search(query_embed, target_embeddings, content_list, k=5, threshold=0.05, temperature=0.5):
    """
    Perform similarity search over embeddings and return top k results.
    
    Args:
        query_embed (torch.Tensor): Query embedding
        target_embeddings (torch.Tensor): Target embeddings matrix to search over
        content_list (list): List of content items corresponding to embeddings
        k (int, optional): Number of top results to return. Defaults to 5.
        threshold (float, optional): Minimum similarity score threshold. Defaults to 0.1.
        temperature (float, optional): Temperature for softmax scaling. Defaults to 0.5.
    
    Returns:
        tuple: (results, scores) where:
            - results: List of top k content matches
            - scores: Corresponding similarity scores
    """
    # Calculate similarities
    similarities = torch.matmul(query_embed, target_embeddings.T)
    
    # Rescale similarities via softmax
    scores = torch.nn.functional.softmax(similarities/temperature, dim=1)
    
    # Get sorted indices and scores
    sorted_indices = scores.argsort(descending=True)[0]
    sorted_scores = scores[0][sorted_indices]
    
    # Filter by threshold and get top k
    filtered_indices = [
        idx.item() for idx, score in zip(sorted_indices, sorted_scores) 
        if score.item() >= threshold
    ][:k]
    
    # Get corresponding content items and scores
    top_results = [content_list[i] for i in filtered_indices]
    result_scores = [scores[0][i].item() for i in filtered_indices]
    
    return top_results, result_scores

### Query

In [5]:
def context_retrieval(query, text_embeddings, image_embeddings, text_content_list, image_content_list, 
                    text_k=15, image_k=5, 
                    text_threshold=0.01, image_threshold=0.25,
                    text_temperature=0.25, image_temperature=0.5):
    """
    Perform context retrieval over embeddings and return top k results.
    """
    # embed query using CLIP
    query_embed = embed_text(query)

    # perform similarity search
    text_results, _ = similarity_search(query_embed, text_embeddings, text_content_list, k=text_k, threshold=text_threshold, temperature=text_temperature)
    image_results, _ = similarity_search(query_embed, image_embeddings, image_content_list, k=image_k, threshold=image_threshold, temperature=image_temperature)

    return text_results, image_results

### Prompt Engineering

In [6]:

def construct_prompt(query, text_results, image_results):
    """
    Construct a prompt for the LLM to generate a response.
    """

    text_context = ""
    for text in text_results:
        if text_results:
            text_context = text_context + "**Article title:** " + text['article_title'] + "\n"
            text_context = text_context + "**Section:**  " + text['section'] + "\n"
            text_context = text_context + "**Snippet:** " + text['text'] + "\n\n"

    image_context = ""
    for image in image_results:
        if image_results:
            image_context = image_context + "**Article title:** " + image['article_title'] + "\n"
            image_context = image_context + "**Section:**  " + image['section'] + "\n"
            image_context = image_context + "**Image Path:**  " + image['image_path'] + "\n"
            image_context = image_context + "**Image Caption:** " + image['caption'] + "\n\n"

    # construct prompt
    return f"""Given the query "{query}" and the following relevant snippets:

    {text_context}
    {image_context}

    Please provide a concise and accurate answer to the query, incorporating relevant information from the provided snippets where possible.

    """

### Chat UI

In [8]:
ollama.pull('llama3.2-vision')

ProgressResponse(status='success', completed=None, total=None, digest=None)

In [9]:
# Function to interact with the Ollama model
def stream_chat(message, history):
    """
    Streams the response from the Ollama model and sends it to the Gradio UI.
    
    Args:
        message (str): The user input message.
        history (list): A list of previous conversation messages.
        
    Yields:
        str: The chatbot's response chunk by chunk.
    """

    # context retrieval
    text_results, image_results = context_retrieval(message["text"], text_embeddings, image_embeddings, text_content_list, image_content_list)

    # construct prompt
    prompt = construct_prompt(message["text"], text_results, image_results)
    
    # Append the user message to the conversation history
    history.append({"role": "user", "content": prompt, "images": [image["image_path"] for image in image_results]})
    
    # Initialize streaming from Ollama
    stream = ollama.chat(
        model='llama3.2-vision',
        messages=history,  # Full chat history including the current user message
        stream=True,
    )
    
    response_text = ""
    for chunk in stream:
        content = chunk['message']['content']
        response_text += content
        yield response_text  # Send the response incrementally to the UI

    # Append the assistant's full response to the history
    history.append({"role": "assistant", "content": response_text})

In [10]:
# Create a Gradio ChatInterface
gr.ChatInterface(
    fn=stream_chat,  # The function handling the chat
    type="messages",  # Using "messages" to enable chat-style conversation
    examples=[{"text": "What is CLIP's contrastive loss function?"}, 
              {"text": "What are the three paths described for making LLMs multimodal?"},
              {"text": "What is an intuitive explanation of multimodal embeddings?"}],  # Example inputs
    multimodal=True,
).launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


