In [1]:
cd ../../

/home/taiamiti/Projects/microplastic_analysis


In [2]:
from src.data_prep import embeddings as emb

In [3]:
import fiftyone as fo
import os
from pathlib import Path

In [4]:
from tqdm import tqdm

## Load lot2 as fiftyone ImageDirectory dataset

In [5]:
export_dir = "data/raw/lot2-30-05-2023-tak_nacl/"

dataset = fo.Dataset.from_dir(
    dataset_type=fo.types.ImageDirectory,
    dataset_dir=export_dir,
    labels_path=None,
)

 100% |███████████████| 2267/2267 [1.3s elapsed, 0s remaining, 1.7K samples/s]         


## Compute and add embeddings

In [6]:
image_paths = dataset.values("filepath")
embeddings = emb.compute_embeddings(image_paths)

  return torch._C._cuda_getDeviceCount() > 0
100%|██████████| 10/10 [00:01<00:00,  8.61it/s]


In [7]:
import fiftyone.brain as fob

# Compute 2D representation using pre-computed embeddings
results = fob.compute_visualization(
    dataset,
    embeddings=embeddings,
    num_dims=2,
    brain_key="image_embeddings",
    verbose=True,
    seed=51,
)

Generating visualization...


INFO:fiftyone.brain.internal.core.visualization:Generating visualization...


UMAP(random_state=51, verbose=True)
Tue Feb 20 14:11:42 2024 Construct fuzzy simplicial set
Tue Feb 20 14:11:42 2024 Finding Nearest Neighbors


  warn(


Tue Feb 20 14:11:44 2024 Finished Nearest Neighbor Search
Tue Feb 20 14:11:46 2024 Construct embedding


Epochs completed:   0%|            0/500 [00:00]

Tue Feb 20 14:11:47 2024 Finished embedding


## Annotate manually cluster based on filter using tags

Run the app then select clusters and tag it with DAPI, TRI, CY2, NAT

In [8]:
session = fo.launch_app(dataset, auto=False)
session.open_tab()

Session launched. Run `session.show()` to open the App in a cell output.


INFO:fiftyone.core.session.session:Session launched. Run `session.show()` to open the App in a cell output.


<IPython.core.display.Javascript object>

In [9]:
# Visualize image embeddings colored by time of day
plot = results.visualize(
    axis_equal=True,
)
plot.show(height=512)

# Attach plot to session
session.plots.attach(plot)





FigureWidget({
    'data': [{'customdata': array(['65d53fb3885bdbca032481d9', '65d53fb3885bdbca032481da',
                                   '65d53fb3885bdbca032481db', '65d53fb3885bdbca032481dc',
                                   '65d53fb3885bdbca032481dd', '65d53fb3885bdbca032481de',
                                   '65d53fb3885bdbca032481df', '65d53fb3885bdbca032481e0',
                                   '65d53fb3885bdbca032481e1', '65d53fb3885bdbca032481e2'], dtype=object),
              'hovertemplate': 'x, y = %{x:.3f}, %{y:.3f}<br>ID: %{customdata}<extra></extra>',
              'mode': 'markers',
              'type': 'scattergl',
              'uid': 'ddbe4acd-ca0d-4903-801b-a000d55169a0',
              'x': array([-14.116757 , -13.452612 , -12.701604 , -12.971095 , -13.426902 ,
                          -14.9977255, -14.605405 , -14.354495 , -13.904014 , -13.338925 ],
                         dtype=float32),
              'y': array([12.319295 , 14.603811 , 12.432488 , 13.

In [None]:
dataset.count_sample_tags()

## Compute embeddings centers based on annotated tags

In [None]:
def get_embedding(embeddings, results, sample_id):
    mapping = dict((v, i) for i, v in enumerate(results.sample_ids))
    return embeddings[mapping[sample_id], :]

In [None]:
embedding_centers = []
for tag in ['CY2', 'TRI', 'DAPI', 'NAT']:
    mean_embedding_list = []
    for sample in dataset.match_tags(tag):
        mean_embedding_list.append(get_embedding(embeddings, results, sample.id)[None, :])
    embedding_center = np.concatenate(mean_embedding_list).mean(axis=0)[None, :]
    embedding_centers.append(embedding_center)

## Save embeddings centers

In [None]:
embedding_centers_serialized = dict((k,v.tolist()[0]) for k,v in zip(['CY2', 'TRI', 'DAPI', 'NAT'], embedding_centers))

In [None]:
import json

In [None]:
out_path = "data/processed/compute_embedding_filter_centers/embedding_centers_lot2.json"
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w") as f:
    json.dump(embedding_centers_serialized, f)