In [1]:
!pip install --quiet rasterio
!pip install --quiet --upgrade gdown

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.5/21.5 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [1]:
# Clone the repository:
!git clone https://github.com/spaceml-org/STARCOP.git

Cloning into 'STARCOP'...
remote: Enumerating objects: 154, done.[K
remote: Counting objects: 100% (154/154), done.[K
remote: Compressing objects: 100% (123/123), done.[K
remote: Total 154 (delta 48), reused 119 (delta 25), pack-reused 0[K
Receiving objects: 100% (154/154), 8.56 MiB | 16.72 MiB/s, done.
Resolving deltas: 100% (48/48), done.


In [17]:

# Data:
# STARCOP_mini.zip with 261 MB
!gdown https://drive.google.com/uc?id=1Qw96Drmk2jzBYSED0YPEUyuc2DnBechl -O STARCOP_mini.zip

# Models:
!gdown https://drive.google.com/uc?id=1TXFlAHO_eRdfbJGLNNt3KY0lJqjm3fdX -O multistarcop_varon.zip
!gdown https://drive.google.com/uc?id=1Kvnc_lOBn4z-xO1HFRyLZOMEldXWQvql -O hyperstarcop_magic_rgb.zip

Downloading...
From (original): https://drive.google.com/uc?id=1Qw96Drmk2jzBYSED0YPEUyuc2DnBechl
From (redirected): https://drive.google.com/uc?id=1Qw96Drmk2jzBYSED0YPEUyuc2DnBechl&confirm=t&uuid=ee5a02fa-406b-4b44-ad52-f01add1d5bd6
To: /content/STARCOP/STARCOP_mini.zip
100% 274M/274M [00:08<00:00, 32.6MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1TXFlAHO_eRdfbJGLNNt3KY0lJqjm3fdX
From (redirected): https://drive.google.com/uc?id=1TXFlAHO_eRdfbJGLNNt3KY0lJqjm3fdX&confirm=t&uuid=db3a89ca-f4ee-4317-ba7e-f616b47f470b
To: /content/STARCOP/multistarcop_varon.zip
100% 73.5M/73.5M [00:03<00:00, 23.7MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1Kvnc_lOBn4z-xO1HFRyLZOMEldXWQvql
From (redirected): https://drive.google.com/uc?id=1Kvnc_lOBn4z-xO1HFRyLZOMEldXWQvql&confirm=t&uuid=59cc4205-0967-48f5-bba7-67e8a0c9f57f
To: /content/STARCOP/hyperstarcop_magic_rgb.zip
100% 73.3M/73.3M [00:03<00:00, 23.0MB/s]


In [18]:
!unzip -q STARCOP_mini.zip
!unzip -q multistarcop_varon.zip
!unzip -q hyperstarcop_magic_rgb.zip
# clean-up
!rm *.zip

In [19]:
!ls

bash			     hyperstarcop_magic_rgb  README.md		       starcop
config.yaml		     _illustrations	     requirements_package.txt  STARCOP_mini
experiments		     LICENSE		     requirements.txt
final_checkpoint_model.ckpt  multistarcop_varon      scripts
hyperstarcop_mag1c_only      notebooks		     setup.py


In [20]:

%cd /content/STARCOP

/content/STARCOP


In [21]:

!pip --quiet install -r requirements.txt
# ignore the warnings about the library versions

In [1]:
# additional libraries and exact versions:
!pip install git+https://github.com/spaceml-org/georeader.git
!pip install torchtext==0.14.1

Collecting git+https://github.com/spaceml-org/georeader.git
  Cloning https://github.com/spaceml-org/georeader.git to /tmp/pip-req-build-u0z2loz5
  Running command git clone --filter=blob:none --quiet https://github.com/spaceml-org/georeader.git /tmp/pip-req-build-u0z2loz5
  Resolved https://github.com/spaceml-org/georeader.git to commit 813d92ebbe79aa79b99c9a4bf05b7894cbc30bf3
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [1]:

%cd /content/STARCOP

/content/STARCOP


In [11]:
%%writefile scripts/train_colab.py
import matplotlib
matplotlib.use('agg')

import os
import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer
from starcop.dataset_setup import get_dataset
from starcop.model_setup import get_model
from starcop.data.data_logger import ImageLogger
from hydra.utils import get_original_cwd
import logging
import fsspec
from starcop.validation import  run_validation
from torch.utils.data import DataLoader
import random
import numpy as np
import matplotlib.pyplot as plt

@hydra.main(version_base=None, config_path="configs", config_name="config")
def train(settings : DictConfig) -> None:
    experiment_path = os.getcwd()
    folder_relative_name = experiment_path.replace(get_original_cwd(), "") #remove beginning of path

    log = logging.getLogger(__name__)

    checkpoint_path = os.path.join(experiment_path, "checkpoint").replace("\\", "/")
    if not experiment_path.startswith("gs://"):
        os.makedirs(experiment_path, exist_ok=True)

    if not checkpoint_path.startswith("gs://"):
        os.makedirs(checkpoint_path, exist_ok=True)

    # Set up remote path
    if folder_relative_name.startswith("/"):
        folder_relative_name = folder_relative_name[1:]
    if not folder_relative_name.endswith("/"):
        folder_relative_name += "/"
    remote_path = os.path.join("gs://starcop/", folder_relative_name)

    OmegaConf.set_struct(settings, False)
    settings["experiment_path"] = experiment_path
    settings["experiment_path"] = remote_path

    log.info(f"trained models will be save at {experiment_path}")
    log.info(f"At the end of training, models will be copied to {remote_path}")

    plt.ioff()

    # LOGGING SETUP
    log.info("SETTING UP LOGGERS")
    # Change vs the main code - commented the Wandb logging:
    # wandb_logger = WandbLogger(
    #     name=settings.experiment_name,
    #     project=settings.wandb.wandb_project,
    #     entity=settings.wandb.wandb_entity,
    # )
    # # wandb.config.update(settings)
    # wandb_logger.experiment.config.update(settings)

    # settings["wandb_logger_version"] = wandb_logger.version
    OmegaConf.set_struct(settings, True)

    log.info(f"Settings dump:{OmegaConf.to_yaml(settings)}")
    log.info(f"Using matplotlib backend: {matplotlib.get_backend()}")

    # ======================================================
    # EXPERIMENT SETUP
    # ======================================================
    # Seed
    seed_everything((None if settings.seed == "None" else settings.seed))

    # DATASET SETUP
    log.info("SETTING UP DATASET")
    data_module = get_dataset(settings)
    data_module.prepare_data()

    # MODEL SETUP
    log.info("SETTING UP MODEL")
    settings.model.test = False
    settings.model.train = True
    model = get_model(settings, settings.experiment_name)

    # CHECKPOINTING SETUP
    log.info("SETTING UP CHECKPOINTING")

    metric_monitor ="val_loss"
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_path,
        save_top_k=True,
        verbose=True,
        monitor=metric_monitor,
        mode="min"
    )

    early_stop_callback = EarlyStopping(
        monitor=metric_monitor,
        patience=settings.model.early_stopping_patience,
        strict=False,
        verbose=False,
        mode="min"
    )

    # Images for logs from the first batch
    batch_train = next(iter(data_module.train_plot_dataloader(batch_size=settings.plot_samples)))
    batch_test = next(iter(data_module.test_plot_dataloader(batch_size=settings.plot_samples)))

    # il = ImageLogger(batch_train=batch_train,
    #                  batch_test=batch_test, products_plot=settings.products_plot,
    #                  input_products=settings.dataset.input_products)
    #callbacks = [checkpoint_callback, il]
    callbacks = [checkpoint_callback]

    # TRAINING SETUP
    log.info("START TRAINING")

    # See: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
    trainer = Trainer(
        fast_dev_run=False,
        logger=None, ### < change
        callbacks=callbacks,
        default_root_dir=experiment_path,
        accumulate_grad_batches=1,
        gradient_clip_val=0.0,
        benchmark=False,
        accelerator=settings.training.accelerator,
        devices=settings.training.devices,
        max_epochs=settings.training.max_epochs,
        # check_val_every_n_epoch=settings.training.val_every,
        val_check_interval=settings.training.val_check_interval,
        # Pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when
        log_every_n_steps=settings.training.train_log_every_n_steps,

        #no longer used in newer versions of the library:
        #auto_select_gpus=True,
        #auto_lr_find=False,
        #resume_from_checkpoint=checkpoint_path if settings.resume_from_checkpoint else None,
    )

    trainer.fit(model, data_module)

    # Save model
    trainer.save_checkpoint(os.path.join(experiment_path, "final_checkpoint_model.ckpt"))

    print("Training finished!")

if __name__ == "__main__":
    train()


Overwriting scripts/train_colab.py


In [12]:

import omegaconf
import pylab as plt
import torch
import omegaconf
import fsspec
import os
import json
import pandas as pd
import numpy as np
from starcop.torch_utils import to_device
import starcop.plot as starcoplot
from mpl_toolkits.axes_grid1 import make_axes_locatable

from starcop.data.datamodule import Permian2019DataModule
from starcop.models.model_module import ModelModule
from starcop.validation import run_validation

device = torch.device("cuda:0")
fs = fsspec.filesystem("gs")
config_general = omegaconf.OmegaConf.load("scripts/configs/config.yaml")
root_folder = "/content/STARCOP/STARCOP_mini"

In [14]:
!python -m scripts.train_colab dataset.input_products=["mag1c","TOA_AVIRIS_640nm","TOA_AVIRIS_550nm","TOA_AVIRIS_460nm"] model.model_type='unet_semseg' model.pos_weight=1 experiment_name="HyperSTARCOP_magic_rgb_DEMO" dataloader.num_workers=4 dataset.use_weight_loss=True training.val_check_interval=0.5 training.max_epochs=5 products_plot=["rgb_aviris","mag1c","label","pred","differences"] dataset.weight_sampling=True dataset.train_csv="train_mini10.csv" dataset.root_folder=$root_folder


[2024-05-12 04:06:59,529][__main__][INFO] - trained models will be save at /content/STARCOP/experiments/HyperSTARCOP_magic_rgb_DEMO/2024-05-12_04-06
[2024-05-12 04:06:59,530][__main__][INFO] - At the end of training, models will be copied to gs://starcop/experiments/HyperSTARCOP_magic_rgb_DEMO/2024-05-12_04-06/
[2024-05-12 04:06:59,530][__main__][INFO] - SETTING UP LOGGERS
[2024-05-12 04:06:59,534][__main__][INFO] - Settings dump:experiment_name: HyperSTARCOP_magic_rgb_DEMO
seed: None
resume_from_checkpoint: false
wandb:
  wandb_project: your_wandb_project
  wandb_entity: your_wandb_entity
  images_logging: wandb
dataloader:
  batch_size: 32
  num_workers: 4
products_plot:
- rgb_aviris
- mag1c
- label
- pred
- differences
plot_samples: 8
dataset:
  input_products:
  - mag1c
  - TOA_AVIRIS_640nm
  - TOA_AVIRIS_550nm
  - TOA_AVIRIS_460nm
  output_products:
  - labelbinary
  use_weight_loss: true
  weight_loss: weight_mag1c
  training_size:
  - 128
  - 128
  training_size_overlap:
  - 64
