Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use CLIP-ReID as feature extractor? #10

Closed
mikel-brostrom opened this issue Jul 27, 2023 · 39 comments
Closed

How to use CLIP-ReID as feature extractor? #10

mikel-brostrom opened this issue Jul 27, 2023 · 39 comments

Comments

@mikel-brostrom
Copy link

mikel-brostrom commented Jul 27, 2023

Hi! Would first of all like to know whether your are okay with me implementing these models here: https://github.com/mikel-brostrom/yolo_tracking. Then I would also like to know if there is any easy way of extracting features with these models. Keep up the great work!

@awarebayes
Copy link

def forward_override(self, x: torch.Tensor, cv_emb = None, old_forward = None):
    _, image_features, image_features_proj = old_forward(x, cv_emb)
    return torch.cat([image_features[:,0], image_features_proj[:,0]], dim=1)
    
def  main():

    model_path = cfg.TEST.WEIGHT
    batch_size = cfg.TEST.IMS_PER_BATCH
    output_dir = cfg.OUTPUT_DIR
    output_name = cfg.MODEL.OUTPUT_NAME

    train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
    model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)

    logger.info(f"Loading model for eval from {model_path}. Batch size = {batch_size}")

    model.load_param(model_path)
    model = model.image_encoder

    old_forward = model.forward
    model.forward = lambda *args, **kwargs: forward_override(model, old_forward=old_forward, *args, **kwargs)

    device = torch.device('cuda:0')
    model = model.eval().to(device)

    image_dir = Path(args.test_dir)
    dataset = ReidDataset(image_dir, cfg)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    for i, (images, person_ids) in loader:
            images = images.to(device)
            vectors = model(images)

@mikel-brostrom
Copy link
Author

Thanks for the example code @awarebayes! Really appreciate it

@awarebayes
Copy link

We are working on some fixes on the repo.
If you are using cosine distance for your descriptors, just change euclidean_dist to cosine_dist, used for ranking.
https://github.com/Syliz517/CLIP-ReID/blob/master/loss/triplet_loss.py#L123C26-L123C26

We have found it makes metrics worse, but separation with cosine distances is way better. Can show you the graphs of average distances between same vs others

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

If you are using cosine distance for your descriptors, just change euclidean_dist to cosine_dist, used for ranking

Have you visualized the feature learning process? Would make sense if the features tend to cluster like beams shooting out of origo. Would be enlightening if you could share the graphs ✨

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

Let me know when this is in usable state 😄 . I have been trying to get the MSMT17_clipreid_12x12sie_ViT model working with no luck.

Traceback (most recent call last):
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 52, in <module>
    main()
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 40, in main
    model.load_param(model_path)
  File "/home/mikel.brostrom/CLIP-ReID/model/make_model.py", line 121, in load_param
    self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
RuntimeError: The size of tensor a (2) must match the size of tensor b (15) at non-singleton dimension 0

You can reproduce this with:

import os
import torch
import yaml
import argparse
from config import cfg_base as cfg
from model.make_model import make_model

def forward_override(self, x: torch.Tensor, cv_emb = None, old_forward = None):
    _, image_features, image_features_proj = old_forward(x, cv_emb)
    return torch.cat([image_features[:,0], image_features_proj[:,0]], dim=1)
    
def  main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument(
        "--config_file",
        default="/home/mikel.brostrom/CLIP-ReID/MSMT17_clipreid_12x12sie_ViT-B-16_60_test_log.yml",
        help="path to config file",
        type=str
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER
    )

    args = parser.parse_args()

    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    model = make_model(cfg, num_class=1501, camera_num=2, view_num = 1)

    model_path = 'MSMT17_clipreid_12x12sie_ViT-B-16_60.pth'
    batch_size= 1
    print(f"Loading model for eval from {model_path}. Batch size = {batch_size}")

    model.load_param(model_path)
    model = model.image_encoder

    old_forward = model.forward
    model.forward = lambda *args, **kwargs: forward_override(model, old_forward=old_forward, *args, **kwargs)

    device = torch.device('cuda:0')
    model = model.eval().to(device)

    
main()

Had to delete some stuff from the configs as well to get to the weight loading part...

@awarebayes
Copy link

You can override config parameters as:

--config_file
src/configs/person/vit_clipreid.yml
--test_dir
"./dataset/reid/market1501/bounding_box_test"
MODEL.DEVICE_ID
"'0'"
DATASETS.NAMES
"market1501"
DATASETS.ROOT_DIR
"./dataset/reid"
TEST.WEIGHT
output/clipreid.pt
SOLVER.SEED
42

@awarebayes
Copy link

As for the graphs, I am doing one vs rest cosine distance distribution.

When ranking with cosine:
$OUTPUT_NAME_vectors_distribution

When ranking with euclidean:
image

Do you have any examples on how to 'visualize the feature learning process'?

@mikel-brostrom
Copy link
Author

Do you have any examples on how to 'visualize the feature learning process'?

https://github.com/KaiyangZhou/pytorch-center-loss

@awarebayes
Copy link

They have a bottleneck of size 2, which they visualize. I dont think we can have that with person reid

self.fc1 = nn.Linear(128*3*3, 2)
self.prelu_fc1 = nn.PReLU()
self.fc2 = nn.Linear(2, num_classes)

@mikel-brostrom
Copy link
Author

You can override config parameters as:

Thx, but it seems that the loaded ViT-B-16 is not matching the one in MSMT17_clipreid_12x12sie_ViT-B-16_60.pth

@awarebayes
Copy link

awarebayes commented Jul 27, 2023

import os
import torch
import yaml
import argparse

from PIL import Image

from config import cfg
import torchvision.transforms as T

from model.make_model import make_model


def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None):
    _, image_features, image_features_proj = old_forward(x, cv_emb)
    return torch.cat([image_features[:, 0], image_features_proj[:, 0]], dim=1)


def main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument(
        "--config_file",
        default="",
        help="path to config file",
        type=str
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER
    )

    args = parser.parse_args()

    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    model = make_model(cfg, num_class=751, camera_num=2, view_num=1)

    model_path = '/home/mscherbina/Downloads/Market1501_clipreid_ViT-B-16_60.pth'
    batch_size = 1
    print(f"Loading model for eval from {model_path}. Batch size = {batch_size}")

    model.load_param(model_path)
    model = model.image_encoder

    old_forward = model.forward
    model.forward = lambda *args, **kwargs: forward_override(model, old_forward=old_forward, *args, **kwargs)

    device = torch.device('cuda:0')
    model = model.eval().to(device)

    transforms = T.Compose([
        T.Resize([256, 128]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image_dir = '....'
    for i in os.listdir(image_dir):
        img = Image.open(f"{image_dir}/{i}").convert("RGB")
        img = transforms(img).to(device)
        img = img.unsqueeze(0)
        with torch.no_grad():
            feats = model(img)
            print("i", i, feats.size())




main()

Tested with ViT-CLIP-ReID Market

With that you can keep config the default one

from config import cfg

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

I changed the following content in:

config/defaults.py

-_C.MODEL.NAME = 'resnet50'
+_C.MODEL.NAME = 'ViT-B-16'

and adapted model_path in the code above. I am on master with no modifications. Run the above script as:

python feature_extraction_test.py

but get the following error

Loading model for eval from /home/mikel.brostrom/CLIP-ReID/MSMT17_clipreid_12x12sie_ViT-B-16_60.pth. Batch size = 1
Traceback (most recent call last):
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 73, in <module>
    main()
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 48, in main
    model.load_param(model_path)
  File "/home/mikel.brostrom/CLIP-ReID/model/make_model.py", line 121, in load_param
    self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
KeyError: 'cv_embed'

@awarebayes
Copy link

I changed load_param to be

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in self.state_dict():
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

Ok, think I am loading the SIE-OLP model. Let me try the one you linked

@mikel-brostrom
Copy link
Author

Loading Market1501_clipreid_ViT-B-16_60.pth now and getting

Loading model for eval from /home/mikel.brostrom/CLIP-ReID/Market1501_clipreid_ViT-B-16_60.pth. Batch size = 1
Traceback (most recent call last):
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 73, in <module>
    main()
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 48, in main
    model.load_param(model_path)
  File "/home/mikel.brostrom/CLIP-ReID/model/make_model.py", line 121, in load_param
    self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
RuntimeError: The size of tensor a (193) must match the size of tensor b (129) at non-singleton dimension 0

@mikel-brostrom
Copy link
Author

Still same issue after updating load_param

@awarebayes
Copy link

Hmm, which key is it?
Also do you have telegram, we could go there

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

Sorry, no telegram, only whatsapp 😞

@mikel-brostrom
Copy link
Author

Added a print in load_param:

def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in self.state_dict():
            print(i)
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

It stops at positional_embedding:

bottleneck.weight
bottleneck.bias
bottleneck.running_mean
bottleneck.running_var
bottleneck.num_batches_tracked
bottleneck_proj.weight
bottleneck_proj.bias
bottleneck_proj.running_mean
bottleneck_proj.running_var
bottleneck_proj.num_batches_tracked
image_encoder.class_embedding
image_encoder.positional_embedding
Traceback (most recent call last):
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 73, in <module>
    main()
  File "/home/mikel.brostrom/CLIP-ReID/feature_extraction_test.py", line 48, in main
    model.load_param(model_path)
  File "/home/mikel.brostrom/CLIP-ReID/model/make_model.py", line 122, in load_param
    self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
RuntimeError: The size of tensor a (193) must match the size of tensor b (129) at non-singleton dimension 0

@awarebayes
Copy link

awarebayes commented Jul 27, 2023

I believe config is to blame:

        self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0]-16)//cfg.MODEL.STRIDE_SIZE[0] + 1)
        self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1]-16)//cfg.MODEL.STRIDE_SIZE[1] + 1)
        self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
        clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
        clip_model.to("cuda")

Config I load with is https://gist.github.com/awarebayes/271beb52cabc9cf0bc77f592764e1b62, maybe change it.

@mikel-brostrom
Copy link
Author

Stride is 16 in both confs...

@awarebayes
Copy link

awarebayes commented Jul 27, 2023

At the end of the day I have these values here:
Ignore 457, its #persons in my dataset
image

@awarebayes
Copy link

By the way I noticed ViT with tensorrt + fp16 does not work any slower than resnet50

@mikel-brostrom
Copy link
Author

ok. Fixed.

# Size of the image during training
_C.INPUT.SIZE_TRAIN = [256, 128]
# Size of the image during test
_C.INPUT.SIZE_TEST = [256, 128]

instead of

# Size of the image during training
_C.INPUT.SIZE_TRAIN = [384, 128]
# Size of the image during test
_C.INPUT.SIZE_TEST = [384, 128]

@mikel-brostrom
Copy link
Author

Thx for you patience 😄.

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

I am kind of surprised by the inference time. For a torch.Size([1, 3, 256, 128]) input, I got 0.058 seconds averaged over 1000 images on an old NVIDIA p2000. Didn't though it would be that low, even taking into consideration the small input size.

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

0.37 seconds for an torch.Size([10, 3, 256, 128]) input. In case anybody is interested 😄.

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

Are you responsible for this repo now @awarebayes? If this is the case would you be okay with me implementing these models here: https://github.com/mikel-brostrom/yolo_tracking. @Syliz517 ? I intend to use them for associating detections across frames both by visual means only, as well as motion plus appearance information.

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 27, 2023

By the way I noticed ViT with tensorrt + fp16 does not work any slower than resnet50

Any slower than the regular pt resnet weights? Do you have any plans of adding a TensorRT converter to this repo?

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 28, 2023

As for the graphs, I am doing one vs rest cosine distance distribution.

Clearly cosine distance separates the embeddings better. I am however not entirely sure why the large tails on the distributions.

@awarebayes
Copy link

I can help you with implementing clip reid into yolo tracking.
I just got tasked to fix this repo at work.
TensorRT converter is unnessesary since it has ONNX, and ONNX converts to tensorrt natively using trt-exec

@mikel-brostrom
Copy link
Author

I have it working in real-time now

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 28, 2023

git clone https://github.com/mikel-brostrom/yolo_tracking
git checkout clip-reid
python examples/track.py --reid-model clip_market1501.pt
# or
python examples/track.py --reid-model clip_duke.pt

haven't evaluated on any MOT dataset yet

@awarebayes
Copy link

I dont know if keeping this guys repo issues section as a communication medium is a good idea, but whatever.
Have you tried using video reid nets?

https://paperswithcode.com/task/video-based-person-re-identification

They basically work the same way, but during learning, they take an average of the tracklet's features instead of individual image's.

@mikel-brostrom
Copy link
Author

mikel-brostrom commented Jul 31, 2023

I dont know if keeping this guys repo issues section as a communication medium is a good idea

I think I have telegram, just uninstalled it because didn't use it. Let me check

Have you tried using video reid nets?

Interesting, never heard of this! Could be quite heavy though for real-time as crop-outs are stacked. This is however handled to some extent in modern multi-object trackers as individual feature-maps are averaged over several frames. Like here:
https://github.com/NirAharon/BoT-SORT/blob/251985436d6712aaf682aaaf5f71edb4987224bd/tracker/bot_sort.py#L41

@Syliz517
Copy link
Owner

Hi! Would first of all like to know whether your are okay with me implementing these models here: https://github.com/mikel-brostrom/yolo_tracking. Then I would also like to know if there is any easy way of extracting features with these models. Keep up the great work!

Of course, I notice you had some problems before, have you resolved them now?

@mikel-brostrom
Copy link
Author

Thanks for your response @Syliz517! @awarebayes helped me out to get it working 😄

@edwardnguyen1705
Copy link

import os
import torch
import yaml
import argparse

from PIL import Image

from config import cfg
import torchvision.transforms as T

from model.make_model import make_model


def forward_override(self, x: torch.Tensor, cv_emb=None, old_forward=None):
    _, image_features, image_features_proj = old_forward(x, cv_emb)
    return torch.cat([image_features[:, 0], image_features_proj[:, 0]], dim=1)


def main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument(
        "--config_file",
        default="",
        help="path to config file",
        type=str
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER
    )

    args = parser.parse_args()

    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    model = make_model(cfg, num_class=751, camera_num=2, view_num=1)

    model_path = '/home/mscherbina/Downloads/Market1501_clipreid_ViT-B-16_60.pth'
    batch_size = 1
    print(f"Loading model for eval from {model_path}. Batch size = {batch_size}")

    model.load_param(model_path)
    model = model.image_encoder

    old_forward = model.forward
    model.forward = lambda *args, **kwargs: forward_override(model, old_forward=old_forward, *args, **kwargs)

    device = torch.device('cuda:0')
    model = model.eval().to(device)

    transforms = T.Compose([
        T.Resize([256, 128]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image_dir = '....'
    for i in os.listdir(image_dir):
        img = Image.open(f"{image_dir}/{i}").convert("RGB")
        img = transforms(img).to(device)
        img = img.unsqueeze(0)
        with torch.no_grad():
            feats = model(img)
            print("i", i, feats.size())




main()

Tested with ViT-CLIP-ReID Market

With that you can keep config the default one

from config import cfg

Hi @awarebayes ,
I see the feats.size() = 1280 (768 + 512), if I would like the model to output to different feature size, how would I config?

@awarebayes
Copy link

I would suggest changing the size, seeing where it fails, and debugging from there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants