# Extract_embeddings code

In [1]:
import pickle
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, List

import torch
from efficientnet_pytorch import EfficientNet
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

In [2]:
import os

In [3]:
os.chdir("E:/final/Dissertation/")

In [4]:
def get_model() -> torch.nn.Module:
    model = EfficientNet.from_pretrained("efficientnet-b4")
    model.eval()
    return model

In [5]:
def get_transform(width: int, height: int) -> Callable[[Image.Image], torch.Tensor]:
    """Provide a transform that converts image to a square and keeps the cell shape"""
    target_dim = max(width, height)
    w_pad = (target_dim - width) // 2
    h_pad = (target_dim - height) // 2
    transform = transforms.Compose([
        transforms.Pad(padding=(w_pad, h_pad)),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return transform

In [6]:
def get_image_embeddings(image: Path) -> List[float]:
    input_image = Image.open(image).convert("RGB")
    preprocess = get_transform(input_image.width, input_image.height)
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)

    return output[0].cpu().detach().tolist()

In [7]:
import glob

# Input and output paths
input_path = "data/extracted/*/*.jpg"
output_path = "data/embeddings_efficientnetb4.pkl"

# Get a list of image paths using glob
image_paths = glob.glob(input_path)

# Process the images and save the embeddings
model = get_model()
mapping = {str(i): get_image_embeddings(image=i) for i in tqdm(sorted(image_paths))}
output_file = Path(output_path)
output_file.write_bytes(pickle.dumps(mapping))

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to C:\Users\pavan/.cache\torch\hub\checkpoints\efficientnet-b4-6ed6700e.pth


  0%|          | 0.00/74.4M [00:00<?, ?B/s]

Loaded pretrained weights for efficientnet-b4


100%|████████████████████████████████████████████████████████████████████████████████| 839/839 [00:37<00:00, 22.61it/s]


7650857

In [8]:
print(model)

EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1))
  )
  (_bn0): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        48, 48, kernel_size=(3, 3), stride=[1, 1], groups=48, bias=False
        (static_padding): ZeroPad2d((1, 1, 1, 1))
      )
      (_bn1): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        48, 12, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        12, 48, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False
  

# Reduce dimensionality

In [9]:
import pickle
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
from loguru import logger
from sklearn.manifold import TSNE
from tqdm import tqdm

In [10]:
def load_embeddings(file: Path) -> Tuple[List[Path], np.ndarray]:
    embeddings = pickle.loads(file.read_bytes())
    files, vectors = zip(*list(embeddings.items()))
    return files, np.array(vectors)

In [11]:
def apply_tSNE(vectors: np.ndarray) -> np.ndarray:
    tSNE = TSNE(n_components=2, random_state=42)
    logger.info(f"Applying {tSNE}")
    return tSNE.fit_transform(vectors)

In [12]:
def reduce_dimensionality(embeddings_file: Path, output_file: Path):
    files, vectors = load_embeddings(file=embeddings_file)
    points = apply_tSNE(vectors=vectors)
    points_mapping = dict(zip(files, points))
    output_file.write_bytes(pickle.dumps(points_mapping))

In [13]:
points_path= "data/points_efficientnetb4.pkl"

In [14]:
reduce_dimensionality(
    output_file= Path(points_path),
    embeddings_file= Path(output_path))

[32m2023-08-05 14:45:04.695[0m | [1mINFO    [0m | [36m__main__[0m:[36mapply_tSNE[0m:[36m3[0m - [1mApplying TSNE(random_state=42)[0m


# Compute Clusters

In [15]:
from sklearn.cluster import KMeans
from kneed import KneeLocator

In [16]:
def make_clusters(vectors: np.ndarray, n_clusters: int) -> np.ndarray:
    kMeans = KMeans(n_clusters=n_clusters, random_state=42)
    logger.info(f"Clustering with {kMeans}")
    return kMeans.fit_predict(vectors)

In [17]:
def find_optimal_number_of_clusters(values: List[int], data: np.ndarray) -> int:
    logger.info("Finding optimal value for k ...")
    inertias = [KMeans(n_clusters=n, random_state=42).fit(data).inertia_ for n in tqdm(values)]
    kneedle = KneeLocator(x=values, y=inertias, direction='decreasing', curve='convex')
    logger.info(inertias)
    logger.info(f"Optimal k={kneedle.elbow}")
    return kneedle.knee

In [18]:
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score

In [19]:
def compute_clusters(embeddings_file: Path, output_file: Path, clusters: Optional[int] = None):
    files, vectors = load_embeddings(file=embeddings_file)
    if clusters is None:
        clusters = find_optimal_number_of_clusters(range(5, 45, 5), data=vectors)
    clusters = make_clusters(vectors=vectors, n_clusters=clusters)
    cluster_mapping = dict(zip(files, clusters))
    output_file.write_bytes(pickle.dumps(cluster_mapping))
    # Calculate evaluation metrics
    silhouette_avg = silhouette_score(vectors, clusters)
    calinski_harabasz = calinski_harabasz_score(vectors, clusters)
    davies_bouldin = davies_bouldin_score(vectors, clusters)

    print(f"Silhouette Score: {silhouette_avg}")
    print(f"Calinski-Harabasz Index: {calinski_harabasz}")
    print(f"Davies-Bouldin Index: {davies_bouldin}")

In [20]:
compute_clusters(
    embeddings_file=Path("data/embeddings_efficientnetb4.pkl"),
    output_file=Path("data/clusters_efficientb4.pkl")
)

[32m2023-08-05 14:45:09.232[0m | [1mINFO    [0m | [36m__main__[0m:[36mfind_optimal_number_of_clusters[0m:[36m2[0m - [1mFinding optimal value for k ...[0m
100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.93it/s]
[32m2023-08-05 14:45:13.375[0m | [1mINFO    [0m | [36m__main__[0m:[36mfind_optimal_number_of_clusters[0m:[36m5[0m - [1m[116136.93599971384, 103938.42847235149, 96294.99528701497, 91430.92832692462, 87433.58699905376, 84307.07848045694, 81694.01671964716, 79156.05586027892][0m
[32m2023-08-05 14:45:13.376[0m | [1mINFO    [0m | [36m__main__[0m:[36mfind_optimal_number_of_clusters[0m:[36m6[0m - [1mOptimal k=15[0m
[32m2023-08-05 14:45:13.378[0m | [1mINFO    [0m | [36m__main__[0m:[36mmake_clusters[0m:[36m3[0m - [1mClustering with KMeans(n_clusters=15, random_state=42)[0m


Silhouette Score: 0.083615066339086
Calinski-Harabasz Index: 60.99920881527343
Davies-Bouldin Index: 2.543386772475993


# Export data to csv

In [21]:
import pandas as pd
from collections import defaultdict

In [22]:
def load_file_contents(file: Path) -> Tuple[List[Path], List]:
    return pickle.loads(file.read_bytes())

In [23]:
def export_csv(clusters_file: Path, points_file: Path, output_file: Path):
    clusters_data = load_file_contents(clusters_file)
    points_data = load_file_contents(points_file)

    data = defaultdict(dict)
    for key, cluster in clusters_data.items():
        data[key]["cluster"] = cluster
        point = points_data[key]
        data[key]["point_x"], data[key]["point_y"] = point

    df: pd.DataFrame = pd.DataFrame.from_dict(data, orient="index") \
        .rename_axis('cell_location') \
        .reset_index() \
        .assign(cell_location=lambda df: df.cell_location.apply(Path)) \
        .assign(parent_image=lambda df: df.cell_location.apply(lambda path: path.parent.name))

    df.to_csv(output_file, index=False)

In [24]:
export_csv(
        clusters_file=Path("data/clusters_efficientb4.pkl"),
        points_file=Path("data/points_efficientnetb4.pkl"),
        output_file=Path("data/export_efficientnetb4.csv"))

# Plots

In [25]:
from functools import partial
from multiprocessing import Lock, cpu_count
from multiprocessing.pool import ThreadPool
import plotly.express as px
import plotly.graph_objects as go
from PIL import Image
from sklearn.manifold import TSNE

In [26]:
def load_file_contents(file: Path) -> Tuple[List[Path], np.ndarray]:
    data = pickle.loads(file.read_bytes())
    files, points = zip(*list(data.items()))
    return files, np.array(points)

In [27]:
lock = Lock()

In [28]:
def create_figure_2d(points: np.ndarray, clusters: np.ndarray, images: List[Path], limit_images: int = 500) -> go.FigureWidget:
    x = points[:, 0]
    y = points[:, 1]
    fig = px.scatter(x=x, y=y, color=clusters, opacity=0.75, size_max=5)
    np.random.seed(42)
    logger.info("Adding images to layout")
    with ThreadPool(cpu_count() * 4) as pool:
        shuffled_data = np.random.permutation(list(zip(images, x, y)))
        _ = list(tqdm(
            pool.imap(
                partial(add_image_to_layout, lock=lock, fig=fig),
                shuffled_data[:limit_images]
            ),
            total=min(len(images), limit_images))
        )
    return fig

In [29]:
def add_image_to_layout(image_x_y: Tuple[Path, float, float], fig: go.FigureWidget, lock: Lock):
    image, x, y = image_x_y
    pil_image = Image.open(image)
    lock.acquire()
    fig.add_layout_image(dict(
        source=pil_image,
        x=x,
        y=y,
        xref="x",
        yref="y",
        sizex=2,
        sizey=2,
        opacity=1,
        xanchor="center", yanchor="middle",
        layer="below",
    ))
    lock.release()


In [30]:
def create_figure(points: np.ndarray, clusters: np.ndarray, images: List[Path], limit_images: int) -> go.FigureWidget:
    fig = create_figure_2d(points=points, clusters=clusters, images=images, limit_images=limit_images)
    return fig

In [31]:
def prepare_plot(points_file: Path, clusters_file: Path, output_file: Path, limit_images: int):
    points_files, points = load_file_contents(file=points_file)
    _, clusters = load_file_contents(file=clusters_file)
    fig: go.FigureWidget = create_figure(points=points, clusters=clusters, images=points_files, limit_images=limit_images)
    output_file.write_text(fig.to_json())

In [32]:
prepare_plot(
        output_file=Path("data/plot_efficientnetb4.json"),
        points_file=Path("data/points_efficientnetb4.pkl"),
        clusters_file=Path("data/clusters_efficientb4.pkl"),
        limit_images=10000 )

[32m2023-08-05 14:45:15.315[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_figure_2d[0m:[36m6[0m - [1mAdding images to layout[0m
100%|████████████████████████████████████████████████████████████████████████████████| 839/839 [02:42<00:00,  5.16it/s]


# Visualise Embeddings

In [33]:
import json
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
from functools import lru_cache



The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`



The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`



In [34]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)


def load_figure(file: Path) -> go.Figure:
    data = json.loads(file.read_text())
    fig = go.Figure(data=data)
    fig.update_traces(
        marker_size=12,
        selector=dict(mode='markers')
    )
    return fig

In [35]:
@lru_cache
def load_file_contents(file: Path) -> List[Path]:
    data = pickle.loads(file.read_bytes())
    files, points = zip(*list(data.items()))
    return files

def prepare_dash(fig: go.FigureWidget) -> dash.Dash:
    app.layout = html.Div([
        html.Pre(id="selection-data", style={"fontSize": "22px"}),
        dcc.Graph(figure=fig, id="scatter-plot", style={"width": "100vw", "height": "100vh"}),
    ])
    return app

In [36]:
@app.callback(
    Output(component_id='selection-data', component_property='children'),
    Input(component_id='scatter-plot', component_property='clickData')
)
def handle_selection(data):
    index = data['points'][0]["pointIndex"]
    files = load_file_contents(Path("data/embeddings_efficientnetb4.pkl"))
    
    return json.dumps(
        {
            "cell": files[index],
            "parent_image": Path(files[index]).parent.name
        }, indent=4
    )

In [37]:
figure = load_figure(file=Path("data/plot_efficientnetb4.json"))
dash = prepare_dash(fig=figure)
dash.run_server(port=8056,debug=False, use_reloader=False)


# Model Evaluation

In [38]:
#1. Parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
num_parameters = count_parameters(model)
print(f"Number of parameters in the model: {num_parameters}")

Number of parameters in the model: 19341616


[2023-08-05 14:48:01,406] ERROR in app: Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "C:\Users\pavan\anaconda3\envs\anaconda3\lib\site-packages\flask\app.py", line 2529, in wsgi_app
    response = self.full_dispatch_request()
  File "C:\Users\pavan\anaconda3\envs\anaconda3\lib\site-packages\flask\app.py", line 1825, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "C:\Users\pavan\anaconda3\envs\anaconda3\lib\site-packages\flask\app.py", line 1823, in full_dispatch_request
    rv = self.dispatch_request()
  File "C:\Users\pavan\anaconda3\envs\anaconda3\lib\site-packages\flask\app.py", line 1799, in dispatch_request
    return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args)
  File "C:\Users\pavan\anaconda3\envs\anaconda3\lib\site-packages\dash\dash.py", line 1265, in dispatch
    ctx.run(
  File "C:\Users\pavan\anaconda3\envs\anaconda3\lib\site-packages\dash\_callback.py", line 450, in add_context
    output