Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3ed3213
add timm wrapper class
dnth Sep 28, 2023
a70186d
update image extensions
dnth Sep 28, 2023
702ad3e
init_model
dnth Sep 28, 2023
fc17084
test timm embeddings
dnth Sep 28, 2023
2bb95bd
add class docstring
dnth Sep 28, 2023
eb50a8b
update github actions test
dnth Sep 28, 2023
0054808
update naming
dnth Sep 29, 2023
9b10782
update naming timm encoder
dnth Oct 4, 2023
b045b27
add init py
dnth Oct 4, 2023
a3994cd
initial commit
dnth Oct 5, 2023
39ff163
auto download gdino weights
dnth Oct 6, 2023
979cf1a
add segment anything model
dnth Oct 6, 2023
902d860
update
dnth Oct 6, 2023
2f54d69
remake api
dnth Oct 7, 2023
5079baf
update enrich api
dnth Oct 7, 2023
d1f2b31
use fastdup functions
dnth Oct 9, 2023
4f53dc3
update download from web utils
dnth Oct 9, 2023
f839ede
update fastdup controller
dnth Oct 9, 2023
4b1e864
update models implement feedback
dnth Oct 9, 2023
0427574
undo unintended changes
dnth Oct 9, 2023
454b207
reduce test images
dnth Oct 10, 2023
eb6220a
update test
dnth Oct 10, 2023
0d3a406
use TimmEncoder as name
dnth Oct 10, 2023
a05c731
update timm to include fastdup exception report
dnth Oct 10, 2023
155799c
use fastdup methods
dnth Oct 10, 2023
e500a06
update bgr to rgb convert
dnth Oct 10, 2023
b47812f
add default location for saved embeddings
dnth Oct 10, 2023
41ba84c
update sam
dnth Oct 10, 2023
e593805
add utils
dnth Oct 11, 2023
b132f8f
Merge branch 'dnth/timm-integration' into dnth/combine-merge
dnth Oct 12, 2023
fb3722f
Merge branch 'dnth/mmdet-integration' into dnth/combine-merge
dnth Oct 12, 2023
0a3689e
flatten directory
dnth Oct 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test-pypi-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ jobs:
if: runner.os == 'Linux'
run: |
python -m pip install --upgrade pip
pip install fastdup pytest pytest-md pytest-emoji datasets
pip install fastdup pytest pytest-md pytest-emoji datasets timm

- name: Install dependencies (macOS)
if: runner.os == 'macOS'
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov datasets fastdup
pip install pytest pytest-cov datasets fastdup timm
pip install opencv-python-headless
pip install fastdup --no-deps -U

Expand Down
153 changes: 153 additions & 0 deletions fastdup/embeddings_timm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import numpy as np
import logging
from PIL import Image
from tqdm.auto import tqdm
from fastdup.sentry import fastdup_capture_exception
from fastdup.image import fastdup_imread
from fastdup.utils import get_images_from_path

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("fastdup.embeddings.timm")

try:
import torch
except ImportError as e:
fastdup_capture_exception("embeddings_missing_pytorch_install", e, True)
logger.error(
"The `torch` package is not installed. Please run `pip install torch` or equivalent."
)

try:
import timm
except ImportError as e:
fastdup_capture_exception("embeddings_missing_timm_install", e, True)
logger.error(
"The `timm` package is not installed. Please run `pip install timm` or equivalent."
)


class TimmEncoder:
"""
A wrapper class for TIMM (PyTorch Image Models) to simplify model initialization and
feature extraction for image datasets.

Attributes:
model_name (str): The name of the model architecture to use.
num_classes (int): The number of classes for the model. Use num_features=0 to exclude the last layer.
pretrained (bool): Whether to load pretrained weights.
device (str): Which device to load the model on. Choices: "cuda" or "cpu".
torch_compile (bool): Whether to use torch.compile to optimize model.

embeddings (np.ndarray): The computed embeddings for the images.
file_paths (list): The file paths corresponding to the computed embeddings.
img_folder (str): The folder path containing images for which embeddings are computed.

Methods:
__init__(model_name, num_classes=0, pretrained=True, **kwargs): Initialize the wrapper.
_initialize_model(**kwargs): Internal method to initialize the TIMM model.
compute_embeddings(image_folder_path, save_dir="."): Compute and save embeddings in a local folder.

Example:
>>> wrapper = TimmEncoder(model_name='resnet18')
>>> wrapper.compute_embeddings('path/to/image/folder')
"""

def __init__(
self,
model_name: str,
num_classes: int = 0,
pretrained: bool = True,
device: str = None,
torch_compile: bool = False,
**kwargs,
):
self.model_name = model_name
self.num_classes = num_classes
self.pretrained = pretrained
self.torch_compile = torch_compile

# Pick available device if not specified.
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

self._initialize_model(**kwargs)
self.embeddings = None
self.file_paths = None
self.img_folder = None

def _initialize_model(self, **kwargs):
logger.info(f"Initializing model - {self.model_name}.")
self.model = timm.create_model(
self.model_name,
num_classes=self.num_classes,
pretrained=self.pretrained,
**kwargs,
)

if self.torch_compile:
logger.info("Running torch.compile.")
self.model = torch.compile(self.model, mode="max-autotune")

self.model.eval()
self.model = self.model.to(self.device)

logger.info(f"Model loaded on device - {self.device}")

def compute_embeddings(self, image_folder_path, save_dir="saved_embeddings"):
self.img_folder = image_folder_path

data_config = timm.data.resolve_model_data_config(self.model)
transforms = timm.data.create_transform(**data_config, is_training=False)

embeddings_list = []
file_paths = []

# Get images with extensions supported in fastdup
total_images = len(get_images_from_path(image_folder_path))

for image_file in tqdm(
os.listdir(image_folder_path),
desc="Computing embeddings",
total=total_images,
unit=" images",
):
img_path = os.path.join(image_folder_path, image_file)

try:
img = fastdup_imread(img_path, input_dir=None, kwargs=None)
img = img[:, :, ::-1] # Convert to RGB
img = Image.fromarray(img)
img_tensor = transforms(img).unsqueeze(0).to(self.device)

with torch.no_grad():
output = self.model.forward_features(img_tensor)
output = self.model.forward_head(output, pre_logits=True)

embeddings = output.cpu().numpy()
embeddings_list.append(embeddings)
file_paths.append(os.path.abspath(img_path))

except Exception as e:
logger.error(f"Skipping {img_path} due to error: {e}")

self.embeddings = np.vstack(embeddings_list)
self.file_paths = file_paths

os.makedirs(save_dir, exist_ok=True)

logger.info(f"Saving embeddings in directory - {save_dir} .")

np.save(
os.path.join(save_dir, f"{self.model_name.split('/')[-1]}_embeddings.npy"),
self.embeddings,
)

with open(
os.path.join(save_dir, f"{self.model_name.split('/')[-1]}_file_paths.txt"),
"w",
) as f:
for path in self.file_paths:
f.write(f"{path}\n")
80 changes: 78 additions & 2 deletions fastdup/fastdup_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pathlib
import re


class FastdupController:
def __init__(self, work_dir: Union[str, Path], input_dir: Union[str, Path, list] = None):
"""
Expand Down Expand Up @@ -199,7 +198,7 @@ def annotations(self, valid_only=True):
df_annot = self._df_annot.query(f'{FD.ANNOT_VALID}') if valid_only and self._df_annot is not None \
else self._df_annot
return df_annot


def similarity(self, data: bool = True, split: Union[str, List[str]] = None,
include_unannotated: bool = False, load_crops: bool = False) -> pd.DataFrame:
Expand Down Expand Up @@ -1260,6 +1259,7 @@ def _verify_fastdup_run_args(self, input_dir, work_dir, df_annot, subset, data_t
else:
assert False, f"Wrong data type {data_type}"


def caption(self, model_name='automatic', device = 'cpu', batch_size: int = 8, subset: list = None, vqa_prompt: str = None, kwargs=None) -> pd.DataFrame:
if not self._fastdup_applied:
raise RuntimeError('Fastdup was not applied yet, call run() first')
Expand All @@ -1283,7 +1283,83 @@ def caption(self, model_name='automatic', device = 'cpu', batch_size: int = 8, s
assert False, "Unknown model name provided. Available models for caption generation are 'vitgpt2', 'blip2', and 'blip'.\n Available models for VQA are 'vqa' and 'age'."

return df

def enrich(self, task, model, input_df, input_col, num_rows=None, device=None):

self._fastdup_applied = True # Hack: Allow users to run enrichment without first running fastdup
if num_rows:
df = input_df.head(num_rows)
else: df = input_df

if task == "zero-shot-classification":
if model == "recognize-anything-model":
from fastdup.models_ram import RecognizeAnythingModel

enrichment_model = RecognizeAnythingModel(device=device)
df["ram_tags"] = df[input_col].apply(enrichment_model.run_inference)

elif model == "tag2text":
from fastdup.models_tag2text import Tag2TextModel

enrichment_model = Tag2TextModel(device=device)
df["tag2text_tags"] = df[input_col].apply(
lambda x: enrichment_model.run_inference(x)[0].replace(" | ", " . ")
)
df["tag2text_caption"] = df[input_col].apply(
lambda x: enrichment_model.run_inference(x)[2]
)

elif task == "zero-shot-detection":
if model == "grounding-dino":
from fastdup.models_grounding_dino import GroundingDINO

enrichment_model = GroundingDINO(device=device)

def compute_bbox(row):
results = enrichment_model.run_inference(
row["filename"], text_prompt=row[input_col]
)
return results["boxes"], results["scores"], results["labels"]

df["grounding_dino_bboxes"], df["grounding_dino_scores"], df["grounding_dino_labels"] = zip(
*df.apply(compute_bbox, axis=1)
)

elif task == "zero-shot-segmentation":
if model == "segment-anything":
import torch
from fastdup.models_sam import SegmentAnythingModel

enrichment_model = SegmentAnythingModel(device=device)

try:
tensor_list = [
torch.tensor(bbox, dtype=torch.float32)
for bbox in df[input_col]
]
except Exception as e:
raise KeyError(
f"Column `{input_col}` does not exist."
)

def preprocess_and_run(filename, bbox):
result = enrichment_model.run_inference(filename, bboxes=bbox)

if isinstance(result, torch.Tensor) and result.device.type == "cuda":
result = result.cpu()

return result

df["sam_masks"] = [
preprocess_and_run(filename, bbox)
for filename, bbox in zip(df["filename"], tensor_list)
]

try:
enrichment_model.unload_model()
except Exception as e:
raise ValueError("Please select a valid enrichment model")
return df


def is_fastdup_dir(work_dir):
Expand Down
Loading