Image Classification using pretrained model

In [9]:
from PIL import Image
import timm
import torch
from torchvision import transforms

img = Image.open(
'/home/plankton/Desktop/20220512-00120347_003.127bar_28.24C_44.jpg'
).convert('RGB')

model = timm.create_model('vit_large_patch14_clip_224.openai_ft_in12k_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
#data_config = {'input_size': (3, 224, 224), 'interpolation': 'bicubic', 'mean': (0.48145466, 0.4578275, 0.40821073), 'std': (0.26862954, 0.26130258, 0.27577711), 'crop_pct': 1.0, 'crop_mode': 'center'}
#transform = transforms.Compose([
#    transforms.Resize(size=224, max_size=None, antialias=None),
#    transforms.CenterCrop(size=(224, 224)),
#    transforms.ToTensor(),
#    transforms.Normalize((0.5,), (0.5,))
#    #transforms.Normalize(mean=tensor([0.4815, 0.4578, 0.4082]), std=tensor([0.2686, 0.2613, 0.2758]))
#])
#print(transform)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
#trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)

#img = transforms(img).unsqueeze(0)
#print(img.shape)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)


In [14]:
print(top5_probabilities, top5_class_indices)

tensor([[9.2531e+01, 1.2603e-01, 6.8981e-02, 5.7594e-02, 4.2416e-02]],
       grad_fn=<TopkBackward0>) tensor([[111, 110, 712,  52, 126]])


In [None]:
import timm
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)

In [7]:
len(timm.list_models('*'))

889

Image embeddings

In [1]:
from PIL import Image
import timm
from timm.data.loader import create_loader
from timm.data.dataset import ImageDataset
from tqdm import tqdm

#img = Image.open(
#'/home/plankton/Desktop/20220512-00120347_003.127bar_28.24C_44.jpg'
#).convert('RGB')

model = timm.create_model(
    'vit_large_patch14_clip_224.openai_ft_in12k_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

dataset = ImageDataset('/home/plankton/Results/M181/AI_crop_test/cropped_images/',transform=transforms)

#dataset = datasets.ImageFolder('/home/plankton/Results/M181/AI_crop_test/cropped_images/')
#loader = create_loader(dataset, (3, 224, 224), 4)

# Will contain the features
features = []

for img in tqdm(dataset):
#    output = model(img[0].unsqueeze(0))  # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

    output = model.forward_features(img[0].unsqueeze(0))
# output is unpooled, a (1, 257, 1024) shaped tensor

    output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor
    features.append(output.detach().numpy().reshape(-1))


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76/76 [00:46<00:00,  1.65it/s]


In [67]:
features[0].shape

(1024,)

In [16]:
#output = output.detach().numpy()
print(output)

[[-0.79853034  0.44331598  1.4494123  ... -0.61171705 -0.3869704
   0.06875669]]


In [2]:
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
projections = pca.fit(features).transform(features)

# Percentage of variance explained for each components
print(
    "explained variance ratio (first three components): %s"
    % str(pca.explained_variance_ratio_)
)

explained variance ratio (first three components): [0.18673451 0.11203103 0.07482025]


In [3]:
len(projections)

76

In [89]:
import numpy as np
import base64
import os

def image_to_data_uri(image_path):
    with open(image_path, "rb") as image_file:
        encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
    return "data:image/jpeg;base64," + encoded_image

def listdir_fullpath(d):
    return [os.path.join(d, f) for f in os.listdir(d)]

d = '/home/plankton/Results/M181/AI_crop_test/cropped_images/'
image_paths = listdir_fullpath(d)
image_paths = np.array(image_paths)
labels = np.zeros_like(image_paths)
img_data_uris = {path: image_to_data_uri(path) for path in image_paths}


In [92]:
img_data_uris

{'/home/plankton/Results/M181/AI_crop_test/cropped_images/crop_3_20220509-05474297_001.211bar_30.24C_bg_corrected (1).bmp.png': '

In [93]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

from typing import Dict
from pathlib import Path
from IPython.display import display, HTML


def display_projections(
    labels: np.ndarray,
    projections: np.ndarray,
    image_paths: np.ndarray,
    image_data_uris: Dict[str, str],
    show_legend: bool = False,
    show_markers_with_text: bool = True
) -> None:
    # Create a separate trace for each unique label
    unique_labels = np.unique(labels)
    traces = []
    for unique_label in unique_labels:
        mask = labels == unique_label
        customdata_masked = image_paths[mask]
        trace = go.Scatter3d(
            x=projections[mask][:, 0],
            y=projections[mask][:, 1],
            z=projections[mask][:, 2],
            mode='markers+text' if show_markers_with_text else 'markers',
            text=labels[mask],
            customdata=customdata_masked,
            name=str(unique_label),
            marker=dict(size=8),
            hovertemplate="<b>class: %{text}</b><br>path: %{customdata}<extra></extra>"
        )
        traces.append(trace)

    # Create the 3D scatter plot
    fig = go.Figure(data=traces)
    fig.update_layout(
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
        width=1000,
        height=1000,
        showlegend=show_legend
    )

    # Convert the chart to an HTML div string and add an ID to the div
    plotly_div = fig.to_html(full_html=False, include_plotlyjs=False, div_id="scatter-plot-3d")

    # Define your JavaScript code for copying text on point click
    javascript_code = f"""
    <script>
        function displayImage(imagePath) {{
            var imageElement = document.getElementById('image-display');
            var placeholderText = document.getElementById('placeholder-text');
            var imageDataURIs = {image_data_uris};
            imageElement.src = imageDataURIs[imagePath];
            imageElement.style.display = 'block';
            placeholderText.style.display = 'none';
        }}

        // Get the Plotly chart element by its ID
        var chartElement = document.getElementById('scatter-plot-3d');

        // Add a click event listener to the chart element
        chartElement.on('plotly_click', function(data) {{
            var customdata = data.points[0].customdata;
            displayImage(customdata);
        }});
    </script>
    """

    # Create an HTML template including the chart div and JavaScript code
    html_template = f"""
    <!DOCTYPE html>
    <html>
        <head>
            <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
            <style>
                #image-container {{
                    position: fixed;
                    top: 0;
                    left: 0;
                    width: 200px;
                    height: 200px;
                    padding: 5px;
                    border: 1px solid #ccc;
                    background-color: white;
                    z-index: 1000;
                    box-sizing: border-box;
                    display: flex;
                    align-items: center;
                    justify-content: center;
                    text-align: center;
                }}
                #image-display {{
                    width: 100%;
                    height: 100%;
                    object-fit: contain;
                }}
            </style>
        </head>
        <body>
            {plotly_div}
            <div id="image-container">
                <img id="image-display" src="" alt="Selected image" style="display: none;" />
                <p id="placeholder-text">Click on a data entry to display an image</p>
            </div>
            {javascript_code}
        </body>
    </html>
    """

    # Display the HTML template in the Jupyter Notebook
    display(HTML(html_template))

In [94]:
display_projections(
    labels=labels,
    projections=projections,
    image_paths=image_paths,
    image_data_uris=img_data_uris
)