In [1]:
import sys, os

env_root = '/N/project/baby_vision_curriculum/pythonenvs/hfenv/lib/python3.10/site-packages/'
sys.path.insert(0, env_root)

In [2]:
import numpy as np
import torch, torchvision
from torchvision import transforms as tr
from tqdm import tqdm
from pathlib import Path
# import math
import argparse
import pandas as pd
import warnings

import transformers

# import torch.distributed as dist
# from torch.utils.data.distributed import DistributedSampler
# from torch.nn.parallel import DistributedDataParallel as DDP
# import torch.multiprocessing as mp
# from ddputils import is_main_process, save_on_master, setup_for_distributed


# from PIL import Image
from torch.utils.data import Dataset
import random
from copy import deepcopy
import json
import itertools
import gc

In [3]:
crop_size=224
crop_scale=(0.3, 1.0)#1 #0.3
use_gaussian_blur=False
use_horizontal_flip=False
use_color_distortion=False
color_jitter=0
patch_size=16
pred_mask_scale=(0.15, 0.2)
enc_mask_scale=(0.85,1)
aspect_ratio=(0.75,1.5)
num_enc_masks=1
num_pred_masks=4
allow_overlap=False
min_keep=10

In [4]:
def get_fpathlist(vid_root, subjdir, ds_rate=1):
    """
    # read the image files inside vid_root/subj_dir into a list. 
    # makes sure they're all jpg. also sorts them so that the order of the frames is correct.
    # subjdir = ['008MS']
    """
    
    fpathlist = sorted(list(Path(os.path.join(vid_root, subjdir)).iterdir()), 
                       key=lambda x: x.name)
    fpathlist = [str(fpath) for fpath in fpathlist if fpath.suffix=='.jpg']
    fpathlist = fpathlist[::ds_rate]
    return fpathlist
    
def get_train_val_split(fpathlist, val_ratio=0.1):
    """
    Splits the list of filepaths into a train list and val list
    """
    n_fr = len(fpathlist)
    val_size = int(n_fr*val_ratio)
    
    split1_idx = int((n_fr-val_size)/2)
    split2_idx = int((n_fr+val_size)/2)
    train_set =fpathlist[:split1_idx]+fpathlist[split2_idx:]
    val_set = fpathlist[split1_idx:split2_idx]
    return train_set, val_set

def get_fpathseqlist(fpathlist, seq_len, ds_rate=1, n_samples=None):
    """
    Returns a list of list that can be passed to ImageSequenceDataset
    # n_samples: int
    # between 1 and len(fpathlist)
    # If None, it's set to len(fpathlist)/seq_len
    """
    
    sample_len = seq_len*ds_rate
    if n_samples is None:
        n_samples = int(len(fpathlist)/seq_len)
        sample_stride = sample_len
    else:
        assert type(n_samples)==int
        assert len(fpathlist)>n_samples
        sample_stride = int(len(fpathlist)/n_samples)
        # for adult group, sample_stride ~=10. i.e. each frame contributes to more than 1 sample sequence, 
        # but doesn't appear in the same index of the sequence.

    fpathseqlist = [fpathlist[i:i+sample_len:ds_rate] 
                    for i in range(0, n_samples*sample_stride, sample_stride)]
    return fpathseqlist


def _get_transform(image_size):

    mean = [0.5, 0.5, 0.5]#np.mean(mean_all, axis=0) #mean_all[chosen_subj] 
    std = [0.25, 0.25, 0.25] #std_all[chosen_subj] 
    
#     [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
#     [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD
    
    augs = [tr.Resize(image_size), tr.CenterCrop(image_size), 
            tr.ConvertImageDtype(torch.float32), 
             tr.Normalize(mean,std)]
    return tr.Compose(augs)

def get_fold(gx_fpathlist, fold, max_folds, args):
#     fold_size = int(len(gx_fpathlist)/max_folds)
    segment_size = int(30*60*30/args.ds_rate)
    
    fold_segments = []

    for i_st in range(0, len(gx_fpathlist), segment_size):
        if (i_st // segment_size) % max_folds == fold:
            fold_segments.append(gx_fpathlist[i_st:i_st + segment_size])
            
    fold_segments = list(itertools.chain.from_iterable(fold_segments))
    return fold_segments



In [5]:
import torchvision.transforms as transforms

def make_transforms(
    crop_size=224,
    crop_scale=(0.3, 1.0),
    normalization=((0.5, 0.5, 0.5),
                   (0.25, 0.25, 0.25))
):

    transform_list = []
    transform_list += [transforms.RandomResizedCrop(crop_size, scale=crop_scale)]
#     transform_list += [transforms.ConvertImageDtype(torch.float32)] 
    transform_list += [transforms.ToTensor()]
    transform_list += [transforms.Normalize(normalization[0], normalization[1])]

    transform = transforms.Compose(transform_list)
    return transform


# transform = make_transforms(
#     crop_size=crop_size,
#     crop_scale=crop_scale)

In [6]:
# class ImageSequenceDataset(Dataset):
import PIL
class ImageDataset(Dataset):
    """
    To use for video models. 
    """
    def __init__(self, image_paths, transform, shuffle=False):
        self.image_paths = image_paths
        self.transform = transform
        self.shuffle = shuffle

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the sequence of images
        fp = self.image_paths[idx][0]
        images = self.transform(PIL.Image.open(fp))
#         images = self.transform(torchvision.io.read_image(fp))
#         images = torch.cat([
#             self.transform(torchvision.io.read_image(fp)).unsqueeze(0)
#                      for fp in self.image_paths[idx]]) #with tochvision transform
        
#         if self.shuffle:
#             size = images.size(0)
#             perm = torch.randperm(size)
#             images = images[perm]
            
        return images

def make_dataset(subj_dirs, image_size, args):
    seq_len = args.num_frames #kwargs['seq_len']
#     n_groupframes=kwargs['n_groupframes']#1450000
    ds_rate = args.ds_rate #kwargs['ds_rate']
    jpg_root = args.jpg_root #kwargs['jpg_root']
#     image_size = kwargs['image_size']
    fold = args.fold #kwargs['fold']
    condition = args.condition #kwargs['condition']
    n_trainsamples = args.n_trainsamples
    
    crop_size = args.crop_size
    crop_scale = args.crop_scale
    
    transform = make_transforms(
        crop_size=crop_size,
        crop_scale=crop_scale)
    #_get_transform(image_size)
    gx_fpathlist = []
    for i_subj, subjdir in enumerate(tqdm(subj_dirs)):
        gx_fpathlist += get_fpathlist(jpg_root, subjdir, ds_rate=ds_rate)
    
    # added on May15
    max_folds = 3
    gx_fpathlist = get_fold(gx_fpathlist, fold, max_folds, args)
    print('Num. frames in the fold:',len(gx_fpathlist))

    #     if len(gx_fpathlist)>=n_groupframes:
#         gx_fpathlist = gx_fpathlist[:n_groupframes]
#         # 1450000/16 = 90625 => n_trainsamples=81560, n_valsamples= 9060
#         # 1274 iterations of train. 141 iterations of test. 
#         n_trainsamples = None
#         n_valsamples = None
#     else:
    
    # Train-val split
    gx_train_fp, gx_val_fp = get_train_val_split(gx_fpathlist, val_ratio=0.1)

    if condition=='longshuffle':
        random.shuffle(gx_train_fp)
    
    n_trainsamples = n_trainsamples #int(0.9*n_groupframes/seq_len) #81k
    
    n_maxvalsamples = int(len(gx_val_fp)/seq_len)
    n_valsamples = min(n_maxvalsamples, 10000)  #means don't do bootstraping for val. Use whatever number 0.1*len(gx_fpathlist) gives.
    
    gx_train_fpathseqlist = get_fpathseqlist(gx_train_fp, seq_len, ds_rate=1, n_samples=n_trainsamples)
    gx_val_fpathseqlist = get_fpathseqlist(gx_val_fp, seq_len, ds_rate=1, n_samples=n_valsamples)
    
    
    if condition=='shuffle':
        train_dataset = ImageSequenceDataset(gx_train_fpathseqlist, transform=transform, shuffle=True)
        val_dataset = ImageSequenceDataset(gx_val_fpathseqlist, transform=transform, shuffle=False)
    
    elif condition=='static':
        train_dataset = StillVideoDataset(gx_train_fpathseqlist, transform=transform)
        val_dataset = ImageSequenceDataset(gx_val_fpathseqlist, transform=transform, shuffle=False)
    elif condition=='image':
        train_dataset = ImageDataset(gx_train_fpathseqlist, transform=transform)
        val_dataset = ImageDataset(gx_val_fpathseqlist, transform=transform)
#         StillVideoDataset(gx_val_fpathseqlist, transform=transform)

    else:
        train_dataset = ImageSequenceDataset(gx_train_fpathseqlist, transform=transform, shuffle=False)
        val_dataset = ImageSequenceDataset(gx_val_fpathseqlist, transform=transform, shuffle=False)
        
    return {'train':train_dataset,
           'val': val_dataset}


In [7]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

In [105]:
jpg_root='/N/project/baby_vision_curriculum/homeview_subset_30fps/'
data_seed=401

saveroot='/N/project/baby_vision_curriculum/trained_models/predictive/v0/jul17/'
tbsaveroot='/N/project/baby_vision_curriculum/trained_models/predictive/v0/benchmarks/toybox/jul17/'
ucsaveroot='/N/project/baby_vision_curriculum/trained_models/predictive/v0/benchmarks/ucf101/jul17/'

n_epoch=5
curr='yo'
condition='image'
other_id=curr+'_'+condition
monitor='grad'
ds_rate=1
batch_size=256#64#16
n_trainsamples=12000#162000
max_epoch_iters=2000
tbbatch_size=64
ucbatch_size=64
mask_sampler='tube'
tubelet_size=1
num_frames=1

train_group='g2'#'g0'
fold=data_seed % 3
# mask_ratio=0.9
#adamw: lr=1.5e-4, wd=0.05, momentum=0.9 (doesn't get used)
#adam: lr=0.001. wd=1e-4, momentum=0.9 (doesn't get used)
#sgd: 0.1, 0, 0.9
optim='sgd'
lr=0.1
wd=0
architecture='small'#'base'
momentum=0.9



savedir=saveroot+'s1/' #"${saveroot}s1/"
# Initialization
init_checkpoint_path='na'

# other_seed=data_seed#$data_seed
script='pretrain_vjepa_v0.py'

args = Args(jpg_root=jpg_root,
            train_group=train_group,
            savedir=savedir,
            architecture=architecture,
            init_checkpoint_path=init_checkpoint_path,
            seed=data_seed,
            other_id=other_id,
            batch_size=batch_size,
           num_frames=num_frames,
           ds_rate=ds_rate,
           fold=fold,
           condition=condition,
           n_trainsamples=n_trainsamples,
           crop_size=crop_size,
           crop_scale=crop_scale)

In [106]:
g0='008MS+009SS+010BF+011EA+012TT+013LS+014SN+015JM+016TF+017EW'
g1='026AR+027SS+028CK+028MR+029TT+030FD+031HW+032SR+033SE+034JC'
g2='043MP+044ET+046TE+047MS+048KG+049JC+050AB+050AK+051DW'
g3='BR+CW+EA+ED+JB+KI+LS+SB+TR'
# Total number of frames in each age group: g0=1.68m, g1=1.77m, g2=1.45m

g0 = g0.split('+')
g1 = g1.split('+')
g2 = g2.split('+')
g3 = g3.split('+')

gRand=[]
for gx in [g0,g1,g2,g3]:
    gRand.extend(random.sample(gx, 3))
random.shuffle(gRand)

group_dict = {"g0":g0, "g1":g1, "g2":g2, "g3":g3, 'gr':gRand}
group = group_dict.get(args.train_group)
print(group)                                               

['043MP', '044ET', '046TE', '047MS', '048KG', '049JC', '050AB', '050AK', '051DW']


In [107]:
image_size=224
datasets = make_dataset(group, image_size, args)

100%|█████████████████████████████████████████████| 9/9 [00:22<00:00,  2.47s/it]


Num. frames in the fold: 1404192


In [108]:
import math
from multiprocessing import Value
from logging import getLogger

_GLOBAL_SEED = 0
logger = getLogger()

class MBMaskCollator(object):

    def __init__(
        self,
        input_size=(224, 224),
        patch_size=16,
        enc_mask_scale=(0.2, 0.8),
        pred_mask_scale=(0.2, 0.8),
        aspect_ratio=(0.3, 3.0),
        nenc=1,
        npred=2,
        min_keep=4,
        allow_overlap=False
    ):
        super(MBMaskCollator, self).__init__()
        if not isinstance(input_size, tuple):
            input_size = (input_size, ) * 2
        self.patch_size = patch_size
        self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size
        self.enc_mask_scale = enc_mask_scale
        self.pred_mask_scale = pred_mask_scale
        self.aspect_ratio = aspect_ratio
        self.nenc = nenc
        self.npred = npred
        self.min_keep = min_keep  # minimum number of patches to keep
        self.allow_overlap = allow_overlap  # whether to allow overlap b/w enc and pred masks
        self._itr_counter = Value('i', -1)  # collator is shared across worker processes

    def step(self):
        i = self._itr_counter
        with i.get_lock():
            i.value += 1
            v = i.value
        return v

    def _sample_block_size(self, generator, scale, aspect_ratio_scale):
        _rand = torch.rand(1, generator=generator).item()
        # -- Sample block scale
        min_s, max_s = scale
        mask_scale = min_s + _rand * (max_s - min_s)
        max_keep = int(self.height * self.width * mask_scale)
        # -- Sample block aspect-ratio
        min_ar, max_ar = aspect_ratio_scale
        aspect_ratio = min_ar + _rand * (max_ar - min_ar)
        # -- Compute block height and width (given scale and aspect-ratio)
        h = int(round(math.sqrt(max_keep * aspect_ratio)))
        w = int(round(math.sqrt(max_keep / aspect_ratio)))
        while h >= self.height:
            h -= 1
        while w >= self.width:
            w -= 1

        return (h, w)

    def _sample_block_mask(self, b_size, acceptable_regions=None):
        h, w = b_size

        def constrain_mask(mask, tries=0):
            """ Helper to restrict given mask to a set of acceptable regions """
            N = max(int(len(acceptable_regions)-tries), 0)
            for k in range(N):
                mask *= acceptable_regions[k]
        # --
        # -- Loop to sample masks until we find a valid one
        tries = 0
        timeout = og_timeout = 20
        valid_mask = False
        while not valid_mask:
            # -- Sample block top-left corner
            top = torch.randint(0, self.height - h, (1,))
            left = torch.randint(0, self.width - w, (1,))
            mask = torch.zeros((self.height, self.width), dtype=torch.int32)
            mask[top:top+h, left:left+w] = 1
            # -- Constrain mask to a set of acceptable regions
            if acceptable_regions is not None:
                constrain_mask(mask, tries)
            mask = torch.nonzero(mask.flatten())
            # -- If mask too small try again
            valid_mask = len(mask) > self.min_keep
            if not valid_mask:
                timeout -= 1
                if timeout == 0:
                    tries += 1
                    timeout = og_timeout
                    logger.warning(f'Mask generator says: "Valid mask not found, decreasing acceptable-regions [{tries}]"')
        mask = mask.squeeze()
        # --
        mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
        mask_complement[top:top+h, left:left+w] = 0
        # --
        return mask, mask_complement

    def __call__(self, batch):
        '''
        Create encoder and predictor masks when collating imgs into a batch
        # 1. sample enc block (size + location) using seed
        # 2. sample pred block (size) using seed
        # 3. sample several enc block locations for each image (w/o seed)
        # 4. sample several pred block locations for each image (w/o seed)
        # 5. return enc mask and pred mask
        '''
        B = len(batch)

        collated_batch = torch.utils.data.default_collate(batch)

        seed = self.step()
        g = torch.Generator()
        g.manual_seed(seed)
        p_size = self._sample_block_size(
            generator=g,
            scale=self.pred_mask_scale,
            aspect_ratio_scale=self.aspect_ratio)
        e_size = self._sample_block_size(
            generator=g,
            scale=self.enc_mask_scale,
            aspect_ratio_scale=(1., 1.))

        collated_masks_pred, collated_masks_enc = [], []
        min_keep_pred = self.height * self.width
        min_keep_enc = self.height * self.width
        for _ in range(B):

            masks_p, masks_C = [], []
            for _ in range(self.npred):
                mask, mask_C = self._sample_block_mask(p_size)
                masks_p.append(mask)
                masks_C.append(mask_C)
                min_keep_pred = min(min_keep_pred, len(mask))
            collated_masks_pred.append(masks_p)

            acceptable_regions = masks_C
            try:
                if self.allow_overlap:
                    acceptable_regions= None
            except Exception as e:
                logger.warning(f'Encountered exception in mask-generator {e}')

            masks_e = []
            for _ in range(self.nenc):
                mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
                masks_e.append(mask)
                min_keep_enc = min(min_keep_enc, len(mask))
            collated_masks_enc.append(masks_e)

        collated_masks_pred = [[cm[:min_keep_pred] 
                                for cm in cm_list] 
                               for cm_list in collated_masks_pred]
        collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
        # --
        collated_masks_enc = [[cm[:min_keep_enc] 
                               for cm in cm_list] 
                              for cm_list in collated_masks_enc]
        collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

        return collated_batch, collated_masks_enc, collated_masks_pred

In [109]:


mask_collator = MBMaskCollator(
    input_size=crop_size,
    patch_size=patch_size,
    pred_mask_scale=pred_mask_scale,
    enc_mask_scale=enc_mask_scale,
    aspect_ratio=aspect_ratio,
    nenc=num_enc_masks,
    npred=num_pred_masks,
    allow_overlap=allow_overlap,
    min_keep=min_keep)

In [110]:
pin_memory = False
# num_workers=
dataloaders = {x: torch.utils.data.DataLoader(
        datasets[x], batch_size=batch_size, pin_memory=pin_memory, 
        shuffle=False, collate_fn=mask_collator,
        drop_last=True)
                        for x in ['train', 'val']}


In [111]:
len(dataloaders['train'])

46

In [112]:
phase = 'train'
for inputs in tqdm(dataloaders[phase]):
    break

  0%|                                                    | 0/46 [00:01<?, ?it/s]


In [18]:
for item in inputs[0]:
    print('udata:',item.shape)
for item in inputs[1]:
    print('encoder context:',item.shape)
for item in inputs[2]:
    print('prediction target:',item.shape)

udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
udata: torch.Size([3, 224, 224])
encoder context: torch.Size([8, 67])
prediction target: torch.Size([8, 36])
prediction target: torch.Size([8, 36])
prediction target: torch.Size([8, 36])
prediction target: torch.Size([8, 36])


In [82]:
from helper import init_model, init_opt

In [113]:
device= 'cuda:0'#'cpu'
model_name='vit_base'#'vit_small'
pred_depth=6
pred_emb_dim=384
    
encoder, predictor = init_model(
    device=device,
    patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_name=model_name)
target_encoder = deepcopy(encoder)

INFO:root:VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate=none)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_fe

In [114]:
wd=0
final_wd=0
start_lr=0.1
ref_lr=0.1
final_lr=0.1
ipe=2000 #iterations_per_epoch
warmup=0
num_epochs=1
ipe_scale=1.0
use_bfloat16=True

optimizer, scaler, scheduler, wd_scheduler = init_opt(
    encoder=encoder,
    predictor=predictor,
    wd=wd,
    final_wd=final_wd,
    start_lr=start_lr,
    ref_lr=lr,
    final_lr=final_lr,
    iterations_per_epoch=ipe,
    warmup=warmup,
    num_epochs=num_epochs,
    ipe_scale=ipe_scale,
    use_bfloat16=use_bfloat16)
# encoder = DistributedDataParallel(encoder, static_graph=True)
# predictor = DistributedDataParallel(predictor, static_graph=True)
# target_encoder = DistributedDataParallel(target_encoder)
for p in target_encoder.parameters():
    p.requires_grad = False

INFO:root:Using SGD


In [115]:
# -- momentum schedule
ema = (0.996, 1.0)
momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale)
                      for i in range(int(ipe*num_epochs*ipe_scale)+1))


In [116]:
phase = 'train'
start_epoch = 0

# --
log_timings = True
log_freq = 10
checkpoint_freq = 50
# --

load_model = False
# -- load training checkpoint
if load_model:
    encoder, predictor, target_encoder, _, scaler, _ = load_checkpoint(
        device=device,
        r_path=load_path,
        encoder=encoder,
        predictor=predictor,
        target_encoder=target_encoder,
        opt=optimizer,
        scaler=scaler)
    for _ in range(start_epoch*ipe):
#         scheduler.step()
#         wd_scheduler.step()
#         next(momentum_scheduler)
        mask_collator.step()
    
def save_checkpoint(epoch):
    save_dict = {
        'encoder': encoder.state_dict(),
        'predictor': predictor.state_dict(),
        'target_encoder': target_encoder.state_dict(),
        'opt': None,
        'scaler': None if scaler is None else scaler.state_dict(),
        'epoch': epoch,
        'loss': loss_meter.avg,
        'batch_size': batch_size,
        'world_size': world_size,
        'lr': lr
    }
    if rank == 0:
        torch.save(save_dict, latest_path)
        if (epoch + 1) % checkpoint_freq == 0:
            torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))

In [117]:
rank=0

In [118]:
from loggingtools import (
    CSVLogger,
    grad_logger,
    AverageMeter,
    gpu_timer)

import yaml

# -- LOGGING
folder = '/N/project/baby_vision_curriculum/trained_models/predictive/v0/jul17/'
#args['logging']['folder']
tag = 'jepa'
#args['logging']['write_tag']

dump = os.path.join(folder, 'params-ijepa.yaml')
with open(dump, 'w') as f:
    yaml.dump(args, f)
# ----------------------------------------------------------------------- #

# try:
#     mp.set_start_method('spawn')
# except Exception:
#     pass

# -- init torch distributed backend
# world_size, rank = init_distributed()
# logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')
# if rank > 0:
#     logger.setLevel(logging.ERROR)

# -- log/checkpointing paths
log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
load_path = None
if load_model:
    load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

# -- make csv_logger
csv_logger = CSVLogger(log_file,
                       ('%d', 'epoch'),
                       ('%d', 'itr'),
                       ('%.5f', 'loss'),
                       ('%.5f', 'mask-A'),
                       ('%.5f', 'mask-B'),
                       ('%d', 'time (ms)'))

In [119]:
import torch.nn.functional as F
from tensors import apply_masks, repeat_interleave_batch
from distributed import AllReduce

In [120]:
# -- TRAINING LOOP
for epoch in range(start_epoch, num_epochs):
    logger.info('Epoch %d' % (epoch + 1))

    # -- update distributed-data-loader epoch
#     unsupervised_sampler.set_epoch(epoch)

    loss_meter = AverageMeter()
    maskA_meter = AverageMeter()
    maskB_meter = AverageMeter()
    time_meter = AverageMeter()

    for itr, (udata, masks_enc, masks_pred) in enumerate(
        tqdm(dataloaders[phase])):
        
        i_break = 200
        if itr>i_break:
            break #@@@
            
        def load_imgs():
            # -- unsupervised imgs
            imgs = udata.to(device, non_blocking=True)
            masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
            masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
            return (imgs, masks_1, masks_2)
        imgs, masks_enc, masks_pred = load_imgs()
        maskA_meter.update(len(masks_enc[0][0]))
        maskB_meter.update(len(masks_pred[0][0]))

        def train_step():
#             _new_lr = scheduler.step()
#             _new_wd = wd_scheduler.step()
            # --

            def forward_target():
                with torch.no_grad():
                    h = target_encoder(imgs)
                    h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
                    B = len(h)
                    # -- create targets (masked regions of h)
                    h = apply_masks(h, masks_pred)
                    h = repeat_interleave_batch(h, B, repeat=len(masks_enc))
                    return h

            def forward_context():
                z = encoder(imgs, masks_enc)
                z = predictor(z, masks_enc, masks_pred)
                return z

            def loss_fn(z, h):
                loss = F.smooth_l1_loss(z, h)
                loss = AllReduce.apply(loss)
                return loss

            # Step 1. Forward
            with torch.cuda.amp.autocast(dtype=torch.bfloat16, 
                                         enabled=use_bfloat16):
                h = forward_target()
                z = forward_context()
                loss = loss_fn(z, h)

            #  Step 2. Backward & step
            if use_bfloat16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            grad_stats = grad_logger(encoder.named_parameters())
            optimizer.zero_grad()

            # Step 3. momentum update of target encoder
            with torch.no_grad():
                m = next(momentum_scheduler)
                for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
                    param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)

            return (float(loss), grad_stats)
        loss, grad_stats = train_step()
        (loss, grad_stats), etime = gpu_timer(train_step)
#         etime=0 #@@@
#         (loss, _new_lr, _new_wd, grad_stats), etime = gpu_timer(train_step)
        loss_meter.update(loss)
        time_meter.update(etime)

        # -- Logging
        def log_stats():
            csv_logger.log(epoch + 1, itr, loss, maskA_meter.val, maskB_meter.val, etime)
            if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                logger.info('[%d, %5d] loss: %.3f '
                            'masks: %.1f %.1f '
                            '[mem: %.2e] '
                            '(%.1f ms)'
                            % (epoch + 1, itr,
                               loss_meter.avg,
                               maskA_meter.avg,
                               maskB_meter.avg,
                               torch.cuda.max_memory_allocated() / 1024.**2,
                               time_meter.avg))

                if grad_stats is not None:
                    logger.info('[%d, %5d] grad_stats: [%.2e %.2e] (%.2e, %.2e)'
                                % (epoch + 1, itr,
                                   grad_stats.first_layer,
                                   grad_stats.last_layer,
                                   grad_stats.min,
                                   grad_stats.max))

        log_stats()

        assert not np.isnan(loss), 'loss is nan'

    # -- Save Checkpoint after every epoch
#     logger.info('avg. loss %.3f' % loss_meter.avg)
#         save_checkpoint(epoch+1)


INFO:root:Epoch 1


  0%|                                                    | 0/46 [00:00<?, ?it/s]

INFO:root:[1,     0] loss: 0.415 masks: 49.0 35.0 [mem: 1.80e+04] (391.3 ms)
INFO:root:[1,     0] grad_stats: [1.96e-02 1.06e-02] (1.02e-02, 2.24e-02)


 22%|█████████▎                                 | 10/46 [00:19<01:09,  1.94s/it]

INFO:root:[1,    10] loss: 0.240 masks: 50.2 31.8 [mem: 2.04e+04] (392.4 ms)
INFO:root:[1,    10] grad_stats: [4.54e-03 1.92e-03] (1.90e-03, 1.13e-02)


 43%|██████████████████▋                        | 20/46 [00:56<03:01,  6.99s/it]

INFO:root:[1,    20] loss: 0.204 masks: 49.1 32.3 [mem: 2.04e+04] (390.0 ms)
INFO:root:[1,    20] grad_stats: [3.36e-03 2.05e-03] (1.95e-03, 3.78e-03)


 59%|█████████████████████████▏                 | 27/46 [03:20<02:21,  7.44s/it]


KeyboardInterrupt: 