In [1]:
import sys, os

env_root = '/N/project/baby_vision_curriculum/pythonenvs/hfenv/lib/python3.10/site-packages/'
sys.path.insert(0, env_root)

os.environ['OPENBLAS_NUM_THREADS'] = '38' #@@@@ to help with the num_workers issue
os.environ['OMP_NUM_THREADS'] = '1'  #10

import numpy as np
import torch, torchvision
from torchvision import transforms as tr
from tqdm import tqdm
from pathlib import Path
# import math
import argparse
import pandas as pd
import warnings

import transformers

In [2]:
from transformers import VideoMAEConfig, VideoMAEModel
from torch.utils.data import Dataset
import av


def get_config(image_size, args, num_labels=2):
    arch_kw = args.architecture
    if arch_kw=='small2':
        hidden_size = 768
        intermediate_size = 4*768
        num_attention_heads = 6
        num_hidden_layers = 6
        
        config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                             num_frames=16, tubelet_size=2, 
                                             hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads,
                                             intermediate_size=intermediate_size, num_labels=num_labels)
    
    elif len(arch_kw)==0: #default
        config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                             num_frames=16, tubelet_size=2, 
                                             hidden_size=768, num_hidden_layers=12, num_attention_heads=12,
                                             intermediate_size=3072, num_labels=num_labels)
    elif arch_kw=='small1':
        hidden_size = 384
        intermediate_size = 4*384
        num_attention_heads = 6
        num_hidden_layers = 12
        
        config = transformers.VideoMAEConfig(image_size=image_size, patch_size=16, num_channels=3,
                                             num_frames=16, tubelet_size=2, 
                                             hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads,
                                             intermediate_size=intermediate_size, num_labels=num_labels)
        
    else:
        raise ValueError
    return config


def init_model_from_checkpoint(model, checkpoint_path):
    # caution: model class
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

def adapt_videomae(source_model, target_model):
    # load the embeddings
    target_model.videomae.embeddings.load_state_dict(
        source_model.videomae.embeddings.state_dict())
#     load the encoder
    target_model.videomae.encoder.load_state_dict(
        source_model.videomae.encoder.state_dict())
    return target_model

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.classifier.parameters():
            param.requires_grad = True
            
def get_model(image_size, num_labels, feature_extracting, args):
    config_source = get_config(image_size, args)
    model_source = transformers.VideoMAEForPreTraining(config_source)
    
    if args.init_checkpoint_path!='na':
        print('args.init_checkpoint_path:',args.init_checkpoint_path)
        # initialize the model using the checkpoint
        model_source = init_model_from_checkpoint(model_source, args.init_checkpoint_path)
  
    config_target = get_config(image_size, args, num_labels=num_labels)
    model_target = transformers.VideoMAEForVideoClassification(config=config_target)
    model_target = adapt_videomae(model_source, model_target)
    if not torch.all(
        model_target.videomae.embeddings.patch_embeddings.projection.weight==model_source.videomae.embeddings.patch_embeddings.projection.weight):
        warnings.warn('Model not successfully initialized')
    
    set_parameter_requires_grad(model_target, feature_extracting)
    
    return model_target


In [3]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    
def get_optimizer(model, feature_extract, args):
    params_to_update = model.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                print("\t",name)

#     if feature_extract:
    lr = args.lr#1e-3
    weight_decay =args.wd#5e-5
    optimizer_ft = torch.optim.Adam(params_to_update, lr=lr, weight_decay=weight_decay)
    #     optimizer_ft = torch.optim.SGD([{'params': params_to_update, 
    #                               'initial_lr':lr}], 
    #                             lr=lr, momentum=0.9)
#     else:
#         lr=1e-4
#         optimizer_ft = torch.optim.Adam(params_to_update, lr=lr)
        
    return optimizer_ft

In [4]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

In [5]:

# ------------
# Dataset and Dataloader

def _get_transform(image_size):

    mean = [0.5, 0.5, 0.5]#np.mean(mean_all, axis=0) #mean_all[chosen_subj] 
    std = [0.25, 0.25, 0.25] #std_all[chosen_subj] 
    
#     [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
#     [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD

    augs = [tr.Resize(image_size), tr.CenterCrop(image_size), 
            tr.ConvertImageDtype(torch.float32), 
             tr.Normalize(mean,std)]
    return tr.Compose(augs)



def transform_image_cifar10(image):
#     Used for standard single image datasets such as torchvision.CIFAR10, torchvision.ImageNet
#     if image.shape[0]!=3:
    image_size=224
    num_frames=16
    mean = [0.5, 0.5, 0.5]#np.mean(mean_all, axis=0) #mean_all[chosen_subj] 
    std = [0.25, 0.25, 0.25] #std_all[chosen_subj] 
    
#     [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
#     [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD

    augs = [tr.ToTensor(),
            tr.Resize(image_size), tr.CenterCrop(image_size),
            tr.ConvertImageDtype(torch.float32), 
             tr.Normalize(mean,std)]
    transform = tr.Compose(augs)

    return transform(image).unsqueeze(0).repeat(num_frames,1,1,1)

def get_inp_label(task, batch):
    if task=='ucf101':
        inputs, _, labels = batch
        return inputs, labels
    else:
        raise NotImplementedError()

def make_ucf101dataset(args):
    ucf_root='/N/project/baby_vision_curriculum/benchmarks/mainstream/ucf101/UCF-101'
    annotation_path = '/N/project/baby_vision_curriculum/benchmarks/mainstream/ucf101/UCF101TrainTestSplits-RecognitionTask/ucfTrainTestlist/'
    frames_per_clip = 16
    step_between_clips = 1
    frame_rate=args.frame_rate#int(30/args.ds_rate)
    transform = transform_vid
    output_format= 'TCHW'
    num_workers=args.num_workers-1 #40
    train_dataset = torchvision.datasets.UCF101(ucf_root, 
                                          annotation_path,
                                          frames_per_clip,
                                          step_between_clips=step_between_clips,
                                          frame_rate=frame_rate,
                                          fold=1,
                                          train=True,
                                          transform=transform,
                                          output_format=output_format,
                                          num_workers=num_workers)
    val_dataset = torchvision.datasets.UCF101(ucf_root, 
                                          annotation_path,
                                          frames_per_clip,
                                          step_between_clips=step_between_clips,
                                          frame_rate=frame_rate,
                                          fold=1,
                                          train=False,
                                          transform=transform,
                                          output_format=output_format,
                                          num_workers=num_workers)
    num_classes = 101
    return {'train':train_dataset,
           'val':val_dataset}, num_classes

def make_cifar10dataset(args):
    cifar10img_root = '/N/project/baby_vision_curriculum/benchmarks/mainstream/cifar10'
    image_datasets = {'train': torchvision.datasets.CIFAR10(root=cifar10img_root,
                                                            transform=transform_image_cifar10, train=True, download=True),
                      'val': torchvision.datasets.CIFAR10(root=cifar10img_root,transform=transform_image_cifar10, train=False, download=True)}
    num_classes = 10 
        
    return image_datasets, num_classes

def make_dataset(args):
    task = args.task
    if task=='ucf101':
        return make_ucf101dataset(args)
#     seq_len = kwargs['seq_len']
#     image_size = kwargs['image_size']
    elif task=='cifar10':
        return make_cifar10dataset(args)
    else:
        raise NotImplementedError()

In [6]:
task='cifar10'#'ucf101'
ch_dir='/N/project/baby_vision_curriculum/trained_models/generative/v2/'

init_checkpoint_path=ch_dir+"model_g0_seed_1111_other_1111_mask50_small2_30ep.pt"
savedir='/N/project/baby_vision_curriculum/trained_models/generative/v2/benchmarks/ucf101/'
prot_name='g0'
seed=1111
other_id='10fps.30ep'

n_epoch=1
save_model='n'

frame_rate=10
batch_size=64
num_workers=45#6
architecture='small2'
lr=1e-3
wd=5e-5
args = Args(task=task, architecture=architecture,
            init_checkpoint_path=init_checkpoint_path,
            savedir=savedir,
            prot_name=prot_name,
            seed=seed,
            other_id=other_id,
            n_epoch=n_epoch,
            save_model=save_model,
            frame_rate=frame_rate,
            num_workers=num_workers,
            batch_size=batch_size,
            lr=lr, 
           wd=wd)

In [16]:
datasets, num_classes = make_dataset(args)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
datasets['train'][0][0].shape

torch.Size([16, 3, 224, 224])

In [30]:
def transform_vid(video):
    # Used with standard video datasets such as torchvision.UCF101
#     print(vid.shape)
    if video.shape[1]!=3: # Make it TCHW
        video = torch.permute(video, (0,3,1,2))
    image_size = 224
#     vid.p
    transform = _get_transform(image_size)
#     xtt = [transform(torch.from_numpy(frame)).unsqueeze(0) 
    xtt = [transform(frame).unsqueeze(0) 
       for frame in video]
    return torch.concat(xtt, axis=0)#.unsqueeze(0)

In [31]:
def custom_collate(batch):
    filtered_batch = []
    for video, _, label in batch:
        filtered_batch.append((video, label))
    return torch.utils.data.dataloader.default_collate(filtered_batch)

In [9]:
ucf_root='/N/project/baby_vision_curriculum/benchmarks/mainstream/ucf101/UCF-101'
annotation_path = '/N/project/baby_vision_curriculum/benchmarks/mainstream/ucf101/UCF101TrainTestSplits-RecognitionTask/ucfTrainTestlist/'
frames_per_clip = 16
step_between_clips = 1
frame_rate=args.frame_rate#int(30/args.ds_rate)
transform = transform_vid
output_format= 'TCHW'
num_workers=45# args.num_workers-1 #40
train_dataset = torchvision.datasets.UCF101(ucf_root, 
                                      annotation_path,
                                      frames_per_clip,
                                      step_between_clips=step_between_clips,
                                      frame_rate=frame_rate,
                                      fold=1,
                                      train=True,
                                      transform=transform,
                                      output_format=output_format,
                                      num_workers=num_workers)

  0%|          | 0/833 [00:00<?, ?it/s]



In [32]:
torch.all(
    datasets['train'][0][0]==train_dataset[0][0])



tensor(True)

In [13]:
image_size=224
config_source = get_config(image_size, args)
model_source = transformers.VideoMAEForVideoClassification(config_source)
#transformers.VideoMAEForPreTraining(config_source)

In [14]:
model_source

VideoMAEForVideoClassification(
  (videomae): VideoMAEModel(
    (embeddings): VideoMAEEmbeddings(
      (patch_embeddings): VideoMAEPatchEmbeddings(
        (projection): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
    )
    (encoder): VideoMAEEncoder(
      (layer): ModuleList(
        (0): VideoMAELayer(
          (attention): VideoMAEAttention(
            (attention): VideoMAESelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): VideoMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): VideoMAEIntermediate(
            (dense): Linear(in_f

In [None]:
768

In [None]:
#Instantiate the dataset, criterion
#     datasets, num_classes = make_dataset(args)
feature_extract = True
#     model_type = 'res50'
#-----------------
# Create the criterion
#     criterion = torch.nn.CrossEntropyLoss()


# Instantiate the model, optimizer
#Load the model, adapt it to the downstream task
image_size = 224
feature_extract = True
xmodel = get_model(image_size, num_classes, feature_extract, args)

In [34]:
xmodel = xmodel.to('cuda:0')

In [35]:
optimizer = get_optimizer(xmodel, feature_extract, args) 

Params to learn:
	 classifier.weight
	 classifier.bias


In [37]:
def ucf_collate(batch):
    filtered_batch = []
    for video, _, label in batch:
        filtered_batch.append((video, label))
    return torch.utils.data.dataloader.default_collate(filtered_batch)

In [38]:
sampler_shuffle = True #for the distributed dampler
num_epochs = args.n_epoch
batch_size = args.batch_size# 128
pin_memory = True
num_workers = 5#args.num_workers #number_of_cpu-1#32
if args.task=='ucf101':
    collate_fn = ucf_collate
else:
    collate_fn = None

In [40]:
dataloaders = {x: torch.utils.data.DataLoader(
        datasets[x], batch_size=batch_size, pin_memory=pin_memory, collate_fn=collate_fn,
        num_workers=num_workers, shuffle=False, drop_last=True)
                        for x in ['train', 'val']}
# dataloader = torch.utils.data.DataLoader(
#         train_dataset, batch_size=batch_size, pin_memory=False, 
#         num_workers=num_workers, shuffle=False, drop_last=True, collate_fn=custom_collate)

In [41]:
print('len dset, len dloader: ', len(datasets['train']), len(dataloaders['train']))
#         print(dataset.__getitem__(22).shape)
print('dataloaders created') #@@@

len dset, len dloader:  543459 33966
dataloaders created


In [42]:
rank='cuda:0'

train_acc_history = []
val_acc_history = []

# best_model_wts = deepcopy(model.state_dict())
best_acc = 0.0

for phase in ['train', 'val']:
#     dataloaders[phase].sampler.set_epoch(i_ep)
    if phase == 'train':
        xmodel.train()  # Set model to training mode
    else:
        xmodel.eval()   # Set model to evaluate mode

    running_loss = torch.tensor([0.0], device='cuda:0')
    running_corrects = torch.tensor([0.0], device=rank)

    i_iter, print_period=0, 100
    i_break, print_period = 3,1 #@@@ debug
    # Iterate over data.
    for batch in tqdm(dataloaders[phase]):
        # zero the parameter gradients
        optimizer.zero_grad()

#                 loss, logits = get_loss(task, batch, phase, rank, args)

        # implement get_loss for different datasets and for videomaeclassifier
        inputs, labels = batch #get_inp_label(args.task, batch) 
        inputs = inputs.to(rank)
        labels = labels.to(rank)
        outputs = xmodel(pixel_values=inputs, labels=labels)

        logits = outputs.logits
        loss = outputs.loss

        _, preds = torch.max(logits, 1)

        # backward + optimize only if in training phase
        if phase == 'train':
            loss.backward()
            optimizer.step()

        # statistics
        running_loss += loss.item() * inputs.size(0)
#                 print(rank, 'labels shape, device: ', labels.shape, labels.data.device)
#                 print(rank, 'preds shape, device: ', preds.shape, preds.device)
        running_corrects += torch.sum(preds == labels.data)

        i_iter+=1
        if (i_iter%print_period)==0:
            print('loss:',loss.item())

        if i_iter==i_break:
            break #@@@@ debug

  0%|                                      | 2/33966 [00:03<15:34:03,  1.65s/it]

loss: 4.321846008300781
loss: 3.7920289039611816
loss: 3.2920782566070557


  0%|                                      | 2/33966 [00:05<25:27:39,  2.70s/it]
  0%|                                       | 1/13213 [00:01<6:37:20,  1.80s/it]

loss: 3.3286473751068115


  0%|                                       | 2/13213 [00:02<3:10:21,  1.16it/s]

loss: 3.337937116622925
loss: 3.229137659072876


  0%|                                       | 2/13213 [00:03<5:58:05,  1.63s/it]


In [44]:

memory_allocated = torch.cuda.memory_allocated() / 1024**2
print(f'GPU memory allocated: {memory_allocated:.2f} MB')

GPU memory allocated: 315.97 MB


In [23]:
inputs.size(0)

16