In [1]:
import os
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import CrossEncoder
from dotenv import load_dotenv
from datasets import Dataset
import pandas as pd
import requests
import torch
import json

# Load environment variables
load_dotenv()

# Replace this with your actual question or leave as is
default_question = "Give me very short summary of all tattoos containing lions. Summary must be up to 3 sentences."

# Load data
image_df = pd.read_csv('data.csv')[['post_id', 'subreddit', 'image_path']]

# Update the image paths as per your environment
old_prefix = '/net/pr2/projects/plgrid/plggtattooai'
new_prefix = '/Users/ewojcik/Code/pwr/AMC/amc-lab3/data'
image_df['image_path'] = image_df['image_path'].str.replace(old_prefix, new_prefix)

descriptions_df = pd.read_csv('data.csv')[['post_id', 'tattoo_description','tattoo_color', 'tattoo_style', 'Title']]

if not os.path.exists('tattoos/posts_content.csv'):
    # Get all CSV files in the specified directory
    csv_path = 'tattoos/posts_per_subreddit'
    csv_files = [f for f in os.listdir(csv_path) if f.endswith('.csv')]

    posts_content_df = pd.DataFrame()

    # Read and concatenate each CSV file
    for file in csv_files:
        file_path = os.path.join(csv_path, file)
        df = pd.read_csv(file_path)
        posts_content_df = pd.concat([posts_content_df, df], ignore_index=True)

    posts_content_df = posts_content_df[['Id', 'Content']]
    posts_content_df['Content'] = posts_content_df['Content'].fillna('')
    posts_content_df = posts_content_df.rename(columns={'Id': 'post_id'})
    posts_content_df.to_csv('tattoos/posts_content.csv', index=False)
else:
    posts_content_df = pd.read_csv('tattoos/posts_content.csv')
    posts_content_df['Content'] = posts_content_df['Content'].fillna('')

# Merge datasets
tattoos_df = pd.merge(descriptions_df, posts_content_df, on='post_id', how='left')
tattoos_df.dropna(inplace=True, ignore_index=True)

# Create Hugging Face Dataset
tattoos_dataset = Dataset.from_pandas(tattoos_df)

def concatenate_text(examples):
    return {
        "text": examples["Title"]
        + " \n "
        + examples["Content"]
        + " \n "
        + examples["tattoo_description"]
        + " \n "
        + examples["tattoo_color"]
        + " \n "
        + examples["tattoo_style"]
    }

tattoos_dataset = tattoos_dataset.map(concatenate_text)
tattoos_dataset = tattoos_dataset.map(lambda x: {"text_length": len(x["text"].split())})
tattoos_dataset = tattoos_dataset.filter(lambda x: x["text_length"] > 15)

# %% [markdown]
# ## Load Embedding Model and Compute Embeddings

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
embedding_tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
embedding_model = AutoModel.from_pretrained(model_ckpt)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
embedding_model.to(device)

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

def get_embeddings(text_list):
    encoded_input = embedding_tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    with torch.no_grad():
        model_output = embedding_model(**encoded_input)
    return cls_pooling(model_output)

# Load or compute embeddings
if os.path.exists("data/tattoos_embeddings"):
    embeddings_dataset = Dataset.load_from_disk("data/tattoos_embeddings")
else:
    embeddings_dataset = tattoos_dataset.map(
        lambda x: {"embeddings": get_embeddings([x["text"]]).detach().cpu().numpy()[0]}
    )
    os.makedirs("data", exist_ok=True)
    embeddings_dataset.save_to_disk("data/tattoos_embeddings")

# Add Faiss index for efficient similarity search
embeddings_dataset.add_faiss_index(column="embeddings")

# %% [markdown]
# ## Load Reranker (Cross-Encoder)

rerank_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")

# %% [markdown]
# ## RAG Function (Query and Answer)

def rag_answer(context, question):
    user_token = os.getenv('CLARIN_API_TOKEN')
    if not user_token:
        return "No CLARIN_API_TOKEN found. Please set it in your environment."
    
    url = "https://services.clarin-pl.eu/api/v1/oapi/chat/completions"
    headers = {
        "accept": "application/json",
        "Authorization": f"Bearer {user_token}",
        "Content-Type": "application/json"
    }
    prompt = f"""
Context information is below.
---------------------
{context}
---------------------
Given the context information and not prior knowledge, answer the query.
Question: {question}
Answer:
"""
    payload = {
        "model": "llama",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ]
    }
    response = requests.post(url, headers=headers, json=payload)
    if response.status_code == 200:
        data = response.json()
        assistant_message = data['choices'][0]['message']['content']
        return assistant_message.strip()
    else:
        return "No response from the model."

# %% [markdown]
# ## Helper Function for Text Wrapping

def wrap_text(text, width=100):
    import textwrap
    return "\n".join(textwrap.wrap(text, width=width))

# %% [markdown]
# ## Interactive Demo with Plotly and Widgets

Map:   0%|          | 0/78233 [00:00<?, ? examples/s]

Map:   0%|          | 0/78233 [00:00<?, ? examples/s]

Filter:   0%|          | 0/78233 [00:00<?, ? examples/s]

  0%|          | 0/77 [00:00<?, ?it/s]

In [4]:
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import plotly.graph_objects as go

# Widgets
question_input = widgets.Text(
    value=default_question,
    description='Question:',
    layout=widgets.Layout(width='70%')
)
search_button = widgets.Button(
    description="Search",
    button_style='info'
)
output_area = widgets.Output()

def on_search_button_clicked(b):
    with output_area:
        clear_output()
        question = question_input.value.strip()
        if not question:
            print("Please enter a valid question.")
            return
        
        # Compute question embedding
        question_embedding = get_embeddings([question]).cpu().detach().numpy()
        
        # Retrieve top k from Faiss
        k_for_rag = 10
        scores, samples = embeddings_dataset.get_nearest_examples(
            "embeddings", question_embedding, k=k_for_rag
        )
        
        samples_df = pd.DataFrame.from_dict(samples)
        samples_df["scores"] = scores
        
        # Re-rank with cross-encoder
        cross_scores = rerank_model.predict([(question, txt) for txt in samples_df["text"]])
        samples_df["scores_re_ranker"] = cross_scores
        samples_df.sort_values("scores_re_ranker", ascending=False, inplace=True)
        
        # Display a bar chart of top documents by reranker score
        top_results = samples_df.head(k_for_rag)
        
        # Wrap text for readability
        top_results["wrapped_title"] = top_results["Title"].apply(lambda x: wrap_text(x, 100))
        
        fig = go.Figure(
            data=go.Bar(
                x=top_results["scores_re_ranker"],
                y=top_results["wrapped_title"],
                orientation='h',
                marker=dict(color='blue')
            )
        )
        fig.update_layout(
            title="Top Retrieved Documents (re-ranked)",
            xaxis_title="Re-ranker Score",
            yaxis_title="Title",
            yaxis=dict(autorange='reversed'),  # So the highest score is at the top
            height=600
        )
        
        display(fig)
        
        # Filter top_results to only include entries from 'tattoos' subreddit
        top_tattoos = top_results.merge(image_df, on='post_id', how='left')
        top_tattoos = top_tattoos[top_tattoos['subreddit'] == 'tattoos']
        top_tattoos = top_tattoos.dropna(subset=['image_path'])
        
        # Display up to 3 images
        images_to_display = top_tattoos.head(3)
        if not images_to_display.empty:
            image_tags = []
            for _, row in images_to_display.iterrows():
                image_path = row['image_path']
                if os.path.exists(image_path):
                    # Encode image to display inline
                    import base64
                    with open(image_path, "rb") as img_file:
                        img_bytes = img_file.read()
                        img_type = 'png' if image_path.lower().endswith('.png') else 'jpeg'
                        b64_image = base64.b64encode(img_bytes).decode('utf-8')
                    image_tags.append(f'<img src="data:image/{img_type};base64,{b64_image}" width="200" style="margin-right:10px;"/>')
                else:
                    image_tags.append('<p>Image not found.</p>')
            # Display images horizontally
            display(HTML('<div style="display: flex; gap: 10px;">' + ''.join(image_tags) + '</div>'))
        else:
            print("No images from the 'tattoos' subreddit found in the top results.")
        
        # Generate RAG answer
        context = "\n\n".join(top_results["text"].tolist())
        answer = rag_answer(context, question)
        if answer:
            print(f"**Answer:**\n{wrap_text(answer, 100)}")
        else:
            print("No answer generated.")

# Attach the click event to the button
search_button.on_click(on_search_button_clicked)

# Display widgets
display(question_input, search_button, output_area)

Text(value='Give me very short summary of all tattoos containing lions. Summary must be up to 3 sentences.', d…

Button(button_style='info', description='Search', style=ButtonStyle())

Output()