In [2]:
import os
import json
import nltk
from gensim.models import Word2Vec
from sklearn.metrics.pairwise import cosine_similarity

class RAGSystem:
    def __init__(self, metadata_dir):
        self.metadata = []
        self.main_page_spans = []
        self.side_panel_spans = []
        self.questions = []
        self.questions_customized = []
        self.answers = []
        for filename in os.listdir(metadata_dir):
            if filename.endswith('.json'):
                file_path = os.path.join(metadata_dir, filename)
                with open(file_path, 'r', encoding='utf-8') as file:
                    data = json.load(file)
                    self.metadata.extend(data['dataset'])
                    self.main_page_spans.extend(data['mainPageSpans'])
                    self.side_panel_spans.extend(data['sidePanelSpans'])
                    for item in data['dataset']:
                        self.questions.extend(item.get('questions', []))
                        self.questions_customized.extend(item.get('questions_customized', []))
                        self.answers.extend(item.get('answers',[]))

        self.tokenized_metadata = []
        for item in self.metadata:
            if 'text' not in item:
                item['text'] = []
            if isinstance(item['text'], list):
                tokens = [nltk.word_tokenize(' '.join(text).lower()) for text in item['text']]
                self.tokenized_metadata.extend(tokens)
            else:
                tokens = nltk.word_tokenize(item['text'].lower())
                self.tokenized_metadata.append(tokens)
                
                
        self.tokenized_questions_customized = [nltk.word_tokenize(question_customized.lower()) for question_customized in self.questions_customized]
        self.tokenized_questions = [nltk.word_tokenize(question.lower()) for question in self.questions]
        self.tokenized_answers = [nltk.word_tokenize(answer.lower()) for answer in self.answers]
        self.model = Word2Vec(self.tokenized_metadata + self.tokenized_questions + self.tokenized_questions_customized + self.tokenized_answers, min_count=1)
    
    def retrieve(self, query, items, is_question=False, top_k=5):
        query_tokens = nltk.word_tokenize(query.lower())
        query_vector = self.get_vector(query_tokens)

        similarity_scores = []
        for tokens in items:
            item_vector = self.get_vector(tokens)
            similarity = cosine_similarity([query_vector], [item_vector])[0][0]
            similarity_scores.append(similarity)

        top_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
        
        if is_question:
            retrieved_items = [self.questions[i] for i in top_indices]
        else:
            retrieved_items = [self.metadata[i] for i in top_indices]

        return retrieved_items

    def get_vector(self, tokens):
        vectors = [self.model.wv[token] for token in tokens if token in self.model.wv]
        if vectors:
            return sum(vectors) / len(vectors)
        else:
            return self.model.wv['UNK']  # Return a default vector for unknown tokens
    def remove_empty_objects(self, obj):
        if isinstance(obj, dict):
            cleaned_dict = {}
            for k, v in obj.items():
                cleaned_v = self.remove_empty_objects(v)
                if cleaned_v is not None:
                    cleaned_dict[k] = cleaned_v
            return cleaned_dict if cleaned_dict else None
        elif isinstance(obj, list):
            cleaned_list = []
            for elem in obj:
                cleaned_elem = self.remove_empty_objects(elem)
                if cleaned_elem is not None:
                    cleaned_list.append(cleaned_elem)
            return cleaned_list if cleaned_list else None
        else:
            return obj

    def generate(self, query, query_questions, query_questions_customized, query_answers):
        retrieved_items = self.retrieve(query, self.tokenized_metadata)
        retrieved_items_questions = self.retrieve(query_questions, self.tokenized_questions, is_question=True)
        retrieved_items_questions_customized = self.retrieve(query_questions_customized, self.tokenized_questions_customized, is_question=True)
        retrieved_items_answers = self.retrieve(query_answers, self.tokenized_answers, is_question=True)

        cleaned_retrieved_items = [self.remove_empty_objects(item) for item in retrieved_items]

        new_retrieved_items = []
        for item in cleaned_retrieved_items:
            new_item = {
                "image": item.get("image", []),
                "caption": item.get("caption", []),
                "text": [" ".join(item.get("text", []))],
                "questions": retrieved_items_questions,
                "questions_customized": retrieved_items_questions_customized,
                "answers": retrieved_items_answers,
                "preference": item.get("preference", []),
                "note": item.get("note", []),
                "summary": item.get("summary", []),
                "subfigureJSON": item.get("subfigureJSON", [])
            }
            new_retrieved_items.append(new_item)

        generated_metadata = {
            "dataset": new_retrieved_items,
            "mainPageSpans": self.main_page_spans,
            "sidePanelSpans": self.side_panel_spans
        }

        return generated_metadata


# Usage example
metadata_dir = 'doc'
rag_system = RAGSystem(metadata_dir)

query = """Analyze the given image and its associated metadata to answer the following questions:

  1. What are the key visual elements in the image?
  2. How does the caption describe the image?
  4. What preferences or requests are mentioned in the metadata?
  5. Are there any notable subfigures or annotations in the image?

  Ignore the empty metadata and retrieve only the relevant information from the dataset to answer these questions. Provide a concise summary of the retrieved information.
"""
query_questions = """
  3. Are there any specific questions related to the image? If so, provide a summary of the questions.
"""
query_questions_customized = """
  Retrieve any custom questions related to the image.
"""
query_answers = """
  Retrieve the answers corresponding to the questions related to the image.
"""
generated_metadata = rag_system.generate(query, query_questions, query_questions_customized, query_answers)


output_file = 'generated_metadata.json'
with open(output_file, 'w') as file:
    json.dump(generated_metadata, file, indent=4)