In [19]:
import os
import sys

current_dir = os.getcwd()
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))

sys.path.append(kit_dir)
sys.path.append(repo_dir)

import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [24]:
import chromadb
import numpy as np
import io
import requests
import json
import shutil
import time
import yaml
from io import BytesIO
from PIL import Image
import pandas as pd
from pandas import DataFrame
from dotenv import load_dotenv
import tarfile
load_dotenv(os.path.join(repo_dir,".env"))

True

## Batch image ingestion

In [126]:
!snapi app list | grep CLIP -A 4

CLIP
Name                : CLIP
ID                  : 6c14325a-1be7-4e48-b38f-19b33745fc3b
Playground          : False
Prediction Input    : text



In [127]:
# create predinction csv
def generate_csv(dataset_dir):
    image_paths = []

    for root, dirs, files in os.walk(dataset_dir):
        for file in files:
            if file.endswith(('.jpg', '.jpeg', '.png', '.gif')):
                image_path = os.path.relpath(os.path.join(root, file), dataset_dir)
                image_paths.append(image_path)

    df = pd.DataFrame({'image_path': image_paths, 'description': '', 'subset': '', 'metadata': ''})
    df.to_csv(os.path.join(dataset_dir,'predictions.csv'), index=False)

# Specify the directory containing your dataset
dataset_directory = '../data/images'
generate_csv(dataset_directory)

In [128]:
def load_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config(os.path.join(kit_dir,'config.yaml'))

PENDING_RDU_JOB_STATUS = 'PENDING_RDU'
SUCCESS_JOB_STATUS = 'EXIT_WITH_0'
FAILED_JOB_STATUS = 'FAILED'

In [129]:
class BatchClipProcessor():
    
    def __init__(self, config) -> None:
        self.headers = {
            'content-type': 'application/json',
            'key': os.getenv('SAMBASTUDIO_KEY'),
        }
        self.datasets_path = f".{config['clip']['datasets']['datasets_path']}"
        self.dataset_id = None
        self.dataset_name = config['clip']['datasets']['dataset_name']
        self.dataset_description = config['clip']['datasets']['dataset_description']
        self.dataset_source_type = config['clip']['datasets']['dataset_source_type']
        self.dataset_source_file = f".{config['clip']['datasets']['dataset_source_file']}"
        
        self.clip_app_id = config['clip']['apps']['clip_app_id']
        self.application_field = config['clip']['apps']['application_field']
        
        self.base_url = config['clip']['urls']['base_url']
        self.datasets_url = config['clip']['urls']['datasets_url'] 
        self.projects_url = config['clip']['urls']['projects_url'] 
        self.jobs_url = config['clip']['urls']['jobs_url'] 
        self.download_results_url = config['clip']['urls']['download_results_url'] 
    
        self.project_name = config['clip']['projects']['project_name']
        self.project_description = config['clip']['projects']['project_description']
        self.project_id=None
        
        self.job_name = config['clip']['jobs']['job_name']
        self.job_task = config['clip']['jobs']['job_task']
        self.job_type = config['clip']['jobs']['job_type']
        self.job_description = config['clip']['jobs']['job_description']
        self.model_checkpoint = config['clip']['jobs']['model_checkpoint']
        
        self.output_path = config['clip']['output']['output_path']
        
        
    def _get_call(self, url, params = None, success_message = None):
        response = requests.get(url, params=params, headers=self.headers)

        if response.status_code == 200:
            logging.info('GET request successful!')
            logging.info(success_message)
            logging.debug(f'Response: {response.text}')
        else:
            logging.error(f'GET request failed with status code: {response.status_code}')
            logging.error(f'Error message: {response.text}')
        return response

    def _post_call(self, url, params, success_message = None):
        response = requests.post(url, json=params, headers=self.headers)

        if response.status_code == 200:
            logging.info('POST request successful!')
            logging.info(success_message)
            logging.debug(f'Response: {response.text}')
        else:
            logging.error(f'POST request failed with status code: {response.status_code}')
            raise Exception(f'Error message: {response.text}')
        return response
    
    def _delete_call(self, url):
        response = requests.delete(url, headers=self.headers)    
        if response.status_code == 200:
            logging.info(f'Dataset {self.dataset_name} deleted successfully.')
            logging.debug(f'Response: {response.text}')
        else:
            logging.error(f'Failed to delete the resource. Status code: {response.status_code}')
            raise Exception(f'Error message: {response.text}')    
        return response
    def _generate_csv(self, dataset_dir):
        image_paths = []
        for root, dirs, files in os.walk(dataset_dir):
            for file in files:
                if file.endswith(('.jpg', '.jpeg', '.png', '.gif')):
                    image_path = os.path.relpath(os.path.join(root, file), dataset_dir)
                    image_paths.append(image_path)

        df = pd.DataFrame({'image_path': image_paths, 'description': '', 'subset': '', 'metadata': ''})
        df.to_csv(os.path.join(dataset_dir,'predictions.csv'), index=False)

    def _get_df_output(self, response_content: str) -> DataFrame:
        compressed_bytes = io.BytesIO(response_content)
        
        with tarfile.open(fileobj=compressed_bytes, mode="r:gz") as tar:
            output_tar_member = tar.getmember(self.output_path)
            output_file = tar.extractfile(output_tar_member)
            output_df = pd.read_json(io.BytesIO(output_file.read()), lines=True)       
        return output_df

    def search_dataset(self, dataset_name):
        url = self.base_url + self.datasets_url + '/search'
        params = {
            'dataset_name': dataset_name
        }
        response = self._get_call(url, params, f'Dataset {dataset_name} found in SambaStudio')
        parsed_reponse = json.loads(response.text)
        return parsed_reponse['data']['dataset_id']

    def delete_dataset(self, dataset_name):
        dataset_id = self.search_dataset(dataset_name)
        url = self.base_url + self.datasets_url + '/' + dataset_id
        response = self._delete_call(url)
        logging.info(response.text)
        
        
    def create_dataset(self, path):
        # create clip directory and source.json file
        
        dataset_name = f'{self.dataset_name}_{int(time.time())}'
            
        clip_directory = os.path.join(self.datasets_path, dataset_name)
        
        if not os.path.isdir(self.datasets_path):
            os.mkdir(self.datasets_path) 
            
        if not os.path.isdir(clip_directory):
            logging.info(f'Datasets path: {clip_directory} wan \'t found')
            
            source_file_data = {
                "source_path": clip_directory
            }
            
            with open(self.dataset_source_file, 'w') as json_file:
                json.dump(source_file_data, json_file)

        shutil.copytree(path, clip_directory)
        
        self._generate_csv(clip_directory)
        
        # create dataset
        command = f'echo yes | snapi dataset add \
            --dataset-name {dataset_name} \
            --job_type {self.job_type} \
            --apps {self.clip_app_id} \
            --source_type {self.dataset_source_type} \
            --source_file {self.dataset_source_file} \
            --application_field {self.application_field} \
            --description "{self.dataset_description}"'
        
        os.system(command)
        logging.info(f'Creating dataset: {dataset_name}')
        
        return dataset_name
         
    def check_dataset_creation_progress(self, dataset_name):
        url = self.base_url + self.datasets_url + '/' + dataset_name
        response = self._get_call(url)
        if response.json()["data"]["status"]=="Available": 
            return True
        else:
            return False
            
    def create_load_project(self):

        url = self.base_url + self.projects_url + '/' + self.project_name

        response = self._get_call(url, success_message=f'Project {self.project_name} found in SambaStudio')
        not_found_error_message = f"{self.project_name} not found"

        if not_found_error_message in response.text:
            
            logging.info(f'Project {self.project_name} wasn\'t found in SambaStudio')
            
            url = self.base_url + self.projects_url

            params = {
                'project_name': self.project_name,
                'description': self.project_description
            }

            response = self._post_call(url, params, success_message=f'Project {self.project_name} created!')

        parsed_reponse = json.loads(response.text)
        self.project_id = parsed_reponse['data']['project_id']
        return self.project_id
    
    def run_job(self, dataset_name):
        
        url = self.base_url + self.projects_url + self.jobs_url.format(project_id=self.project_id)
        
        params = {
            'task': self.job_task,
            'job_type': self.job_type,
            'job_name': f'{self.job_name}_{int(time.time())}',
            'project': self.project_id,
            'model_checkpoint': self.model_checkpoint,
            'description': self.job_description,
            'dataset': dataset_name,
        }

        response = self._post_call(url, params, success_message='Job running')
        parsed_reponse = json.loads(response.text)
        job_id = parsed_reponse['data']['job_id']
        
        return job_id
    
    def check_job_progress(self, job_id):
        """Check job progress of a given job.

        Args:
            job_id (str): The id of the job to check.
            
        Returns:
            bool: True when the job is finished.
        """

        url = self.base_url + self.projects_url + self.jobs_url.format(project_id=self.project_id) + '/' + job_id

        status = PENDING_RDU_JOB_STATUS
        while status != SUCCESS_JOB_STATUS:
            response = self._get_call(url, success_message='Still waiting for job to finish')
            parsed_reponse = json.loads(response.text)   
            status = parsed_reponse['data']['status']
            logging.info(f'Job status: {status}')
            if status == SUCCESS_JOB_STATUS:
                logging.info('Job finished!')
                break
            elif status == FAILED_JOB_STATUS:
                logging.info('Job failed!')
                return False
            time.sleep(10)
        
        return True  
    
    def delete_job(self, job_id):
        url = self.base_url +  self.projects_url + self.jobs_url.format(project_id=self.project_id) + '/' + job_id
        response = self._delete_call(url)
        logging.info(response.text)
        
    def retrieve_results(self, job_id):
        url = self.base_url + self.projects_url + self.jobs_url.format(project_id=self.project_id) + '/' + job_id + self.download_results_url
        response = self._get_call(url, success_message='Results downloaded!')
        df = self._get_df_output(response.content)
        return df

In [130]:
clip = BatchClipProcessor(config)

In [131]:
clip.create_load_project()

INFO:root:GET request successful!
INFO:root:Project image_search_project found in SambaStudio


'edcd0f67-0f39-4775-8ba7-d76327894b3f'

In [132]:
dataset_name = clip.create_dataset(path=os.path.join(kit_dir,'data/images'))
while not clip.check_dataset_creation_progress(dataset_name):
    print("waiting for dataset available")
    time.sleep(1)

INFO:root:Datasets path: ../data/datasets/images_dataset_1710352467 wan 't found



Folder Information:
  - Number of Files: 41
  - Total Size: 16.24 MB

Are you sure you want to proceed? ([33myes[0m/no)
: Uploading files


INFO:root:Creating dataset: images_dataset_1710352467


Dataset folder upload complete: ../data/datasets/images_dataset_1710352467
Dataset added successfully.
Time taken to upload the dataset: 49.43410301208496 seconds


INFO:root:GET request successful!
INFO:root:None


In [133]:
clip.search_dataset(dataset_name)

INFO:root:GET request successful!
INFO:root:Dataset images_dataset_1710352467 found in SambaStudio


'3ef23e2b-33e5-499d-bc51-339b017a9c83'

In [134]:
job_id = clip.run_job(dataset_name)

INFO:root:POST request successful!
INFO:root:Job running


In [135]:
result = clip.check_job_progress(job_id) 

INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PENDING_RDU
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PENDING_RDU
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PENDING_RDU
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: EXIT_WITH_0
INFO:root:Job finished!


In [136]:
df=clip.retrieve_results(job_id)

INFO:root:GET request successful!
INFO:root:Results downloaded!


In [137]:
clip.delete_job(job_id)

INFO:root:Dataset images_dataset deleted successfully.
INFO:root:{}


In [138]:
clip.delete_dataset(dataset_name)

INFO:root:GET request successful!
INFO:root:Dataset images_dataset_1710352467 found in SambaStudio
INFO:root:Dataset images_dataset deleted successfully.
INFO:root:{"detail":"The Dataset: 3ef23e2b-33e5-499d-bc51-339b017a9c83 was successfully marked for deletion from the Dataset Hub."}


In [139]:
df.head()

Unnamed: 0,predictions,input,type
0,"[-0.002137135714292, 0.278615444898605, -0.159...",art/art_1.png,img
1,"[0.029079457744956003, 0.12692214548587802, -0...",places/places_2.png,img
2,"[0.018383597955107002, -0.15706798434257502, -...",appliances/appliances_1.png,img
3,"[-0.056953176856040004, 0.11261384934186901, -...",art/art_0.png,img
4,"[0.04671736434102, 0.208280116319656, -0.07212...",nature/nature_2.png,img


In [140]:

from chromadb.api.types import is_image, is_document, Images,  Documents, EmbeddingFunction, Embeddings, Protocol
from typing import cast, Union, TypeVar

## chromadb multimodal

In [141]:
Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)

class ClipEmbbeding(EmbeddingFunction[D]):
    def __init__(self) -> None:
        pass
    def __call__(self, input: D) -> Embeddings:
        embeddings: Embeddings = []
        for item in input:     
            if is_document(item):
                #TODO implement SN endpoint inference
                output = None
            elif is_image(item):
                image = Image.fromarray(item)
                buffer = io.BytesIO()
                image.save(buffer, format='PNG')
                buffer
                #TODO implement SN endpoint inference
                output = None
            embeddings.append(output["embedding"])
        return cast(Embeddings, embeddings)


In [161]:
client = chromadb.PersistentClient(path=os.path.join(kit_dir,"data"))
clip_embedding=ClipEmbbeding()
try:
    client.delete_collection(name="image_collection")
except:
    pass
collection=client.get_or_create_collection(name="image_collection", embedding_function=clip_embedding, metadata={"hnsw:space": "l2"})
collection.get()


{'ids': [],
 'embeddings': None,
 'metadatas': [],
 'documents': [],
 'uris': None,
 'data': None}

## Add individual images 

In [None]:
def get_images(folder_path):
    images=[]
    paths=[]
    for root, _dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
                path=os.path.join(root, file)
                paths.append(path)
                image= np.array(Image.open(os.path.join(root, file)))
                images.append(image)
    return paths,images

In [None]:
paths, images=get_images(os.path.join(kit_dir,"data/images"))
print(len(paths))

400


In [None]:
collection.add(
    images=images,
    metadatas=[{"source": path} for path in paths],
    ids=paths,
    uris=paths
)

In [143]:
collection.get()#include=["uris","documents"])

{'ids': [],
 'embeddings': None,
 'metadatas': [],
 'documents': [],
 'uris': None,
 'data': None}

## Add batch preprocessed images

In [162]:
embeddings = list(df["predictions"]) 
paths = list(df["input"].apply(lambda x: os.path.join(kit_dir,'data/images',x)))

In [163]:
collection.add(
    embeddings=embeddings,
    metadatas=[{"source": path} for path in paths],
    ids=paths,
    uris=paths
)

In [164]:
collection.get()#include=["uris","documents"])

{'ids': ['/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/animals/animals_0.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/animals/animals_2.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/animals/animals_3.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/animals/animals_4.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/appliances/appliances_1.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/appliances/appliances_4.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/art/art_0.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/art/art_1.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starter-kit/image_search/data/images/art/art_2.png',
  '/Users/jorgep/Documents/ask_public_own/ai-starte

## Search methods

In [169]:
def search_image_by_text(query,n=1):
    result=collection.query(query_texts=[query],include=["uris", "distances"],n_results=n)
    return result['uris'][0], result["distances"][0]

In [170]:
def search_image_by_image(path,n=5):
    image= np.array(Image.open(path))
    result=collection.query(query_images=[image],include=["uris", "distances"],n_results=n)
    return result['uris'][0], result["distances"][0]

In [171]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def show_images(image_paths, distances):
    num_images = len(image_paths)
    fig, axes = plt.subplots(1, num_images, figsize=(10*num_images, 10))
    
    for i, path in enumerate(image_paths):
        img = mpimg.imread(path)
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f'Image {i+1}, d={distances[i]}')
    
    plt.show()


In [None]:
uris, distances = search_image_by_text("birds")
show_images(uris, distances)

In [None]:
uris, distances = search_image_by_image("../download.jpeg")
uris.insert(0, "../download.jpeg")
distances.insert(0, 0)
show_images(uris, distances)