# Start the websocket server in the background

In [91]:
import os

def getListOfFiles(directory):
    whaleDict = {}

    # Walk through the directory tree
    for root, _, files in os.walk(directory):
        # Get the relative path from the root directory
        relativePath = os.path.relpath(root, directory)
        # Split the relative path to find the first subdirectory (whale name)
        pathParts = relativePath.split(os.sep)

        if len(pathParts) > 0 and pathParts[0] != ".":
            whaleName = pathParts[0]  # First subdirectory is the whale name

            # Initialize the list for this whale if it doesn't exist
            if whaleName not in whaleDict:
                whaleDict[whaleName] = []

            # Add each file to the whale's list
            for file in files:
                fullPath = os.path.join(root, file)
                whaleDict[whaleName].append((whaleName, fullPath))

    return whaleDict

In [92]:
import random
import json
import websockets

async def extract(allWhales):
	# Get all file paths from the whaleDict
	all_files = []
	for whale_files in allWhales.values():
		all_files.extend([file_path for _, file_path in whale_files])

	# Select 100 random images
	random_files = random.sample(all_files, 100)

	# Prepare the request payload
	request_payload = {
		"type": "extract",
		"fileNames": random_files
	}

	async def send_extract_request():
		uri = "ws://localhost:8765"
		async with websockets.connect(uri) as websocket:
			await websocket.send(json.dumps(request_payload))

			try:
				while True:
					response = await websocket.recv()

					if response.startswith('['):
						data = json.loads(response)
						return data

			except websockets.exceptions.ConnectionClosedOK:
				pass
			except websockets.exceptions.ConnectionClosedError:
				pass

	# Run the async function
	return await send_extract_request()

In [93]:
def format_identify_response(response):
	return {
		"type": "identify",
		"data": [
			{
				"path": item["path"],
				"type": item["type"],
				"embedding": item["embedding"]
			}
			for item in response
		]
	}

async def identify(embeddings):
	request_payload = format_identify_response(embeddings)

	async def send_identify_request():
		uri = "ws://localhost:8765"
		async with websockets.connect(uri) as websocket:
			await websocket.send(json.dumps(request_payload))

			try:
				while True:
					response = await websocket.recv()

					if response.startswith('['):
						data = json.loads(response)
						return data

			except websockets.exceptions.ConnectionClosedOK:
				pass
			except websockets.exceptions.ConnectionClosedError:
				pass

	return await send_identify_request()


In [94]:
import time

allWhales = getListOfFiles("G:\\Whale Stuff\\_data\\cetaceans")

start_time = time.time()

extracted = await extract(allWhales)
print("Extracted!")

identified = await identify(extracted)
print("Identified!")

extracted = await extract(allWhales)
print("Extracted!")

identified = await identify(extracted)
print("Identified!")

end_time = time.time()

print(f"Time taken: {end_time - start_time} seconds")
print(f"Time per image: {(end_time - start_time) / len(identified)} seconds")


Extracted!
Identified!
Extracted!
Identified!
Time taken: 38.86849570274353 seconds
Time per image: 0.38868495702743533 seconds
