* RAM++ is the next generation of RAM, which can recognize any category with high accuracy, including both predefined common categories and diverse open-set categories.
* RAM++ outperforms existing SOTA image fundamental recognition models on common tag categories, uncommon tag categories, and human-object interaction phrases.
* repository: https://github.com/xinyu1205/recognize-anything
* the implementation of tag encoding is inspired by this repository: https://github.com/AIVIETNAMResearch/VN_Multi_User_Video_Search

In [1]:
!git clone https://github.com/xinyu1205/recognize-anything.git
%cd recognize-anything

fatal: destination path 'recognize-anything' already exists and is not an empty directory.
/home/jupyter/Workspace/Amatos_hcm_ai/modules/ram_model/recognize-anything


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import os
import glob
import json
import torch
import numpy as np
from PIL import Image
from ram.models import ram_plus
from ram import get_transform
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
keyframes_dir = '../../../db/keyframes'

In [6]:
# run data for first 2 videos
# N_VIDEOS = 2

keyframes_dir = '../../../db/keyframes'

video_keyframe_paths = dict()

# for video_id in sorted(os.listdir(keyframes_dir))[:N_VIDEOS]:
for video_id in sorted(os.listdir(keyframes_dir)):

    # print(video_id)
    
    keyframe_paths = sorted(glob.glob(f'{keyframes_dir}/{video_id}/*.jpg'))
    
    # print(keyframe_paths[:5])
    
    video_keyframe_paths[video_id] = keyframe_paths



In [7]:
video_keyframe_paths.keys()

dict_keys(['L01_V001', 'L01_V002', 'L01_V003', 'L01_V004', 'L01_V005', 'L01_V006', 'L01_V007', 'L01_V008', 'L01_V009', 'L01_V010', 'L01_V011', 'L01_V012', 'L01_V013', 'L01_V014', 'L01_V015', 'L01_V016', 'L01_V017', 'L01_V018', 'L01_V019', 'L01_V020', 'L01_V021', 'L01_V022', 'L01_V023', 'L01_V024', 'L01_V025', 'L01_V026', 'L01_V027', 'L01_V028', 'L01_V029', 'L01_V030', 'L01_V031', 'L02_V001', 'L02_V002', 'L02_V003', 'L02_V004', 'L02_V005', 'L02_V006', 'L02_V007', 'L02_V008', 'L02_V009', 'L02_V010', 'L02_V011', 'L02_V012', 'L02_V013', 'L02_V014', 'L02_V015', 'L02_V016', 'L02_V017', 'L02_V018', 'L02_V019', 'L02_V020', 'L02_V021', 'L02_V022', 'L02_V023', 'L02_V024', 'L02_V025', 'L02_V026', 'L02_V027', 'L02_V028', 'L02_V029', 'L02_V030', 'L02_V031', 'L03_V001', 'L03_V002', 'L03_V003', 'L03_V004', 'L03_V005', 'L03_V006', 'L03_V007', 'L03_V008', 'L03_V009', 'L03_V010', 'L03_V011', 'L03_V012', 'L03_V013', 'L03_V014', 'L03_V015', 'L03_V016', 'L03_V017', 'L03_V018', 'L03_V019', 'L03_V020', 'L03_

# Download checkpoint

In [8]:
def download_checkpoints(model):

    if not os.path.exists('pretrained'):
        os.makedirs('pretrained')

    if model == "ram_plus":
        ram_plus_weights_path = 'pretrained/ram_plus_swin_large_14m.pth'
        if not os.path.exists(ram_plus_weights_path):
            !wget https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth -O pretrained/ram_plus_swin_large_14m.pth
        else:
            print("RAM plus weights already downloaded!")

model = "ram_plus"
download_checkpoints(model)
print(model, 'weights are downloaded!')

RAM plus weights already downloaded!
ram_plus weights are downloaded!


# Helper Functions

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

@torch.no_grad()
def forward_ram(model, imgs):
    image_embeds = model.image_proj(model.visual_encoder(imgs))
    image_atts = torch.ones(image_embeds.size()[:-1],
                            dtype=torch.long).to(imgs.device)

    image_cls_embeds = image_embeds[:, 0, :]

    bs = imgs.shape[0]

    des_per_class = int(model.label_embed.shape[0] / model.num_class)

    image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
    reweight_scale = model.reweight_scale.exp()
    logits_per_image = (reweight_scale * image_cls_embeds @ model.label_embed.t())
    logits_per_image = logits_per_image.view(bs, -1,des_per_class)

    weight_normalized = torch.nn.functional.softmax(logits_per_image, dim=2)
    label_embed_reweight = torch.empty(bs, model.num_class, 512).to(imgs.device).to(imgs.dtype)

    for i in range(bs):
        reshaped_value = model.label_embed.view(-1, des_per_class, 512)
        product = weight_normalized[i].unsqueeze(-1) * reshaped_value
        label_embed_reweight[i] = product.sum(dim=1)

    label_embed = torch.nn.functional.relu(model.wordvec_proj(label_embed_reweight))

    # recognized image tags using alignment decoder
    tagging_embed = model.tagging_head(
        encoder_embeds=label_embed,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_atts,
        return_dict=False,
        mode='tagging',
    )

    logits = model.fc(tagging_embed[0]).squeeze(-1)

    targets = torch.where(
        torch.sigmoid(logits) > model.class_threshold.to(device),
        torch.tensor(1.0).to(device),
        torch.zeros(model.num_class).to(device))

    tag = targets.cpu().numpy()
    tag[:,model.delete_tag_index] = 0
    
    tag_outputs = []
    tag_logits = []

    for b in range(bs):
        index = np.argwhere(tag[b] == 1)
        tokens = model.tag_list[index].squeeze(axis=1)
        tag_outputs.append([token.replace(" ", "_") for token in tokens])
        scores = logits[b][index[:, 0]]
        tag_logits.append(scores.cpu().numpy())


    return tag_outputs, tag_logits

# Run inference

In [10]:
transform = get_transform(image_size=384)
model = ram_plus(pretrained='pretrained/ram_plus_swin_large_14m.pth',
            image_size=384,
            vit='swin_l')
model.eval()
model = model.to(device)
tag_list = model.tag_list

--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l


In [22]:
# del model

In [11]:
tag_list

array(['3D CG rendering', '3D glasses', 'abacus', ..., 'zombie', 'zongzi',
       'zoo'], dtype='<U27')

In [12]:
output_tag_list = [tag.replace(" ", "_").lower() for tag in tag_list.tolist()]

In [13]:
output_tag_list[:5]

['3d_cg_rendering', '3d_glasses', 'abacus', 'abalone', 'monastery']

# Save tag list to pickle for later use

In [14]:
import pickle

In [15]:
# save tag list for later use
with open("../../../util/tag_list", "wb") as fp:
    pickle.dump(output_tag_list, fp)

In [16]:
with open("../../../util/tag_list", "rb") as fp:
    tag_list_from_pickle = pickle.load(fp)

In [17]:
tag_list_from_pickle[:5]

['3d_cg_rendering', '3d_glasses', 'abacus', 'abalone', 'monastery']

In [18]:
video_keyframe_paths.keys()

dict_keys(['L01_V001', 'L01_V002', 'L01_V003', 'L01_V004', 'L01_V005', 'L01_V006', 'L01_V007', 'L01_V008', 'L01_V009', 'L01_V010', 'L01_V011', 'L01_V012', 'L01_V013', 'L01_V014', 'L01_V015', 'L01_V016', 'L01_V017', 'L01_V018', 'L01_V019', 'L01_V020', 'L01_V021', 'L01_V022', 'L01_V023', 'L01_V024', 'L01_V025', 'L01_V026', 'L01_V027', 'L01_V028', 'L01_V029', 'L01_V030', 'L01_V031', 'L02_V001', 'L02_V002', 'L02_V003', 'L02_V004', 'L02_V005', 'L02_V006', 'L02_V007', 'L02_V008', 'L02_V009', 'L02_V010', 'L02_V011', 'L02_V012', 'L02_V013', 'L02_V014', 'L02_V015', 'L02_V016', 'L02_V017', 'L02_V018', 'L02_V019', 'L02_V020', 'L02_V021', 'L02_V022', 'L02_V023', 'L02_V024', 'L02_V025', 'L02_V026', 'L02_V027', 'L02_V028', 'L02_V029', 'L02_V030', 'L02_V031', 'L03_V001', 'L03_V002', 'L03_V003', 'L03_V004', 'L03_V005', 'L03_V006', 'L03_V007', 'L03_V008', 'L03_V009', 'L03_V010', 'L03_V011', 'L03_V012', 'L03_V013', 'L03_V014', 'L03_V015', 'L03_V016', 'L03_V017', 'L03_V018', 'L03_V019', 'L03_V020', 'L03_

In [None]:
#### TAGGING WITH FREQUENCY

bs = 32
save_dir_all = '../../../db'

if not os.path.exists(save_dir_all):
    os.mkdir(save_dir_all)


save_dir = f'{save_dir_all}/ram_plus_encoded'

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# for key, video_keyframe_paths in all_keyframe_paths.items():
video_ids = sorted(video_keyframe_paths.keys())

for video_id in tqdm(video_ids):
    
    if not os.path.exists(os.path.join(save_dir, video_id)):
        os.mkdir(os.path.join(save_dir, video_id))
    
    
    tag_contexts = []
    video_keyframe_path = video_keyframe_paths[video_id]


    for i in tqdm(range(0, len(video_keyframe_path), bs)):
        # Support batchsize inferencing
        images = []
        image_paths = video_keyframe_path[i:i+bs]
        for image_path in image_paths:
            image = transform(Image.open(image_path)).unsqueeze(0)
            images.append(image)
        images = torch.cat(images).to(device)

        # Forward ram model
        tag_outputs, tag_logits = forward_ram(model, images)

        # Encode result
        for b in range(len(tag_outputs)):
            tag_context = []
            tag_output, tag_logit = tag_outputs[b], tag_logits[b]
            tag_frequency = np.round(tag_logit*10).astype(int)
            for tag, freq in zip(tag_output, tag_frequency):
                tag_context.extend([tag]*freq)
            tag_context = ' '.join(map(str, tag_context))
            tag_contexts.append(tag_context)

    if len(tag_contexts) != len(video_keyframe_path):
        print("Something wrong!!!!!")
        break

    # Saving the video tag context txt
    with open(f"{save_dir}/{video_id}.txt", "w") as f:
        for item in tag_contexts:
            f.write("%s\n" % item)   

  0%|          | 0/1471 [00:00<?, ?it/s]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:04<00:34,  4.31s/it][A
 22%|██▏       | 2/9 [00:08<00:29,  4.25s/it][A
 33%|███▎      | 3/9 [00:12<00:25,  4.24s/it][A
 44%|████▍     | 4/9 [00:16<00:21,  4.21s/it][A
 56%|█████▌    | 5/9 [00:21<00:16,  4.20s/it][A
 67%|██████▋   | 6/9 [00:25<00:12,  4.19s/it][A
 78%|███████▊  | 7/9 [00:29<00:08,  4.18s/it][A
 89%|████████▉ | 8/9 [00:33<00:04,  4.15s/it][A
100%|██████████| 9/9 [00:35<00:00,  3.95s/it][A
  0%|          | 1/1471 [00:35<14:31:12, 35.56s/it]
  0%|          | 0/7 [00:00<?, ?it/s][A
 14%|█▍        | 1/7 [00:04<00:24,  4.11s/it][A
 29%|██▊       | 2/7 [00:08<00:20,  4.10s/it][A
 43%|████▎     | 3/7 [00:12<00:16,  4.09s/it][A
 57%|█████▋    | 4/7 [00:16<00:12,  4.08s/it][A
 71%|███████▏  | 5/7 [00:20<00:08,  4.08s/it][A
 86%|████████▌ | 6/7 [00:24<00:04,  4.08s/it][A
100%|██████████| 7/7 [00:27<00:00,  3.94s/it][A
  0%|          | 2/1471 [01:03<12:36:09, 

In [23]:
del images

In [None]:
import gc
gc.collect()

In [20]:
# ##### SIMPLE TAGGING

# bs = 4
# save_dir_all = '../../../db'

# if not os.path.exists(save_dir_all):
#     os.mkdir(save_dir_all)


# save_dir = f'{save_dir_all}/ram_plus_encoded'

# save_dir_v2 = f'{save_dir_all}/ram_plus_encoded_v2'



# if not os.path.exists(save_dir):
#     os.mkdir(save_dir)

# if not os.path.exists(save_dir_v2):
#     os.mkdir(save_dir_v2)
# # for key, video_keyframe_paths in all_keyframe_paths.items():
# video_ids = sorted(video_keyframe_paths.keys())

# for video_id in tqdm(video_ids):
    
#     tag_contexts = []
#     video_keyframe_path = video_keyframe_paths[video_id]


#     for i in tqdm(range(0, len(video_keyframe_path), bs)):
#         # Support batchsize inferencing
#         images = []
#         image_paths = video_keyframe_path[i:i+bs]
#         for image_path in image_paths:
#             image = transform(Image.open(image_path)).unsqueeze(0)
#             images.append(image)
#         images = torch.cat(images).to(device)

#         # Forward ram model
#         tag_outputs, tag_logits = forward_ram(model, images)

#         # Encode result
#         for b in range(len(tag_outputs)):
#             tag_output, tag_logit = tag_outputs[b], tag_logits[b]
#             tag_context = ' '.join(map(str, tag_output))
#             tag_contexts.append(tag_context)

#     if len(tag_contexts) != len(video_keyframe_path):
#         print("Something wrong!!!!!")
#         break

#     # Saving the video tag context txt
#     with open(f"{save_dir_v2}/{video_id}.txt", "w") as f:
#         for item in tag_contexts:
#             f.write("%s\n" % item)   

  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/68 [00:00<?, ?it/s][A
  1%|▏         | 1/68 [00:00<00:39,  1.69it/s][A
  3%|▎         | 2/68 [00:01<00:36,  1.81it/s][A
  4%|▍         | 3/68 [00:01<00:35,  1.82it/s][A
  6%|▌         | 4/68 [00:02<00:34,  1.83it/s][A
  7%|▋         | 5/68 [00:02<00:34,  1.83it/s][A
  9%|▉         | 6/68 [00:03<00:33,  1.85it/s][A
 10%|█         | 7/68 [00:03<00:33,  1.84it/s][A
 12%|█▏        | 8/68 [00:04<00:32,  1.83it/s][A
 13%|█▎        | 9/68 [00:04<00:32,  1.83it/s][A
 15%|█▍        | 10/68 [00:05<00:31,  1.82it/s][A
 16%|█▌        | 11/68 [00:06<00:31,  1.82it/s][A
 18%|█▊        | 12/68 [00:06<00:30,  1.83it/s][A
 19%|█▉        | 13/68 [00:07<00:30,  1.83it/s][A
 21%|██        | 14/68 [00:07<00:29,  1.82it/s][A
 22%|██▏       | 15/68 [00:08<00:29,  1.82it/s][A
 24%|██▎       | 16/68 [00:08<00:28,  1.82it/s][A
 25%|██▌       | 17/68 [00:09<00:28,  1.81it/s][A
 26%|██▋       | 18/68 [00:09<00:27,  1.81it/s][A
 28%|██▊   