# Install required packages
- Torch 2.1.2 with Cuda 12.1
- Torchvision 0.16.2 with Cuda 12.1
- MMCV 2.1.0 with cuda 12.1
- MMdetection 3.0.3
- TIMM

# Download weights file
- Vision Transformer huge (630M) / 16

In [None]:
!pip install torch==2.1.2+cu121 torchvision==0.16.2+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html
!pip install -U openmim
!mim install mmdet
!pip install timm

!wget s3.amazonaws.com/dataforgood-fb-data/forests/v1/models/saved_checkpoints/SSLhuge_satellite.pth

In [None]:
# Check to make sure that all imports are correct
import shutil
import os

import torch
print(torch.__version__, torch.cuda.is_available())

# Check MMDetection installation
import mmdet
print(mmdet.__version__)

import mmcv
print('mmcv', mmcv.__version__)

import mmengine
print(mmengine.__version__)

# Check mmcv installation
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
print(get_compiling_cuda_version())
print(get_compiler_version())

# Download and unzip the github repository

In [None]:
!wget https://github.com/wri/mmdetection-satellite-dinov2/archive/refs/heads/main.zip
!unzip -o main.zip

# Mount the google drive and copy the training data, move the backbone weights

In [None]:
# Mount google drive

import os
import shutil
from google.colab import drive
drive.mount('/content/drive')
!unzip drive/MyDrive/coco/tree-may-3.zip -d mmdetection-satellite-dinov2-main/data/coco/
shutil.move("SSLhuge_satellite.pth", "mmdetection-satellite-dinov2-main/models/SSLhuge_satellite.pth")

# Update the coco JSON files as necessary (Todo: remove)

In [None]:
## Modify MMDetection geojson to remove labels for images that have no labels (?)
import json
from pprint import pprint
import numpy as np

data_path = 'mmdetection-satellite-dinov2-main/data/coco/tree-may-3/'
with open(f'{data_path}/train2.json') as f:
    d = json.load(f)

d['categories'][0]['id'] = 0
d['categories'] = [{'id': 0, 'name': 'tree'}]

import os
exists = [item for item in os.listdir('mmdetection-satellite-dinov2-main/data/coco/tree-may-3/train/')]
does_not_exist = []
bad_labels = ['bishop_2020_7.tif', 'riverside_2020_58.tif', 'claremont_2016_33.tif', 'chico_2020_36.tif', 'TREE_296500_5039769.tif',
             'b47cb6e0-arcos-11-14336_1024.png', 'a081c588-intl-intl-13824_14336.png','611eb561-intl-intl-13824_14848.png',
             '34785f4c-Lari-basemap30-1536_10240.png', '057d64f7-Lari-basemap30-8192_12288.png',
             'a47d2bdb-Lari-basemap27-5632_0.png', '38a007cb-arcos-11-13312_0.png', 'chico_2020_26.tif',
             'long_beach_2016_31.tif', 'long_beach_2016_42.tif', 'long_beach_2016_43.tif',
 'long_beach_2016_59.tif', 'long_beach_2016_79.tif', 'long_beach_2018_20.tif',
 'long_beach_2018_21.tif', 'long_beach_2018_41.tif', 'long_beach_2018_57.tif', 'claremont_2016_3.tif',
             'palm_springs_2016_13.tif',
 'palm_springs_2016_28.tif','palm_springs_2016_34.tif', 'palm_springs_2016_35.tif',
 'palm_springs_2016_60.tif','palm_springs_2016_68.tif', 'palm_springs_2018_25.tif',
 'palm_springs_2018_28.tif','palm_springs_2018_60.tif' 'palm_springs_2018_61.tif',
 'palm_springs_2018_78.tif','palm_springs_2018_89.tif', 'palm_springs_2018_93.tif',
 'palm_springs_2020_19.tif','palm_springs_2020_56.tif',
 'palm_springs_2020_61.tif','palm_springs_2020_8.tif', 'palm_springs_2020_98.tif']

all_neon = ['BART', 'GRSM', 'TALL', 'WREF', 'LENO', 'RMNP', 'TREE', 'CLBJ', 'OSBS','DEJU', 'NIWO'] 

neon = ['BART', 'GRSM', 'TALL', 'WREF', 'LENO', 'WREF']#, 'RMNP', 'TREE'  'CLBJ', 'OSBS', 'NIWO']
naips = ['claremont', 'long_beach', 'chico', 'santa_monica', 'riverside', 'palm_springs'] # 'eureka'

downsample_pretrain = True
downsample_train= False
pretrain = False

num_removed = 0
num_kept = 0
for key in d['images']:
    #print(key)

    if pretrain:
        isin = 0
        for idx in all_neon + naips:
            if idx in key['file_name']:
                isin += 1
        if isin == 0:
            does_not_exist.append(key['id'])

    if not key['file_name'] in exists:
        does_not_exist.append(key['id'])
    if key['file_name'] in bad_labels:
        print(f"Removing {key['file_name']}")
        does_not_exist.append(key['id'])
        
    for idx in neon:
        if idx in key['file_name']:
            if downsample_pretrain:
            # Remove 1/2 of the NEON data, since there is so much of it
            # 428 neon images, 1/2
                
                sampler = np.random.choice(4, 1)
                if sampler >= 3:
                    print(f"Removing {key['file_name']}")
                    does_not_exist.append(key['id'])
                    num_removed += 1
                else:
                    num_kept += 1
            if downsample_train:
                # Remove 2/3 of NEON data for the main training
                # 88 left for the main training
                sampler = np.random.choice(4, 1)
                if sampler >= 1:
                    #rint(f"Removing {key['file_name']}")
                    does_not_exist.append(key['id'])

for i in (d['annotations']):
    i['category_id'] = 0

for annot in d['annotations']:
    minx, miny = (annot['bbox'][0], annot['bbox'][1])
    #maxx, maxy = (annot['bbox'][2], annot['bbox'][3])
    if minx < 0:
        annot['bbox'][0] = 0.
    if miny < 0:
        annot['bbox'][1] = 0.

d['images'] = [x for x in d['images'] if x['id'] not in does_not_exist]
print(len(d['images']))

with open('mmdetection-satellite-dinov2-main/data/coco/tree-may-3/train.json', 'w') as f:
   json.dump(d, f)

with open(f'{data_path}/val2.json') as f:
    d = json.load(f)

d['categories'][0]['id'] = 0
d['categories'] = [{'id': 0, 'name': 'tree'}]

for i in (d['annotations']):
    i['category_id'] = 0

for annot in d['annotations']:
    minx, miny = (annot['bbox'][0], annot['bbox'][1])

with open('mmdetection-satellite-dinov2-main/data/coco/tree-may-3/val.json', 'w') as f:
    json.dump(d, f)

# Copy edited files to 

In [None]:
import mmdet
import shutil
mmdet_location = "/".join(mmdet.__file__.split("/")[:-1]) + "/"

# Move backbone files

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/backbones/SSLVisionTransformer.py",
            f'{mmdet_location}models/backbones/SSLVisionTransformer.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/backbones/vit_rvsa_mtp_branches.py",
            f'{mmdet_location}models/backbones/vit_rvsa_mtp_branches.py')


shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/backbones/__init__.py",
            f'{mmdet_location}models/backbones/__init__.py')

# Move neck files

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/necks/__init__.py",
            f'{mmdet_location}models/necks/__init__.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/necks/fpn.py",
            f'{mmdet_location}models/necks/fpn.py')

# Move head files
shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/dense_heads/__init__.py",
            f'{mmdet_location}models/dense_heads/__init__.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/dense_heads/crpn_head.py",
            f'{mmdet_location}models/dense_heads/crpn_head.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/dense_heads/cascade_rpn_head.py",
            f'{mmdet_location}models/dense_heads/cascade_rpn_head.py')


#Move assigners
shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/task_modules/assigners/__init__.py",
            f'{mmdet_location}models/task_modules/assigners/__init__.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/task_modules/assigners/dynamic_assigner.py",
            f'{mmdet_location}models/task_modules/assigners/dynamic_assigner.py')


shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/task_modules/assigners/ranking_assigner.py",
            f'{mmdet_location}models/task_modules/assigners/ranking_assigner.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/task_modules/assigners/iou2d_calculator.py",
            f'{mmdet_location}models/task_modules/assigners/iou2d_calculator.py')

shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/task_modules/assigners/hierarchical_assigner.py",
            f'{mmdet_location}models/task_modules/assigners/hierarchical_assigner.py')


shutil.copy("/content/mmdetection-satellite-dinov2-main/mmdet/models/layers/transformer/dino_layers.py",
            f'{mmdet_location}models/layers/transformer/dino_layers.py')

In [None]:
import sys
import os.path as osp
sys.path.append(osp.abspath('mmdetection-satellite-dinov2-main/'))

from mmdet.apis import init_detector
from mmengine.runner import Runner
from mmengine.config import Config, DictAction

config='mmdetection-satellite-dinov2-main/projects/ViTDet/configs/vitdet-codetr.py'
checkpoint = 'mmdetection-satellite-dinov2-main/models/SSLhuge_satellite.pth'
cfg = Config.fromfile(config)
cfg['model']['backbone']['init_cfg']['checkpoint'] = checkpoint
cfg['resume'] = True
cfg['train_cfg']['max_epochs'] = 80
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(config))[0] + '-codetr')

runner = Runner.from_cfg(cfg)
runner.train()