In [1]:
import comet_ml

import os
import sys
import pdb
import timeit

from pathlib import Path
from typing import Any, Dict, Tuple, Type, cast


import hydra
from omegaconf import OmegaConf, DictConfig

from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import PyTorchProfiler

from models.moco2_module import MocoV2
from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler

from dataset.ssl.ssl_geolife_datamodule import GeoLifeDataModule
from models.ssl_online import SSLOnlineEvaluator
from models.utils import InputMonitor

  from .autonotebook import tqdm as notebook_tqdm
  stdout_func(


# Hyperparams

In [2]:
exp_opts = {'use_ffcv_loader': False,
            'batch_size': 32,
            'num_workers': 2,
            'data_dir': "/network/scratch/s/sara.ebrahim-elkafrawy/small_geo_data/",
            'log_dir': "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/exps/debug_exp",
            'config_file': "debug-exp.yaml",
            'random_init_path': "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/ckpts/resnet50_random_init",
            'ckpt_file': "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/ckpts/first_epoch_57_bs64_numworker4.ckpt",
            'max_epochs': 2,
            'gpus': 1,
#             'data':{
#                   'loaders':
#                             {
#                                 'num_workers': 2,   # 32
#                                 'batch_size': 32   # 256
#                             },
#                   'datatype': "img",  
#                   'bands':  ["rgb"],  
#                   'splits':{ 
#                             'train': val,  
#                             'val': val,
#                             'test': test,
#                           },
#                   'transforms':{
#                         - name: crop
#                           ignore: True
#                           p: 0.5
#                           center: true # disable randomness, crop around the image's center
#                           size: [256, 256]
#                         - name: resize
#                           ignore: False
#                           size: [224,224]
#                         - name: hflip
#                           ignore: "val"
#                           p: 0.5
#                         - name: vflip
#                           ignore: "val"
#                           p: 0.5    
#                         - name: normalize
#                           ignore: False
#                           means: [106.94150444, 114.87315837, 104.52826283]   # [0.4194, 0.4505, 0.4099], 
#                           std: [50.59516823, 44.13010964, 41.84300729]        # [0.20, 0.1759, 0.1694]
#                           band: "rgb"
#                         - name: normalize
#                           ignore: True
#                           means: [131.0458]     # 0.5139
#                           std: [53.0884]        # 0.2082
#                           band: "near_ir"
#                         - name: normalize
#                           ignore: True
#                           means: [298.1693]
#                           std: [459.3285]
#                           band: "altitude"
#                         - name: normalize
#                           ignore: True
#                           means: [17.4200]
#                           std: [9.5173]
#                           band: "landcover"
#                               }
#                     },
            
            'ssl':
                    {'learning_rate': 0.03,
                      'ssl_pretrained': False,
                      'num_keys': 3,
                      'schedule': [120, 160],
                      'base_encoder': 'resnet50',
                      'emb_dim': 128,
                      'num_workers': 32,
                      'num_negatives': 16384,
                      'encoder_momentum': 0.999,
                      'softmax_temperature': 0.07,
                      'momentum': 0.9,
                      'weight_decay': 1e-4,
                      'batch_size': 256,
                      'use_ddp': False,
                      'use_ddp2': False,
                    #   accelerator: gpu
                    #   strategy: ddp
                    #   devices: 1
                    #   num_nodes: 1
                      'online_max_epochs': 2,
                      'online_val_every_n_epoch': 1},}


In [3]:
exp_opts = cast(DictConfig, exp_opts)

In [4]:
exp_configs = OmegaConf.merge({}, exp_opts)

In [5]:
exp_configs.ssl.learning_rate

0.03

# GeoLife dataset

In [6]:
import os
import sys
import inspect
import torch

CURR_DIR = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
PARENT_DIR = os.path.dirname(CURR_DIR)
sys.path.insert(0, CURR_DIR)

In [7]:
from dataset.ffcv_loader.dataset_ffcv import GeoLifeCLEF2022DatasetFFCV
from dataset.pytorch_dataset import GeoLifeCLEF2022Dataset

In [8]:
save_dir = "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/tmp_geo"

In [10]:
train_dataset = GeoLifeCLEF2022Dataset(
    exp_configs.data_dir,
    "train",
    region="both",
    patch_data="all", # self.opts.data.bands,
    use_rasters=False,
    patch_extractor=None,
    transform=None,
    target_transform=None,
    )

val_dataset = GeoLifeCLEF2022Dataset(
        exp_configs.data_dir,
        "val",
        region="both",
        patch_data="all", #self.opts.data.bands,
        use_rasters=False,
        patch_extractor=None,
        transform=None,
        target_transform=None,
    )

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               num_workers=exp_configs.num_workers, 
                                               batch_size=exp_configs.batch_size,
                                               pin_memory=True,
                                               drop_last=True,
                                               shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                                              num_workers=exp_configs.num_workers, 
                                              batch_size=exp_configs.batch_size,
                                              pin_memory=True,
                                              drop_last=False,
                                              shuffle=False)

# Model loading

In [11]:
model = MocoV2(exp_configs)



# PL Trainer

## callbacks

In [12]:
checkpoint_callback = ModelCheckpoint(dirpath=exp_configs.log_dir, filename='{epoch}')

lr_monitor = LearningRateMonitor(logging_interval="epoch")

moco_scheduler = MocoLRScheduler(
    initial_lr=exp_configs.ssl.learning_rate, 
    schedule=exp_configs.ssl.schedule, 
    max_epochs=exp_configs.max_epochs)

# online_evaluator = SSLOnlineEvaluator(
#     exp_configs,
#     data_dir=exp_configs.data_dir,
#     z_dim=model.mlp_dim,
# )

trainer_args = {}
trainer_args["callbacks"] = [
    checkpoint_callback,
    lr_monitor,
    moco_scheduler,
#     online_evaluator,
]

In [13]:
trainer = pl.Trainer(
        enable_progress_bar=True,
        default_root_dir=exp_configs.log_dir,
        max_epochs=exp_configs.max_epochs,
        gpus=exp_configs.gpus,
#         accelerator=exp_configs.ssl.accelerator,
#         devices=exp_configs.ssl.devices, 
#         num_nodes=exp_configs.ssl.num_nodes, 
#         strategy=exp_configs.ssl.strategy,
#         logger=comet_logger,
#         log_every_n_steps=trainer_args["log_every_n_steps"],
#         callbacks=trainer_args["callbacks"],
        overfit_batches=0.0,  ## make sure it is 0.0 when training
        precision=16,
        accumulate_grad_batches=int(exp_configs.batch_size / 4),
#         progress_bar_refresh_rate=0,
        #         strategy="ddp_find_unused_parameters_false",
        #         distributed_backend='ddp',
        #         profiler=profiler,
    )

Multiprocessing is handled by SLURM.
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
start = timeit.default_timer()
trainer.fit(model, 
            train_dataloaders=train_dataloader, 
            val_dataloaders=val_dataloader, 
            ckpt_path=exp_configs.ckpt_file,)
stop = timeit.default_timer()

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")


Python 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:56:21) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.4.0 -- An enhanced Interactive Python. Type '?' for help.


check self.strategy.restore_checkpoint_after_setup



In [1]:  exit


Restoring states from the checkpoint path at /home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/ckpts/first_epoch_57_bs64_numworker4.ckpt





LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                | Params
---------------------------------------------------------
0 | transforms_img   | DataAugmentationRGB | 0     
1 | transforms_jit   | DataAugmentationRGB | 0     
2 | transforms_gauss | DataAugmentationRGB | 0     
3 | encoder_q        | Sequential          | 23.5 M
4 | encoder_k        | Sequential          | 23.5 M
5 | heads_q          | ModuleList          | 13.4 M
6 | heads_k          | ModuleList          | 13.4 M
---------------------------------------------------------
36.9 M    Trainable params
36.9 M    Non-trainable params
73.8 M    Total params
147.536   Total estimated model params size (MB)


MisconfigurationException: You restored a checkpoint with current_epoch=58, but you have set Trainer(max_epochs=2).