# Qdrant & Image Data

Using Semantic Search for Accurate Skin Cancer Image Comparison

Vector databases are a "relatively" new way for interacting with abstract data representations derived from opaque machine learning models -- deep learning architectures being the most common ones. These representations are often called vectors or embeddings and they are a compressed version of the data used to train a machine learning model to accomplish a task (e.g., sentiment analysis, speech recognition, object detection, and many more).

## Table of Contents

1. Overview
2. Set Up
3. Image Embeddings
4. Semantic Search
7. Conclusion

## 1. Overview

This tutorial aims to provide an in-depth walkthrough of how to employ semantic search techniques with image data. In particular, 
we'll go over an example on how to assist doctors in comparing rare or challenging images of skin cancer with pre-labeled 
images categorized with different diseases. By leveraging the power of semantic search, medical professionals can enhance 
their diagnostic capabilities and make more accurate decisions regarding skin cancer diagnosis. That said, you can swap the 
dataset used in this tutorial and follow along with minimal adjustments to the code. 

The dataset used can be found in the [Hugging Face Hub](https://huggingface.co/datasets/marmal88/skin_cancer) and you don't 
need to do anything to download it. Here is a short description of each of the variables available.

- `image` - PIL objct of size 600x450
- `image_id` - unique id for the image
- `lesion_id` - unique id for the type of lesion on the skin of the patient
- `dx` - diagnosis given to the patient (e.g., melanocytic_Nevi, melanoma, benign_keratosis-like_lesions, basal_cell_carcinoma, 
actinic_keratoses, vascular_lesions, dermatofibroma)
- `dx_type` - type of diagnosis (e.g., histo, follow_up, consensus, confocal)
- `age` - the age of the patients from 5 to 86 (some values are missing)
- `sex` - the gender of the patient (female, male, and unknown)
- `localization` - location of the spot in the body (e.g., 'lower extremity', 'upper extremity', 'neck', 'face', 'back', 
'chest', 'ear', 'abdomen', 'scalp', 'hand', 'trunk', 'unknown', 'foot', 'genital', 'acral')

By the end of the tutorial, you will be able to extract embeddings from images using transformers and conduct image-to-image semantic search with Qdrant.

## 2. Set Up

Before you run any line of code, please make sure you have 
1. downloaded the data
2. created a virtual environment (if not in Google Colab)
3. installed the packages below
4. started a container with Qdrant

```bash
# with conda or mamba if you have it installed
mamba env create -n my_env python=3.10
mamba activate my_env

# or with virtualenv
python -m venv venv
source venv/bin/activate

# install packages
pip install qdrant-client transformers datasets torch numpy
```

The open source version of Qdrant is available as a docker image and it can be pulled and run from any machine with docker installed. If you don't have Docker installed in your PC you can follow the instructions in the official documentation [here](https://docs.docker.com/get-docker/). After that, open your terminal start by downloading the image with the following command.

```sh
docker pull qdrant/qdrant
```

Next, initialize Qdrant with the following command, and you should be good to go.

```sh
docker run -p 6333:6333 \
    -v $(pwd)/qdrant_storage:/qdrant/storage \
    qdrant/qdrant
```

Verify that you are ready to go by importing the following libraries and connecting to Qdrant via its Python client.

In [None]:
from transformers import ViTImageProcessor, ViTModel
from qdrant_client import QdrantClient
from qdrant_client.http import models
from datasets import load_dataset
import numpy as np
import torch

In [None]:
client = QdrantClient(host="localhost", port=6333)

In [None]:
my_collection = "image_collection"
client.recreate_collection(
    collection_name=my_collection,
    vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
)

## 3. Image Embeddings

In computer vision systems, vector databases are used to store image features. Image features are vector representations 
of images that capture their visual content, and they are used to improve the performance of computer vision tasks such 
as object detection, image classification, and image retrieval.

To extract these useful feature representation from our images, we'll use vision transformers (ViT). ViTs are advanced 
algorithms that enable computers to "see" and understand visual information in similar fashion to how humans do. They 
use a transformer architecture to process images and extract meaningful features from images.

To understand how ViTs work, imagine you have a large jigsaw puzzle with many different pieces. To solve the puzzle, 
you would typically look at the individual pieces, their shapes, and how they fit together to form the full picture. ViTs 
work in a similar fashion, meaning, instead of looking at the entire image at once, vision transformers break it down 
into smaller parts called "patches." Each of these patches is like one piece of the puzzle that captures a specific portion 
of the image, and these patches are then analyzed and processed by the ViTs.

By analysing these patches, the ViTs identify important patterns, such as edges, colors, and textures, and combine them 
to form a coherent understanding of a given image.

That said, let's get started using transformer to analyze and interpret our images more effectively.

We'll start by reading in the data and examining one sample.

In [None]:
dataset = load_dataset("marmal88/skin_cancer", split='train')
dataset

In [None]:
dataset[8500]

In [None]:
image = dataset[8500]["image"]
image

The image at index 8500, as shown above, is an instance of melanoma, which is a type of skin cancer that starts in the cells called melanocytes. These are responsible for producing a pigment called melanin that gives color to our skin, hair, and eyes. When melanocytes become damaged or mutate, they can start growing and dividing rapidly, forming a cancerous growth known as melanoma. Melanoma often appears as an unusual or changing mole, spot, or growth on the skin, and it can be caused by excessive exposure to ultraviolet (UV) radiation from the sun or tanning beds, as well as genetic factors. If detected early, melanoma can usually be treated successfully, but if left untreated, it can spread to other parts of the body and become more difficult to treat.

Because Melanoma can often be difficult to detect, and we want to empower doctors with the ability to compare and contrast cases that are difficult to classify without invasive procedures (i.e., by taking a sample of the skin of the patient).

In order to search through the images and provide the most similar ones to the doctors, we'll need to download a pre-trained model that will help us extract the embedding layer from our dataset. We'll do this using the transformers library and Facebook's Dino model.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb16')
model = ViTModel.from_pretrained('facebook/dino-vitb16').to(device)

Let's process the instance of melanoma we selected earlier.

In [None]:
inputs = processor(images=image, return_tensors="pt")
inputs['pixel_values'].shape, inputs

In [None]:
one_embedding = model(**inputs).last_hidden_state
one_embedding.shape, one_embedding[0, 0, :20]

As you can see, what we get back from our preprocessing function is a multi-dimensional tensor represented 
as [`batch_size`, `channels`, `rows`, `columns`]. The `batch_size` is the amount of samples passed through our 
feature extractor and the channels are represent the red, green, and blue hues of the image. Lastly, the rows and 
columns, which can also be thought of as dimensions, represent the width and height of the image, and this 
4-dimensional representation is the input our model expects. In return, it provides us with a tensor 
of [`batch_size`, `patches`, `dimensions`], and what's left for us to do is to choose a pooling method 
for our embedding since it is not feasible to use 197 embedding vectors when one would suffice. For the final step,
we'll use mean pooling.

In [None]:
one_embedding.mean(dim=1).shape

Let's create a function with the process we just walked through above and map it to our dataset to get an 
embedding vector for each image.

In [None]:
def get_embeddings(batch):
    inputs = processor(images=batch['image'], return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
    batch['embeddings'] = outputs
    return batch

In [None]:
dataset = dataset.map(get_embeddings, batched=True, batch_size=16)

In [None]:
dataset

As you can see, we now have an embedding for each of the images in our dataset. 

We'll now save the vector of embeddings as a NumPy array so the we don't have to run it again later, and then 
what we want to do is to create a payload for the metadata about each of our images. We can accomplish 
this by converting the rest of the columns we didn't use before into a JSON object for each sample.

In [None]:
np.save("vectors", np.array(dataset['embeddings']), allow_pickle=False)

In [None]:
payload = dataset.select_columns([
    'dx', 'dx_type', 'age', 'sex', 'localization'
]).to_pandas().fillna({"age": 0}).to_dict(orient="records")
payload[:3]

Note that in the cell above we use `.fillna({"age": 0})`, that is because there are several missing values in the `age` column. Because 
we don't want to assume the age of patient, we'll leave this number as 0. It is also important to note that, at the time of writing, 
Qdrant will not take in NumPy `NaN`s but rather `None` values only.

To make sure each image has an explicit id inside Qdrant, we'll create a new column with a range of numbers equivalent to the rows in 
our dataset. In addition, we'll load the embeddings we just saved.

In [None]:
ids = list(range(dataset.num_rows))
embeddings = np.load("vecs.npy").tolist()

Now we are ready to upsert the ID of 

In [None]:
batch_size = 1000

for i in range(0, dataset.num_rows, batch_size):

    low_idx = min(i+batch_size, dataset.num_rows)

    batch_of_ids = ids[i: low_idx]
    batch_of_embs = embeddings[i: low_idx]
    batch_of_payloads = payload[i: low_idx]

    client.upsert(
        collection_name=my_collection,
        points=models.Batch(
            ids=batch_of_ids,
            vectors=batch_of_embs,
            payloads=batch_of_payloads
        )
    )

In [None]:
client.count(
    collection_name=my_collection, 
    exact=True,
)

In [None]:
from datasets import Image

In [None]:
dataset[7000]   #.select_columns("image").cast_column("image", Image(decode=False))

In [None]:
dataset.filter(
    lambda x: x == "ISIC_0031944", input_columns="image_id"
).select_columns("image")[0]['image']

In [None]:
dataset[0]["image"]

In [None]:
client.scroll()

## 4. Semantic Search

## 5. Conclusion