In [1]:
import os
import cv2
import numpy as np
from numpy import dot
from numpy.linalg import norm
import sys
import glob
import json
import h5py
import math
from tqdm import tqdm
import torch
import torchvision
import torchvision.transforms as trn
import torchvision.models as models
import torchvision.ops.roi_align as roi_align

from modules.until_module import PreTrainedModel, AllGather, CrossEn
from modules.module_cross import CrossModel, CrossConfig, Transformer as TransformerClip

from modules.module_clip import CLIP, convert_weights
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import pickle
import pathlib

In [2]:
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import CLIP4Clip
from modules.optimization import BertAdam
from util import parallel_apply, get_logger

device = torch.device('cuda')

In [3]:
# Argument
class args:
    msvd = True # or msvd = False for MSR-VTT
    slice_framepos=2
    dset ='../' # change based on dataset location
    save_path = '../extracted_features'
    max_frames = 20
    eval_frame_order = 0 
    output_dir='pretrained'
    cache_dir=''
    
    features_path='..'
    msrvtt_csv ='msrvtt.csv'
    data_path ='MSRVTT_data.json'
    max_words=32
    feature_framerate=1
    cross_model="cross-base"
    local_rank=0

In [4]:
#MSVD 
if args.msvd:
    dset_path = os.path.join(os.path.join(args.dset,'dataset'),'MSVD')
    features_path = os.path.join(dset_path,'raw') # video .avi    
    name_list = glob.glob(features_path+os.sep+'*')
    args.features_path = features_path

    url2id = {}
    data_path = os.path.join(os.path.join(dset_path,'captions','youtube_mapping.txt'))
    args.data_path = data_path
    for line in open(data_path,'r').readlines():
        url2id[line.strip().split(' ')[0]] = line.strip().split(' ')[-1]

    path_to_saved_models = f"{args.save_path}/msvd"
    pathlib.Path(path_to_saved_models).mkdir(parents=True, exist_ok=True)
    save_file = path_to_saved_models+'/MSVD_Clip4Clip_features.pickle'
    args.max_words =30
    
    # Load video to dataloader
    %run ../dataloaders/dataloader_msvd.py import MSVD_Loader
    
    videos= MSVD_Loader(
        data_path=args.data_path,
        features_path=args.features_path,
        max_words=args.max_words,
        feature_framerate=args.feature_framerate,
        max_frames=args.max_frames,
        frame_order=args.eval_frame_order,
        slice_framepos=args.slice_framepos,
        transform_type = 0,
    ) 
#MSR-VTT    
else:
  
    dset_path = os.path.join(os.path.join(args.dset,'dataset'),'MSRVTT')
    features_path = os.path.join(dset_path,'raw') 
    args.features_path = features_path
    data_path=os.path.join(dset_path,'MSRVTT_data.json')
    args.data_path = data_path
    args.msrvtt_csv = os.path.join(dset_path,'msrvtt.csv')
    name_list = glob.glob(features_path+os.sep+'*')
    
    path_to_saved_models = "extracted/msrvtt"
    pathlib.Path(path_to_saved_models).mkdir(parents=True, exist_ok=True)
    save_file = path_to_saved_models+'/MSRVTT_Clip4Clip_features.pickle'
    args.max_words =73
    
    #Load video to dataloader
    %run ../dataloaders/dataloader_msrvtt.py import MSRVTT_RawDataLoader
    videos= MSRVTT_RawDataLoader(
        csv_path=args.msrvtt_csv,
        features_path=args.features_path,
        max_words=args.max_words,
        feature_framerate=args.feature_framerate,
        max_frames=args.max_frames,
        frame_order=args.eval_frame_order,
        slice_framepos=args.slice_framepos,
        transform_type = 0,
)

Video number: 1970
Id number: 1970


### Load the CLIP4Clip pretrained models

In [5]:
epoch = 5
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch-1))
model_state_dict = torch.load(model_file, map_location='cpu')
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)

  model_state_dict = torch.load(model_file, map_location='cpu')
Stage-One:True, Stage-Two:False
	 embed_dim: 512
	 image_resolution: 224
	 vision_layers: 12
	 vision_width: 768
	 vision_patch_size: 32
	 context_length: 77
	 vocab_size: 49408
	 transformer_width: 512
	 transformer_heads: 8
	 transformer_layers: 12
	 cut_top_layer: 0
Weights from pretrained model cause errors in CLIP4Clip: 
   size mismatch for clip.visual.positional_embedding: copying a param with shape torch.Size([197, 768]) from checkpoint, the shape in current model is torch.Size([50, 768]).
   size mismatch for clip.visual.conv1.weight: copying a param with shape torch.Size([768, 3, 16, 16]) from checkpoint, the shape in current model is torch.Size([768, 3, 32, 32]).


In [6]:
model

CLIP4Clip(
  (clip): CLIP(
    (visual): VisualTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLi

In [7]:
clip = model.clip.to(device)

### Extract clip features

In [8]:

clip.eval()

with torch.no_grad():
    data ={}
    stop = False
    with open(save_file, 'wb') as handle:

        for i in tqdm(range(len(videos))):

            video_id,video,video_mask = videos[i]

            tensor = video[0]
            tensor = tensor[video_mask[0]==1,:]
            tensor = torch.as_tensor(tensor).float()
            video_frame,num,channel,h,w = tensor.shape
            tensor = tensor.view(video_frame*num, channel, h, w)

            video_frame,channel,h,w = tensor.shape


            output = clip.encode_image(tensor.to(device), video_frame=video_frame).float().to(device)
            output = output.detach().cpu().numpy()
            data[video_id]=output

            del output
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


  0%|          | 3/1970 [00:11<2:05:18,  3.82s/it]


KeyboardInterrupt: 