In [1]:
from utils import dataset_manager, model_manager
import os
import fiftyone as fo
from flask import Flask, request, jsonify
from routes.interact_with_csv_files import csv_routes
from deep_translator import GoogleTranslator
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

KeyboardInterrupt: 

In [None]:
# data_dir = os.path.join("..", "data") # remember to organize data folder as described in github
data_dir = '/Users/VoThinhPhat/Desktop/data'
dataset_manager = dataset_manager.Dataset(dataset_name='AIC_2024',
                                        data_dir=data_dir)
dataset_manager.load_metadata()
dataset = dataset_manager.get_fiftyone_dataset()

In [None]:
# clip vit large patch14 model
model_clip14 = model_manager.CLIP_14_model()

# remember to download model_config_file and model_file (contact vtphatt2 for link)
model_config_file = os.path.join(os.getcwd(), 'task-former', 'code', 'training', 
                                 'model_configs', 'ViT-B-16.json')
model_file = os.path.join(os.getcwd(), 'task-former', 'model', 'tsbir_model_final.pt')
model_task_former = model_manager.TASK_former_model(model_config_file=model_config_file,
                                                    model_file=model_file)

In [4]:
def searchByText(text_query, k = 200):
    submission_list = []
    similarities = cosine_similarity([model_clip14.inference(text_query)], 
                                     dataset_manager.get_image_clip14_embeddings())[0]
    top_k_indices = similarities.argsort()[-k:][::-1]
    image_samples = dataset_manager.get_image_samples()
    visited = [False] * k
    for i in range(0, k):
        if (not visited[i]):
            video_name = image_samples[top_k_indices[i]]['video']
            x = [video_name, [(image_samples[top_k_indices[i]]['filepath'], int(image_samples[top_k_indices[i]]['frame_id']))]]
            visited[i] = True
            for j in range(i + 1, k):
                if (not visited[j] and video_name == image_samples[top_k_indices[j]]['video']
                    and abs(int(image_samples[top_k_indices[i]]['keyframe_id']) - int(image_samples[top_k_indices[j]]['keyframe_id'])) < 8):
                    x[1].append((image_samples[top_k_indices[j]]['filepath'], int(image_samples[top_k_indices[j]]['frame_id'])))
                    visited[j] = True
            x[1] = sorted(x[1], key=lambda a:int(a[1]))
            submission_list.append(x)
    return submission_list

In [38]:
def temporalSearch(text_first_this, text_then_that, k = 100, range_size = 8):
    submission_list = []

    video_range = dataset_manager.get_video_range()
    image_samples = dataset_manager.get_image_samples()

    x = model_clip14.inference(text_first_this)
    y = model_clip14.inference(text_then_that)

    results = []
    for video_name in video_range.keys():
        low = video_range[video_name][0] 
        high = video_range[video_name][1] + 1 
        vectors = []
        for i in range(low, high):
            vectors.append(image_samples[i]['clip-14'])
        num_vectors = len(vectors)

        for i in range(0, num_vectors - range_size + 1, 10):
            block = vectors[i:i+range_size]
            x_cos_sim = cosine_similarity([x], block[:int(0.7 * range_size)])[0]
            y_cos_sim = cosine_similarity([y], block[int(0.3 * range_size):])[0]
            block_similarity = (np.max(x_cos_sim) + np.max(y_cos_sim)) / 2

            results.append((block_similarity, video_name, i))

    results.sort(key=lambda x: x[0], reverse=True)
    top_results = results[:k]

    for similarity, video_name, best_index in top_results:
        video_samples = []
        low = video_range[video_name][0] 
        for j in range(best_index, best_index + range_size):
            sample = image_samples[low + j]
            video_samples.append((sample['filepath'], int(sample['frame_id'])))
        submission_list.append([video_name, video_samples])
    
    return submission_list

In [None]:
app = Flask(__name__)

# Register the blueprint with the main app
app.register_blueprint(csv_routes)

@app.route('/search_by_text', methods=['POST'])
def search_by_text():
    data = request.json
    search_text = data.get('searchText')
    translated_text = GoogleTranslator(source='vi', target='en').translate(search_text)
    submission_list = searchByText(translated_text, k=100)  # Ensure this returns an ordered dict if necessary

    response = jsonify({
        "translated_text": translated_text,
        "submission_list": submission_list  # Use a list instead of a dict
    })
    response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0'
    response.headers['Pragma'] = 'no-cache'
    response.headers['Expires'] = '0'

    return response, 200

@app.route('/temporal_search', methods=['POST'])
def temporal_search():
    # Extract the request data
    data = request.json
    text_first_this = data.get('textFirstThis')
    text_then_that = data.get('textThenThat')
    translated_first_this = GoogleTranslator(source='vi', target='en').translate(text_first_this)
    translated_then_that = GoogleTranslator(source='vi', target='en').translate(text_then_that)

    submission_list = temporalSearch(translated_first_this, translated_then_that, k = 30, range_size=15)

    # Prepare and return the response
    response = jsonify({
        "translated_first_this": translated_first_this,
        "translated_then_that": translated_then_that,
        "submission_list": submission_list
    })
    response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0'
    response.headers['Pragma'] = 'no-cache'
    response.headers['Expires'] = '0'

    return response, 200

app.run(debug=True, use_reloader=False)