In [1]:
import sys, os

env_root = '/N/project/baby_vision_curriculum/pythonenvs/hfenv/lib/python3.10/site-packages/'
sys.path.insert(0, env_root)

In [2]:
import numpy as np
import torch, torchvision
from torchvision import transforms as tr
from tqdm import tqdm
from pathlib import Path
# import math
import argparse
import pandas as pd
import warnings

import transformers

# import torch.distributed as dist
# from torch.utils.data.distributed import DistributedSampler
# from torch.nn.parallel import DistributedDataParallel as DDP
# import torch.multiprocessing as mp
# from ddputils import is_main_process, save_on_master, setup_for_distributed


# from PIL import Image
from torch.utils.data import Dataset
import random
from copy import deepcopy
import json
import itertools
import gc

import logging

In [3]:
def get_fpathlist(vid_root, subjdir, ds_rate=1):
    """
    # read the image files inside vid_root/subj_dir into a list. 
    # makes sure they're all jpg. also sorts them so that the order of the frames is correct.
    # subjdir = ['008MS']
    """
    
    fpathlist = sorted(list(Path(os.path.join(vid_root, subjdir)).iterdir()), 
                       key=lambda x: x.name)
    fpathlist = [str(fpath) for fpath in fpathlist if fpath.suffix=='.jpg']
    fpathlist = fpathlist[::ds_rate]
    return fpathlist
    
def get_train_val_split(fpathlist, val_ratio=0.1):
    """
    Splits the list of filepaths into a train list and val list
    """
    n_fr = len(fpathlist)
    val_size = int(n_fr*val_ratio)
    
    split1_idx = int((n_fr-val_size)/2)
    split2_idx = int((n_fr+val_size)/2)
    train_set =fpathlist[:split1_idx]+fpathlist[split2_idx:]
    val_set = fpathlist[split1_idx:split2_idx]
    return train_set, val_set

def get_fpathseqlist(fpathlist, seq_len, ds_rate=1, n_samples=None):
    """
    Returns a list of list that can be passed to ImageSequenceDataset
    # n_samples: int
    # between 1 and len(fpathlist)
    # If None, it's set to len(fpathlist)/seq_len
    """
    
    sample_len = seq_len*ds_rate
    if n_samples is None:
        n_samples = int(len(fpathlist)/seq_len)
        sample_stride = sample_len
    else:
        assert type(n_samples)==int
        assert len(fpathlist)>n_samples
        sample_stride = int(len(fpathlist)/n_samples)
        # for adult group, sample_stride ~=10. i.e. each frame contributes to more than 1 sample sequence, 
        # but doesn't appear in the same index of the sequence.

    fpathseqlist = [fpathlist[i:i+sample_len:ds_rate] 
                    for i in range(0, n_samples*sample_stride, sample_stride)]
    return fpathseqlist

def get_fold(gx_fpathlist, fold, max_folds, args):
#     fold_size = int(len(gx_fpathlist)/max_folds)
    segment_size = int(30*60*30/args.ds_rate)
    
    fold_segments = []

    for i_st in range(0, len(gx_fpathlist), segment_size):
        if (i_st // segment_size) % max_folds == fold:
            fold_segments.append(gx_fpathlist[i_st:i_st + segment_size])
            
    fold_segments = list(itertools.chain.from_iterable(fold_segments))
    return fold_segments



In [4]:
# import torchvision.transforms as transforms

def _get_transform(image_size):

    mean = [0.5, 0.5, 0.5]#np.mean(mean_all, axis=0) #mean_all[chosen_subj] 
    std = [0.25, 0.25, 0.25] #std_all[chosen_subj] 
#     [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
#     [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD
    transform_list = [tr.Resize(image_size), tr.CenterCrop(image_size), 
            tr.ConvertImageDtype(torch.float32), 
             tr.Normalize(mean,std)]
#     Alternative: 
#     transform_list = [tr.RandomResizedCrop(crop_size, scale=crop_scale),
#                     tr.ToTensor(),
#                     tr.Normalize(mean,std)]# crop_scale=(0.3, 1.0),
    return tr.Compose(transform_list)


In [5]:
# class ImageSequenceDataset(Dataset):
import PIL
class ImageDataset(Dataset):
    """
    To use for video models. 
    """
    def __init__(self, image_paths, transform, shuffle=False):
        self.image_paths = image_paths
        self.transform = transform
        self.shuffle = shuffle

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the sequence of images
        fp = self.image_paths[idx][0]
        images = self.transform(PIL.Image.open(fp))
#         images = self.transform(torchvision.io.read_image(fp))
#         images = torch.cat([
#             self.transform(torchvision.io.read_image(fp)).unsqueeze(0)
#                      for fp in self.image_paths[idx]]) #with tochvision transform
        
#         if self.shuffle:
#             size = images.size(0)
#             perm = torch.randperm(size)
#             images = images[perm]
            
        return images

class ImageSequenceDataset(Dataset):
    """
    To use for video models. 
    """
    def __init__(self, image_paths, transform, shuffle=False):
        self.image_paths = image_paths
        self.transform = transform
        self.shuffle = shuffle

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the sequence of images
#         fp = self.image_paths[idx][0]
#         images = self.transform(PIL.Image.open(fp))
#         images = self.transform(torchvision.io.read_image(fp))
        images = torch.cat([
            self.transform(torchvision.io.read_image(fp)).unsqueeze(0)
                     for fp in self.image_paths[idx]]) #with tochvision transform

        if self.shuffle:
            size = images.size(0)
            perm = torch.randperm(size)
            images = images[perm]
            
        return images
    


In [6]:
from homeview import (_get_transform, get_fpathlist,
    get_fold, get_train_val_split, get_fpathseqlist, 
    get_group,                  
    ImageSequenceDataset, ImageDataset)


In [7]:
def make_dataset(subj_dirs, image_size, args):
    seq_len = args.num_frames #kwargs['seq_len']
#     n_groupframes=kwargs['n_groupframes']#1450000
    ds_rate = args.ds_rate #kwargs['ds_rate']
    jpg_root = args.jpg_root #kwargs['jpg_root']
#     image_size = kwargs['image_size']
    fold = args.fold #kwargs['fold']
    condition = args.condition #kwargs['condition']
    n_trainsamples = args.n_trainsamples
    
    crop_size = args.crop_size
    crop_scale = args.crop_scale
    
    transform = _get_transform(image_size)
#     make_transforms(
#         crop_size=crop_size,
#         crop_scale=crop_scale)
    #_get_transform(image_size)
    gx_fpathlist = []
    for i_subj, subjdir in enumerate(tqdm(subj_dirs)):
        gx_fpathlist += get_fpathlist(jpg_root, subjdir, ds_rate=ds_rate)
    
    # added on May15
    max_folds = 3
    gx_fpathlist = get_fold(gx_fpathlist, fold, max_folds, args)
    print('Num. frames in the fold:',len(gx_fpathlist))

    #     if len(gx_fpathlist)>=n_groupframes:
#         gx_fpathlist = gx_fpathlist[:n_groupframes]
#         # 1450000/16 = 90625 => n_trainsamples=81560, n_valsamples= 9060
#         # 1274 iterations of train. 141 iterations of test. 
#         n_trainsamples = None
#         n_valsamples = None
#     else:
    
    # Train-val split
    gx_train_fp, gx_val_fp = get_train_val_split(gx_fpathlist, val_ratio=0.1)

    if condition=='longshuffle':
        random.shuffle(gx_train_fp)
    
    n_trainsamples = n_trainsamples #int(0.9*n_groupframes/seq_len) #81k
    
    n_maxvalsamples = int(len(gx_val_fp)/seq_len)
    n_valsamples = min(n_maxvalsamples, 10000)  #means don't do bootstraping for val. Use whatever number 0.1*len(gx_fpathlist) gives.
    
    gx_train_fpathseqlist = get_fpathseqlist(gx_train_fp, seq_len, ds_rate=1, n_samples=n_trainsamples)
    gx_val_fpathseqlist = get_fpathseqlist(gx_val_fp, seq_len, ds_rate=1, n_samples=n_valsamples)
    
    
    if condition=='shuffle':
        train_dataset = ImageSequenceDataset(gx_train_fpathseqlist, transform=transform, shuffle=True)
        val_dataset = ImageSequenceDataset(gx_val_fpathseqlist, transform=transform, shuffle=False)
    
    elif condition=='static':
        train_dataset = StillVideoDataset(gx_train_fpathseqlist, transform=transform)
        val_dataset = ImageSequenceDataset(gx_val_fpathseqlist, transform=transform, shuffle=False)
    elif condition=='image':
        train_dataset = ImageDataset(gx_train_fpathseqlist, transform=transform)
        val_dataset = ImageDataset(gx_val_fpathseqlist, transform=transform)
#         StillVideoDataset(gx_val_fpathseqlist, transform=transform)

    else:
        train_dataset = ImageSequenceDataset(gx_train_fpathseqlist, transform=transform, shuffle=False)
        val_dataset = ImageSequenceDataset(gx_val_fpathseqlist, transform=transform, shuffle=False)
        
    return {'train':train_dataset,
           'val': val_dataset}


In [8]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

In [9]:
crop_size=224
crop_scale=(0.3, 1.0)
use_gaussian_blur=False
use_horizontal_flip=False
use_color_distortion=False
color_jitter=0
patch_size=16
pred_mask_scale=(0.15, 0.2)
enc_mask_scale=(0.85,1)
aspect_ratio=(0.75,1.5)
num_enc_masks=1
num_pred_masks=4
allow_overlap=False
min_keep=10

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()




jpg_root='/N/project/baby_vision_curriculum/homeview_subset_30fps/'
data_seed=401

saveroot='/N/project/baby_vision_curriculum/trained_models/predictive/v0/jul17/'
tbsaveroot='/N/project/baby_vision_curriculum/trained_models/predictive/v0/benchmarks/toybox/jul17/'
ucsaveroot='/N/project/baby_vision_curriculum/trained_models/predictive/v0/benchmarks/ucf101/jul17/'

n_epoch=1
curr='yo'
condition='default'#'image'
other_id=curr+'_'+condition
monitor='grad'
ds_rate=30 #1
batch_size=8
n_trainsamples=12000#162000
max_epoch_iters=2000
tbbatch_size=64
ucbatch_size=64
mask_sampler='tube'
tubelet_size=1
num_frames=4
image_size=224

train_group='g2'#'g0'
fold=data_seed % 3
# mask_ratio=0.9
#adamw: lr=1.5e-4, wd=0.05, momentum=0.9 (doesn't get used)
#adam: lr=0.001. wd=1e-4, momentum=0.9 (doesn't get used)
#sgd: 0.1, 0, 0.9
optim='sgd'
lr=0.1
wd=0
architecture='base'#'small'#
momentum=0.9

world_size= 1

savedir=saveroot+'s1/' #"${saveroot}s1/"
# Initialization
init_checkpoint_path='na'

# other_seed=data_seed#$data_seed
script='pretrain_vjepa_v0.py'

args = Args(jpg_root=jpg_root,
            train_group=train_group,
            savedir=savedir,
            architecture=architecture,
            init_checkpoint_path=init_checkpoint_path,
            seed=data_seed,
            other_id=other_id,
            batch_size=batch_size,
           num_frames=num_frames,
            tubelet_size=tubelet_size,
           ds_rate=ds_rate,
           fold=fold,
           condition=condition,
           n_trainsamples=n_trainsamples,
           crop_size=crop_size,
            image_size=image_size,
            patch_size=patch_size,
           crop_scale=crop_scale)

In [10]:
g0='008MS+009SS+010BF+011EA+012TT+013LS+014SN+015JM+016TF+017EW'
g1='026AR+027SS+028CK+028MR+029TT+030FD+031HW+032SR+033SE+034JC'
g2='043MP+044ET+046TE+047MS+048KG+049JC+050AB+050AK+051DW'
g3='BR+CW+EA+ED+JB+KI+LS+SB+TR'
# Total number of frames in each age group: g0=1.68m, g1=1.77m, g2=1.45m

g0 = g0.split('+')
g1 = g1.split('+')
g2 = g2.split('+')
g3 = g3.split('+')

gRand=[]
for gx in [g0,g1,g2,g3]:
    gRand.extend(random.sample(gx, 3))
random.shuffle(gRand)

group_dict = {"g0":g0, "g1":g1, "g2":g2, "g3":g3, 'gr':gRand}
group = group_dict.get(args.train_group)
print(group)                                               

['043MP', '044ET', '046TE', '047MS', '048KG', '049JC', '050AB', '050AK', '051DW']


In [11]:
image_size=224
datasets = make_dataset(group, image_size, args)

100%|█████████████████████████████████████████████| 9/9 [00:33<00:00,  3.70s/it]

Num. frames in the fold: 46811





In [12]:
# import math
# from multiprocessing import Value
# from logging import getLogger

# _GLOBAL_SEED = 0
# logger = getLogger()


In [13]:
from mask import MaskCollator as MBMaskCollator

In [14]:
mask_collator = MBMaskCollator(
    input_size=crop_size,
    patch_size=patch_size,
    pred_mask_scale=pred_mask_scale,
    enc_mask_scale=enc_mask_scale,
    aspect_ratio=aspect_ratio,
    nenc=num_enc_masks,
    npred=num_pred_masks,
    allow_overlap=allow_overlap,
    min_keep=min_keep)

In [15]:
pin_memory = False
# num_workers=
dataloaders = {x: torch.utils.data.DataLoader(
        datasets[x], batch_size=batch_size, pin_memory=pin_memory, 
        shuffle=False, collate_fn=mask_collator,
        drop_last=True)
                        for x in ['train', 'val']}


In [54]:
len(dataloaders['train'])

1500

In [62]:
phase = 'train'
for inputs in tqdm(dataloaders[phase]):
    break

  0%|                                                   | 0/750 [00:00<?, ?it/s]


In [63]:
for item in inputs[0]:
    print('udata:',item.shape)
for item in inputs[1]:
    print('encoder context:',item.shape)
for item in inputs[2]:
    print('prediction target:',item.shape)

udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
udata: torch.Size([6, 3, 224, 224])
encoder context: torch.Size([16, 48])
prediction target: torch.Size([16, 36])
prediction target: torch.Size([16, 36])
prediction target: torch.Size([16, 36])
prediction target: torch.Size([16, 36])


In [16]:
import vision_transformer as vit
from tensors import trunc_normal_

 #@@@
def init_model(
    device,
    image_size=224,
    patch_size=16,
    tubelet_size=1,
    num_frames=1,
    model_name='vit_base',
    pred_depth=6,
    pred_emb_dim=384
):
#     encoder = vit_base(
    encoder = vit.__dict__[model_name](
        img_size=[image_size],
        patch_size=patch_size,
        num_frames=num_frames,
        tubelet_size=tubelet_size) #@@@
#     predictor = vit_predictor(
    predictor = vit.__dict__['vit_predictor'](
        sequence_shape=encoder.sequence_shape,
        embed_dim=encoder.embed_dim,
        predictor_embed_dim=pred_emb_dim,
        depth=pred_depth,
        num_heads=encoder.num_heads)

    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)

    for m in encoder.modules():
        init_weights(m)

    for m in predictor.modules():
        init_weights(m)

    encoder.to(device)
    predictor.to(device)
    logger.info(encoder)
    return encoder, predictor

In [17]:
from helper import init_opt #init_model, 

In [18]:
device= 'cuda:0'#'cpu'
model_name='vit_'+args.architecture #'vit_base'#'vit_small'
pred_depth=6
pred_emb_dim=384
# tubelet_size=2#2#1 #@@@
# num_frames=6#6#1 #@@@

encoder, predictor = init_model(
    device=device,
    image_size=image_size,
    patch_size=patch_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_name=model_name,
    tubelet_size=tubelet_size,
    num_frames=num_frames)
target_encoder = deepcopy(encoder)

INFO:root:VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv3d(3, 768, kernel_size=(1, 16, 16), stride=(1, 16, 16))
  )
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate=none)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear

In [20]:
# the masks are defined for certain spatial
# patches. Their index is based on single frame.
# in order to apply them to spatiotemporal patches, we need to convert 
# their indices. In particular, for target indices where we want to
# choose from future frames, we might want to add a random multiple
# of num_patches_per_frame to all masks. the random integer is in [0, num_frames-1]

def update_masks(masks, args):
    T = args.num_frames//args.tubelet_size
    num_patches_per_frame = (args.image_size//args.patch_size)**2
    
    for i_mask, m in enumerate(masks):
        frame_index = np.random.randint(0,T)
        m+=frame_index*num_patches_per_frame
        masks[i_mask] = m
    return masks


def apply_masks(x, masks):
    """
    :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
    :param masks: list of tensors containing indices of patches in [N] to keep
    """
    all_x = []
    for m in masks:
        mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)

In [21]:
wd=0
final_wd=0
start_lr=0.1
ref_lr=0.1
final_lr=0.1
ipe=2000 #iterations_per_epoch
warmup=0
num_epochs=1
ipe_scale=1.0
use_bfloat16=True

optimizer, scaler, scheduler, wd_scheduler = init_opt(
    encoder=encoder,
    predictor=predictor,
    wd=wd,
    final_wd=final_wd,
    start_lr=start_lr,
    ref_lr=lr,
    final_lr=final_lr,
    iterations_per_epoch=ipe,
    warmup=warmup,
    num_epochs=num_epochs,
    ipe_scale=ipe_scale,
    use_bfloat16=use_bfloat16)
# encoder = DistributedDataParallel(encoder, static_graph=True)
# predictor = DistributedDataParallel(predictor, static_graph=True)
# target_encoder = DistributedDataParallel(target_encoder)
for p in target_encoder.parameters():
    p.requires_grad = False

INFO:root:Using SGD


In [22]:
# -- momentum schedule
ema = (0.996, 1.0)
momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale)
                      for i in range(int(ipe*num_epochs*ipe_scale)+1))


In [23]:


# def save_checkpoint(epoch):
#     save_dict = {
#         'encoder': encoder.state_dict(),
#         'predictor': predictor.state_dict(),
#         'target_encoder': target_encoder.state_dict(),
#         'opt': None,
#         'scaler': None if scaler is None else scaler.state_dict(),
#         'epoch': epoch,
#         'loss': loss_meter.avg,
#         'batch_size': batch_size,
#         'world_size': world_size,
#         'lr': lr
#     }
#     if rank == 0:
#         torch.save(save_dict, latest_path)
#         if (epoch + 1) % checkpoint_freq == 0:
#             torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))

In [24]:
rank=0
r_file=None

In [26]:
from loggingtools import (
    CSVLogger,
    grad_logger,
    AverageMeter,
    gpu_timer)

import yaml

# -- LOGGING
folder = '/N/project/baby_vision_curriculum/trained_models/predictive/v0/jul24/'
#args['logging']['folder']
tag = 'jepa'
#args['logging']['write_tag']

dump = os.path.join(folder, 'params-ijepa.yaml')
with open(dump, 'w') as f:
    yaml.dump(args, f)
# ----------------------------------------------------------------------- #

# try:
#     mp.set_start_method('spawn')
# except Exception:
#     pass

# -- init torch distributed backend
# world_size, rank = init_distributed()
# logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')
# if rank > 0:
#     logger.setLevel(logging.ERROR)

# -- log/checkpointing paths
log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') #save checkpoint path
load_path = None
load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

# -- make csv_logger
csv_logger = CSVLogger(log_file,
                       ('%d', 'epoch'),
                       ('%d', 'itr'),
                       ('%.5f', 'loss'),
                       ('%.5f', 'mask-A'),
                       ('%.5f', 'mask-B'),
                       ('%d', 'time (ms)'))

In [27]:
print(load_path)

/N/project/baby_vision_curriculum/trained_models/predictive/v0/jul24/jepa-latest.pth.tar


In [28]:
import torch.nn.functional as F
from tensors import apply_masks, repeat_interleave_batch
from distributed import AllReduce

In [31]:
def save_checkpoint(epoch):
    save_dict = {
        'encoder': encoder.state_dict(),
        'predictor': predictor.state_dict(),
        'target_encoder': target_encoder.state_dict(),
        'opt': optimizer.state_dict(),
        'scaler': None if scaler is None else scaler.state_dict(),
        'epoch': epoch,
        'loss': loss_meter.avg,
        'batch_size': batch_size,
        'world_size': world_size,
        'lr': lr
    }
    if rank == 0:
        torch.save(save_dict, latest_path)
#         if (epoch + 1) % checkpoint_freq == 0:
#             torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))

In [32]:
phase = 'train'
start_epoch = 0

# --
log_timings = True
log_freq = 10
checkpoint_freq = 50
# --

load_model = False #True #False
# -- load training checkpoint
if load_model:
    encoder, predictor, target_encoder, optimizer, scaler, start_epoch = load_checkpoint(
        r_path=load_path,
        encoder=encoder,
        predictor=predictor,
        target_encoder=target_encoder,
        opt=optimizer,
        scaler=scaler)
    for _ in range(start_epoch*ipe):
        mask_collator.step()
#     start_epoch+=1
#         scheduler.step()
#         wd_scheduler.step()
#         next(momentum_scheduler)
        

            
# -- TRAINING LOOP
for epoch in range(start_epoch, start_epoch+num_epochs):
    logger.info('Epoch %d' % (epoch + 1))

    # -- update distributed-data-loader epoch
#     unsupervised_sampler.set_epoch(epoch)

    loss_meter = AverageMeter()
    maskA_meter = AverageMeter()
    maskB_meter = AverageMeter()
    time_meter = AverageMeter()

    for itr, (udata, masks_enc, masks_pred) in enumerate(
        tqdm(dataloaders[phase])):
        
        i_break = 11
        if itr>i_break:
            break #@@@
        masks_enc = update_masks(masks_enc, args)
        masks_pred = update_masks(masks_pred, args)
        
        def load_imgs():
            # -- unsupervised imgs
            imgs = udata.to(device, non_blocking=True)
            masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
            masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
            return (imgs, masks_1, masks_2)
        imgs, masks_enc, masks_pred = load_imgs()
        maskA_meter.update(len(masks_enc[0][0]))
        maskB_meter.update(len(masks_pred[0][0]))

        def train_step():
#             _new_lr = scheduler.step()
#             _new_wd = wd_scheduler.step()
            # --

            def forward_target():
                with torch.no_grad():
                    h = target_encoder(imgs)
                    h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
                    B = len(h)
                    # -- create targets (masked regions of h)
                    h = apply_masks(h, masks_pred)
#                     print('h shape:', h.shape)
                    h = repeat_interleave_batch(h, B, repeat=len(masks_enc))
                    return h

            def forward_context():
                z = encoder(imgs, masks_enc)
                z = predictor(z, masks_enc, masks_pred)
                return z

            def loss_fn(z, h):
                loss = F.smooth_l1_loss(z, h)
                loss = AllReduce.apply(loss)
                return loss

            # Step 1. Forward
            with torch.cuda.amp.autocast(dtype=torch.bfloat16, 
                                         enabled=use_bfloat16):
                h = forward_target()
                z = forward_context()
                loss = loss_fn(z, h)

            #  Step 2. Backward & step
            if phase == 'val':
                return float(loss),0
            
            if use_bfloat16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            grad_stats = grad_logger(encoder.named_parameters())
            optimizer.zero_grad()

            # Step 3. momentum update of target encoder
            with torch.no_grad():
                m = next(momentum_scheduler)
                for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
                    param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)
            return (float(loss), grad_stats)
            
        if phase=='val':
            (loss, _), etime = gpu_timer(train_step)
            logger.info('val loss: %.3f etime %.1f ms', loss, etime)
            continue
        (loss, grad_stats), etime = gpu_timer(train_step)
        loss_meter.update(loss)
        time_meter.update(etime)

        # -- Logging
        def log_stats():
            csv_logger.log(epoch + 1, itr, loss, maskA_meter.val, maskB_meter.val, etime)
            if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                logger.info('[%d, %5d] loss: %.3f '
                            'masks: %.1f %.1f '
                            '[mem: %.2e] '
                            '(%.1f ms)'
                            % (epoch + 1, itr,
                               loss_meter.avg,
                               maskA_meter.avg,
                               maskB_meter.avg,
                               torch.cuda.max_memory_allocated() / 1024.**2,
                               time_meter.avg))

                if grad_stats is not None:
                    logger.info('[%d, %5d] grad_stats: [%.2e %.2e] (%.2e, %.2e)'
                                % (epoch + 1, itr,
                                   grad_stats.first_layer,
                                   grad_stats.last_layer,
                                   grad_stats.min,
                                   grad_stats.max))

        log_stats()

        assert not np.isnan(loss), 'loss is nan'

    # -- Save Checkpoint after every epoch
    logger.info('avg. loss %.3f' % loss_meter.avg)
    save_checkpoint(epoch+1)


INFO:root:Epoch 1


  0%|                                                  | 0/1500 [00:00<?, ?it/s]

INFO:root:[1,     0] loss: 0.169 masks: 73.0 30.0 [mem: 2.42e+03] (248.4 ms)
INFO:root:[1,     0] grad_stats: [2.85e-03 1.26e-03] (1.26e-03, 3.01e-03)


  1%|▎                                        | 10/1500 [00:04<10:08,  2.45it/s]

INFO:root:[1,    10] loss: 0.186 masks: 62.0 33.7 [mem: 2.45e+03] (242.7 ms)
INFO:root:[1,    10] grad_stats: [1.35e-03 8.26e-04] (8.16e-04, 2.22e-03)


  1%|▎                                        | 12/1500 [00:05<10:20,  2.40it/s]

INFO:root:avg. loss 0.184



