In [1]:
import os
import sys

env_root = '/N/slate/sheybani/pythonenvs/hfenv2/lib/python3.10/site-packages'
# Himanshu: '/N/slate/hhansar/hgenv/lib/python3.10/site-packages'
sys.path.insert(0,env_root)

In [2]:
import transformers
from transformers import AutoImageProcessor, VideoMAEForPreTraining

In [3]:
import numpy as np
import torch
from PIL import Image
from pathlib import Path
from tqdm import tqdm
from torch.utils.data import Dataset

In [4]:
def get_fpathlist(vid_root, subjdir, ds_rate=1):
    """
    # read the image files inside vid_root/subj_dir into a list. 
    # makes sure they're all jpg. also sorts them so that the order of the frames is correct.
    # subjdir = ['008MS']
    """
    
    fpathlist = sorted(list(Path(os.path.join(vid_root, subjdir)).iterdir()), 
                       key=lambda x: x.name)
    fpathlist = [str(fpath) for fpath in fpathlist if fpath.suffix=='.jpg']
    fpathlist = fpathlist[::ds_rate]
    return fpathlist


class ImageSequenceDataset(Dataset):
    """
    To use for video models. 
    """
    def __init__(self, image_paths, transform):
        # transform is a Hugging Face image processor transform. check the usage in __getitem
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the sequence of images
        images = [Image.open(fp) for fp in self.image_paths[idx]]
        images = self.transform(images, return_tensors="pt").pixel_values[0]
        return images
    
def get_train_val_split(fpathlist, val_ratio=0.1):
    """
    Splits the list of filepaths into a train list and test list
    """
    n_fr = len(fpathlist)
    val_size = int(n_fr*val_ratio)
    
    split1_idx = int((n_fr-val_size)/2)
    split2_idx = int((n_fr+val_size)/2)
    train_set =fpathlist[:split1_idx]+fpathlist[split2_idx:]
    val_set = fpathlist[split1_idx:split2_idx]
    return train_set, val_set

def get_fpathseqlist(fpathlist, seq_len, ds_rate=1, n_samples=None):
    """
    Returns a list of list that can be passed to ImageSequenceDataset
    # n_samples: int
    # between 1 and len(fpathlist)
    # If None, it's set to len(fpathlist)/seq_len
    """
    
    sample_len = seq_len*ds_rate
    if n_samples is None:
        n_samples = int(len(fpathlist)/seq_len)
        sample_stride = sample_len
    else:
        assert type(n_samples)==int
        sample_stride = int(len(fpathlist)/n_samples)

    fpathseqlist = [fpathlist[i:i+sample_len:ds_rate] 
                    for i in range(0, n_samples*sample_stride, sample_stride)]
    return fpathseqlist

In [6]:
jpg_root='/N/project/infant_image_statistics/preproc_saber/JPG_10fps/'
ds_rate = 1

n_groupframes = 1450000 # minimum number of frames across age groups

g0='008MS+009SS_withrotation+010BF_withrotation+011EA_withrotation+012TT_withrotation+013LS+014SN+015JM+016TF+017EW_withrotation'
g1='026AR+027SS+028CK+028MR+029TT+030FD+031HW+032SR+033SE+034JC_withlighting'
g2='043MP+044ET+046TE+047MS+048KG+049JC+050AB+050AK_rotation+051DW'
# Total number of frames in each age group: g0=1.68m, g1=1.77m, g2=1.45m

g0 = g0.split('+')
g1 = g1.split('+')
g2 = g2.split('+')

In [7]:
image_size = 224
num_frames = 16

# image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
image_processor = transformers.VideoMAEImageProcessor(
    size={"shortest_edge":image_size},
    do_center_crop=True, crop_size={"height":image_size, "width": image_size}
)

seq_len = num_frames #equivalent to num_frames in VideoMAE()
#     ds_rate = 1
n_samples = None#10 #50000

In [9]:
def make_dataset(subj_dirs, **kwargs):
    seq_len = kwargs['seq_len']
    n_groupframes=kwargs['n_groupframes']#1450000
    ds_rate = kwargs['ds_rate']
    jpg_root = kwargs['jpg_root']
    
    gx_fpathlist = []
    for i_subj, subjdir in enumerate(tqdm(subj_dirs)):
        gx_fpathlist += get_fpathlist(jpg_root, subjdir, ds_rate=ds_rate)
    gx_fpathlist = gx_fpathlist[:n_groupframes]

    # Train-val split
    gx_train_fp, gx_val_fp = get_train_val_split(gx_fpathlist, val_ratio=0.1)


    gx_train_fpathseqlist = get_fpathseqlist(gx_train_fp, seq_len, ds_rate=1, n_samples=None)
    gx_val_fpathseqlist = get_fpathseqlist(gx_val_fp, seq_len, ds_rate=1, n_samples=None)
    
    return {'train':ImageSequenceDataset(gx_train_fpathseqlist, transform=image_processor),
           'val': ImageSequenceDataset(gx_val_fpathseqlist, transform=image_processor)}

In [11]:
datasets = make_dataset(g2, seq_len=seq_len, jpg_root=jpg_root, ds_rate=ds_rate, n_groupframes=n_groupframes)

100%|███████████████████████████████████| 9/9 [00:10<00:00,  1.11s/it]


In [43]:
batch_size = 1
dataloaders = {x: torch.utils.data.DataLoader(
        datasets[x], batch_size=batch_size)
                        for x in ['train', 'val']}

In [None]:
hidden_size = 768 #384
intermediate_size = 3072 #4*384
num_atention_heads = 12 #6

config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                     num_frames=num_frames, tubelet_size=2, 
                                     hidden_size=hidden_size, num_hidden_layers=12, num_attention_heads=num_atention_heads,
                                     intermediate_size=intermediate_size, initializer_range=0.02,
                                     use_mean_pooling=True, decoder_num_attention_heads=6,
                                     decoder_hidden_size=384, decoder_num_hidden_layers=4, 
                                     decoder_intermediate_size=1536, norm_pix_loss=True)
# config
model = transformers.VideoMAEForPreTraining(config)
#model components: base_model, encoder_to_decoder, decoder
# model.videomae==model.base_model
# model.videomae.embeddings, model.videomae.encoder

# embeddings: a Conv3D layer: Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
# Splits the image sequence into n tubes 
# Maps each tube (2x16x16) to a 768D vector using linear projection (one projector for all tubes)
# Returns a tensor of shape (789x768) for each image sequence.
 
# encoder: has 12 layers of type VideoMAELayer. Each layer is: [attention, linear, linear(768,3073), gelu, linear(3072,368), layernorm]

# encoder_to_decoder: one linear layer: (in_features=768, out_features=384, bias=False)

# decoder: has 4 VideoMAELayer layers + a linear projection from 384 to 1536 dimensions


# If you use the VideoMAEModel, it only includes the base model (encoder). 
# You may or may not choose to pass in bool_masked_pos argument.
# If you don't, the output is of shape: [1568, 768] (the eoncoding of all tubes?)
# If you do, the output is of a different shape: [789, 768] (the predictions for the masked tubes?)

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [44]:
num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
model_seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
    
    
for phase in ['train', 'val']:
#     dataloaders_dict[phase].sampler.set_epoch(i_ep)
    if phase == 'train':
        model.train()  # Set model to training mode
    else:
        model.eval()   # Set model to evaluate mode

    # Iterate over data.
    for inputs in tqdm(dataloaders[phase]):
        print(inputs.shape)
#         break
        optimizer.zero_grad()
        bool_masked_pos = torch.randint(0, 2, (batch_size, model_seq_length)).bool()
        outputs = model(inputs, bool_masked_pos=bool_masked_pos)

        loss = outputs.loss
        if phase == 'train':
            loss.backward()
            optimizer.step()
        break
        

  0%|                                       | 0/81562 [00:00<?, ?it/s]

torch.Size([1, 16, 3, 224, 224])


  0%|                                       | 0/81562 [00:02<?, ?it/s]
  0%|                                        | 0/9062 [00:00<?, ?it/s]

torch.Size([1, 16, 3, 224, 224])


  0%|                                        | 0/9062 [00:00<?, ?it/s]
