In [None]:
# check what's in the current directory
!ls

## Set up pre-trained weights
With internet enabled, PyTorch will check the cache for pre-trained weights as needed and, if the required file is not found, automatically download it. To run without internet, the pre-trained checkpoint file must exist locally, so I've uploaded a pre-trained ResNet-50 backbone as part of my private dataset. We can copy the file to the default torch home location `~/.cache/torch`, and PyTorch will find it there instead of needing to download it.

In [None]:
import os
from pathlib import Path
import shutil

readonly_checkpoint_dir = Path('/kaggle/input/cmdwheatdet/torch_cache/hub/checkpoints')
readwrite_checkpoint_dir = Path('/root/.cache/torch/hub/checkpoints')
if not (readwrite_checkpoint_dir).is_dir():
    os.makedirs(readwrite_checkpoint_dir)
checkpoint_path = sorted(readonly_checkpoint_dir.glob('*.pth'))[0]
shutil.copy(checkpoint_path, readwrite_checkpoint_dir)

## Install dependencies
I'm installing my little Python package that I put together for this Kaggle wheat detection challenge. The source code is here: https://github.com/sheromon/wheat-detection.

In [None]:
# I appreciate that when you add a Kaggle dataset, it unzips and untars everything, but it's actually inconvenient
# for me in this situation where `pip download` give you a tar.gz file and not a wheel.
# seems like `pip install --find-links <directory-path>` doesn't work when the install file, normally .tar.gz, is unzipped.
# anyway, this feels silly, but we can fix it by tarring up the untarred files for configobj.
!mkdir /kaggle/working/deps/
!cd /kaggle/input/cmdwheatdet/deps/deps/configobj-5.0.6/ && tar -czf /kaggle/working/deps/configobj-5.0.6.tar.gz configobj-5.0.6/

In [None]:
!pip install --no-index --find-links /kaggle/input/cmdwheatdet/deps/deps --find-links /kaggle/working/deps/ cmd-wheat-det

In [None]:
import pytorch_lightning as pl
pl.__version__

In [None]:
import os
from pathlib import Path
import pprint

import numpy as np
import torch

from wheat.config import load_config
from wheat.scripts import train, evaluate, predict

In [None]:
# load the default configuration file
config = load_config()

In [None]:
config['numpy_seed'] = 1234
config['data_dir'] = '/kaggle/input/global-wheat-detection'

In [None]:
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(config)

In [None]:
!nvidia-smi

In [None]:
# set pytorch lightning flags here
pl_args_dict = dict(
    max_epochs=30,
    gpus=1,
)

In [None]:
# run training!
train.train(config, pl_args_dict)

In [None]:
# find the last model checkpoint
checkpoint_dir = Path('lightning_logs/version_0/checkpoints')
checkpoint_files = checkpoint_dir.glob('*.ckpt')
# there should only be one checkpoint file saved, so just get the first match
checkpoint_path = next(checkpoint_files)
checkpoint_path

In [None]:
# run evaluation
# if this environment variable is set, detections and ground truth annotations
# will be saved to .csv files for easy loading and analysis later on
os.environ['CMD_WHEAT_OUTPUT_DIR'] = str(checkpoint_dir.parent)
evaluate.evaluate(config, {'gpus': 1}, checkpoint_path)

In [None]:
# run inference
predict.predict(config, pl_args_dict, checkpoint_path)