In [1]:
import sys, os

env_root = '/N/project/baby_vision_curriculum/pythonenvs/hfenv/lib/python3.10/site-packages/'
sys.path.insert(0, env_root)

os.environ['OPENBLAS_NUM_THREADS'] = '38' #@@@@ to help with the num_workers issue
os.environ['OMP_NUM_THREADS'] = '1'  #10

import numpy as np
import torch, torchvision
from torchvision import transforms as tr
from tqdm import tqdm
from pathlib import Path
# import math
import argparse
import pandas as pd
import warnings

import transformers

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from ddputils import is_main_process, save_on_master, setup_for_distributed

# torchvision.disable_beta_transforms_warning()
# import torchvision.transforms.v2 as tr #May 9: would require reinstalling th evirtual env.
# we might do it later.

# import torch.nn as nn


# SCRIPT_DIR = os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe()))) #os.getcwd() #
# # print('cwd: ',SCRIPT_DIR)
# #os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))
# util_path = os.path.normpath(os.path.join(SCRIPT_DIR, '..', 'util'))
# sys.path.insert(0, util_path)    


# from train_downstream_VideoMAE import train_classifier_ddp
# from make_toybox_dataset import make_toybox_dataset
from transformers import VideoMAEConfig, VideoMAEModel
from torch.utils.data import Dataset
import av

from time import time
from copy import deepcopy

In [2]:
def _get_transform(image_size):

    mean = [0.5, 0.5, 0.5]#np.mean(mean_all, axis=0) #mean_all[chosen_subj] 
    std = [0.25, 0.25, 0.25] #std_all[chosen_subj] 
    
#     [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
#     [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD

    augs = [tr.Resize(image_size), tr.CenterCrop(image_size), 
            tr.ConvertImageDtype(torch.float32), 
             tr.Normalize(mean,std)]
    return tr.Compose(augs)

def transform_vid(video):
    # Used with standard video datasets such as torchvision.UCF101
#     print(vid.shape)
    if video.shape[1]!=3: # Make it TCHW
        video = torch.permute(video, (0,3,1,2))
    image_size = 224
#     vid.p
    transform = _get_transform(image_size)
#     xtt = [transform(torch.from_numpy(frame)).unsqueeze(0) 
    xtt = [transform(frame).unsqueeze(0) 
       for frame in video]
    return torch.concat(xtt, axis=0)#.unsqueeze(0)

In [3]:
def get_config(image_size, args, num_labels=2):
    arch_kw = args.architecture
    if arch_kw=='small2':
        hidden_size = 768
        intermediate_size = 4*768
        num_attention_heads = 6
        num_hidden_layers = 6
        
        config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                             num_frames=16, tubelet_size=2, 
                                             hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads,
                                             intermediate_size=intermediate_size, num_labels=num_labels)
    
    elif len(arch_kw)==0: #default
        config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                             num_frames=16, tubelet_size=2, 
                                             hidden_size=768, num_hidden_layers=12, num_attention_heads=12,
                                             intermediate_size=3072, num_labels=num_labels)
    elif arch_kw=='small1':
        hidden_size = 384
        intermediate_size = 4*384
        num_attention_heads = 6
        num_hidden_layers = 12
        
        config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                             num_frames=16, tubelet_size=2, 
                                             hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads,
                                             intermediate_size=intermediate_size, num_labels=num_labels)
        
    else:
        raise ValueError
    return config


def init_model_from_checkpoint(model, checkpoint_path):
    # caution: model class
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

def adapt_videomae(source_model, target_model):
    # load the embeddings
    target_model.videomae.embeddings.load_state_dict(
        source_model.videomae.embeddings.state_dict())
#     load the encoder
    target_model.videomae.encoder.load_state_dict(
        source_model.videomae.encoder.state_dict())
    return target_model
# def adapt_videomae(source_model, target_model):
#     # load the embeddings
#     target_model.embeddings.load_state_dict(
#         source_model.videomae.embeddings.state_dict())
# #     load the encoder
#     target_model.encoder.load_state_dict(
#         source_model.videomae.encoder.state_dict())
#     return target_model

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
#         for param in model.classifier.parameters():
#             param.requires_grad = True
            
def get_model(image_size, num_labels, feature_extracting, args):
    config_source = get_config(image_size, args)
    model_source = transformers.VideoMAEForPreTraining(config_source)
    
    if args.init_checkpoint_path!='na':
        print('args.init_checkpoint_path:',args.init_checkpoint_path)
        # initialize the model using the checkpoint
        model_source = init_model_from_checkpoint(model_source, args.init_checkpoint_path)
  
    config_target = get_config(image_size, args, num_labels=num_labels)
    model_target = transformers.VideoMAEForVideoClassification(config=config_target)
#     model_target = transformers.VideoMAEModel(config=config_target) #@@@ do not add the classifer head
    model_target = adapt_videomae(model_source, model_target)
#     if not torch.all(
#         model_target.embeddings.patch_embeddings.projection.weight==model_source.videomae.embeddings.patch_embeddings.projection.weight):
#         warnings.warn('Model not successfully initialized')
    if not torch.all(
        model_target.videomae.embeddings.patch_embeddings.projection.weight==model_source.videomae.embeddings.patch_embeddings.projection.weight):
        warnings.warn('Model not successfully initialized')
    
    if feature_extracting:
        set_parameter_requires_grad(model_target, feature_extracting)
    
    return model_target


In [4]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

In [5]:
prot_name='g0g1'
seed=401

rank = 'cuda:0'
image_size= 224
num_classes=0# to get only the features, no classifier head
feature_extract=True
architecture='small2'
init_checkpoint_path='na'
savedir='/N/project/baby_vision_curriculum/trained_models/generative/v2/benchmarks/toybox/'
batch_size=128#64
other_id='10fps.3ep'
frame_rate=10
num_workers=6

args = Args(architecture=architecture,
            init_checkpoint_path=init_checkpoint_path,
            savedir=savedir,
            prot_name=prot_name,
            seed=seed,
            other_id=other_id,
            frame_rate=frame_rate,
            num_workers=num_workers,
            batch_size=batch_size)


In [6]:
xmodel = get_model(image_size, num_classes, feature_extract, args)
    


In [8]:
# inputs = image_processor(list(video), return_tensors="pt")
# inputs['pixel_values'].shape
# outputs.logits.shape

## Prepare the toybox dataset

In [9]:
import os
import cv2

class ToyboxDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform, frame_rate=10, sample_len=16):
        self.root_dir = root_dir
        self.frame_rate = frame_rate
        self.sample_len = sample_len
        self.transform = transform
        self.samples = []
        for supercategory in os.listdir(self.root_dir):
            for obj in os.listdir(os.path.join(self.root_dir, supercategory)):
#                 for obj in os.listdir(os.path.join(self.root_dir, supercategory, category)):
                object_dir = os.path.join(self.root_dir, supercategory, obj)
                for view in os.listdir(object_dir):
                    view_path = os.path.join(object_dir, view)
                    self.samples.append(view_path)
#                         self.samples.append((view_path, supercategory, category, object))

    def __len__(self):
        return len(self.samples)

    def get_all_frames(self, cap):
        desired_frames = self.sample_len
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
#                 print('end of the video, i_frame, len fames', frame_count, len(frames))
                # End of video
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)                
            if len(frames) == desired_frames:
                break
        tmp_nframes = len(frames)
        if tmp_nframes < desired_frames:
            last_frame = frames[-1]
            for i in range(desired_frames - tmp_nframes):
                frames.append(last_frame)
        
        assert len(frames)==desired_frames
        return frames
    
    def wrap_frames(self, frames):
        frames = torch.as_tensor(np.asarray(frames))
        if len(frames.shape)!=4: #torch.Size([16, 12xx, 19xx, 3])
            return None
        return self.transform(frames)
            
    def __getitem__(self, index):
#         print('---------------')
        vid_path = self.samples[index]
        frames = []
        cap = cv2.VideoCapture(vid_path)
        if cap is None or not cap.isOpened():
            warnings.warn('unable to open video source: '+vid_path)
            return None, None

        fps = cap.get(cv2.CAP_PROP_FPS)
        ds_rate = round(fps/self.frame_rate)
        num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
#         print('num_frames:',num_frames)
#         print('ds_rate:',ds_rate)
#         print('num_frames:',num_frames)
        
        sample_scope = self.sample_len*ds_rate
        if num_frames<sample_scope:
#             print('Not enough frames in the video',vid_path)
            frames = self.get_all_frames(cap)
                        #apply transform
            frames_transformed = self.wrap_frames(frames)
            if frames_transformed is None:
                print(vid_path, 'gave None')
                return None, None
            return frames_transformed, vid_path
            
        
        # duration = num_frames / fps
        start_frame = int(num_frames * 1 / 5)  # Starting frame at 2/3 of video duration
        if (num_frames-start_frame)<sample_scope:
            start_frame = num_frames-sample_scope
        
#         print('start_frame',start_frame)
#         end_frame = start_frame+sample_scope#int(start_frame + fps * 1.6)  # Ending frame after 1.6 seconds
        desired_frames = self.sample_len
        frame_count = 0
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
#                 print('end of the video, i_frame, len fames', frame_count, len(frames))
                # End of video
                break
            
            if frame_count % ds_rate==0:
#                 if (frame_count > start_frame) & \
#                 (frame_count < end_frame):
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)                
            if len(frames) == desired_frames:
                break
            frame_count += 1
            
        cap.release()
        frames_transformed = self.wrap_frames(frames)
        
        if frames_transformed is None:
            print(vid_path, 'gave None')
            return None, None
        else:
            return frames_transformed, vid_path
            
        

In [7]:
toybox_root = '/N/project/baby_vision_curriculum/benchmarks/toybox/vids/toybox/'
transform = transform_vid
frame_rate=3
sample_len=16
tb_dataset = ToyboxDataset(toybox_root, transform, 
                          frame_rate=frame_rate, sample_len=sample_len)

In [21]:
dataset[0][0].shape==torch.Size([16, 3, 224, 224])

True

In [None]:
xtt = []
for i in range(2000,3000,100):
    t0 = time()
    sx = tb_dataset[i][0]
    xtt.append(deepcopy(sx))
    print(time()-t0)
    assert len(sx)==16
# takes 1 second. might get faster later because of cache

import matplotlib.pyplot as plt
fig,ax = plt.subplots(1,8, figsize=(10,3))
j=9
for i in range(8):
    ax[i].imshow(xtt[j][i])

## Inference loop

In [10]:
def my_collate(batch):
    batch = tuple(filter(lambda x: x[0] is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [11]:
sampler_shuffle = True #for the distributed dampler
# num_epochs = args.n_epoch
batch_size = args.batch_size# 128
pin_memory = False
num_workers = args.num_workers #number_of_cpu-1#32
collate_fn = my_collate#None

In [12]:
toybox_root = '/N/project/baby_vision_curriculum/benchmarks/toybox/vids/toybox/'
transform = transform_vid
frame_rate=3
sample_len=16
dataset = ToyboxDataset(toybox_root, transform, 
                          frame_rate=frame_rate, sample_len=sample_len)

In [13]:
dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, pin_memory=pin_memory, collate_fn=collate_fn,
        num_workers=num_workers, shuffle=False, drop_last=False)#, sampler=samplers_dict[x])

In [14]:

xmodel = xmodel.to(rank)
_ = xmodel.eval()

In [15]:
def save_results(fnames, embeddings, args):
    hdim = embeddings.shape[1]
    xdf = pd.DataFrame(embeddings, columns= ['dim'+str(i)
                                         for i in range(hdim)])
    xdf['fnames'] = fnames
    xdf = xdf[['fnames']+ list(xdf.columns[:-1])]

    xdf = xdf.sort_values('fnames')
    xdf = xdf.drop_duplicates(subset='fnames', ignore_index=True)

    savedir = args.savedir
    Path(savedir).mkdir(parents=True, exist_ok=True)
#         <model vs scores>_<prot>_seed_<seed>_other_<other>_<other id>
    result_fname = '_'.join(['embeddings', args.prot_name, 
                            'seed', str(args.seed),  
                            args.other_id])+'.csv'
    results_fpath = os.path.join(savedir, result_fname)
    results_df.to_csv(results_fpath, sep=',', float_format='%.6f', index=False)

In [16]:
# all_embeddings = []
world_size = 4
data = {
        'fnames':[],
        'embeddings': []   
    }
outputs = [None for _ in range(world_size)]

print_period=1#20
i_break = 10
with torch.no_grad():
    for i_t, xbatch in enumerate(tqdm(dataloader)):
        inputs, fnames = xbatch
        if inputs is None:
            continue
        inputs = inputs.to(rank)
        image_features = xmodel(pixel_values=inputs).logits
        image_features /= image_features.norm(dim=-1, keepdim=True)
        data['fnames'] += fnames
        data['embeddings'].append(image_features.detach().cpu().numpy())

        if (i_t%print_period)==0:
            memory_allocated = torch.cuda.memory_allocated() / 1024**2
            print(f'GPU memory allocated: {memory_allocated:.2f} MB')
        if i_t==i_break:
            break #@@@

# dist.all_gather_object(outputs, data)

            
#     if is_main_process():
#         print('finished processing')
#         allfnames, allembeddings = [],[]
#         for cdict in outputs:
#             allfnames += list(chain(*cdict['fnames'])) 
#             print('Aggregating worker results:',len(cdict['fnames']),'/', len(allfnames))
#             allembeddings +=cdict['embeddings']
            
#         allembeddings = np.concatenate(allembeddings)
    
#         save_results(allfnames, allembeddings, args)

  3%|█▏                                       | 1/34 [02:13<1:13:28, 133.60s/it]

GPU memory allocated: 1343.09 MB


  6%|██▌                                         | 2/34 [02:15<29:54, 56.07s/it]

GPU memory allocated: 1343.09 MB


  9%|███▉                                        | 3/34 [02:17<16:10, 31.30s/it]

GPU memory allocated: 1343.09 MB


 12%|█████▏                                      | 4/34 [02:19<09:51, 19.71s/it]

GPU memory allocated: 1343.09 MB


 15%|██████▍                                     | 5/34 [02:21<06:24, 13.27s/it]

GPU memory allocated: 1343.09 MB


 18%|███████▊                                    | 6/34 [02:22<04:22,  9.38s/it]

GPU memory allocated: 1343.09 MB


[mov,mp4,m4a,3gp,3g2,mj2 @ 0x163c90c0] moov atom not found
 21%|█████████                                   | 7/34 [03:55<16:26, 36.53s/it]

GPU memory allocated: 1343.09 MB


 24%|██████████▎                                 | 8/34 [03:57<11:02, 25.48s/it]

GPU memory allocated: 1343.09 MB


 26%|███████████▋                                | 9/34 [03:59<07:33, 18.13s/it]

GPU memory allocated: 1343.09 MB


 29%|████████████▋                              | 10/34 [04:00<05:14, 13.11s/it]

GPU memory allocated: 1343.09 MB
GPU memory allocated: 1343.09 MB


 29%|████████████▋                              | 10/34 [04:32<10:54, 27.29s/it]
