# Train deepBlink

In [None]:
import requests
import subprocess
import yaml

from pathlib import Path

In [None]:
# Define path to outputs folder.
outputs_path = Path().absolute().parent / 'outputs'

# Define paths to datasets.
datasets_path = outputs_path / 'datasets'
piscis_datasets_path = datasets_path / 'piscis'

# Define path to deepBlink models.
deepblink_models_path = outputs_path / 'deepblink_models'
deepblink_models_path.mkdir(parents=True, exist_ok=True)

# Define path to temporary folder.
tmp_path = Path('tmp')
tmp_path.mkdir(parents=True, exist_ok=True)

In [None]:
!pip install deepblink

### Download pretrained deepBlink models.

In [None]:
# Define the URL for the Figshare API.
api_url = f'https://api.figshare.com/v2/articles/12958127'

# Get a list of files from Figshare.
files = requests.get(api_url).json()['files']

# Loop through files and download.
for file in files:
    file_name = file['name']
    if file_name.startswith('deepblink'):
        download_url = file['download_url']
        response = requests.get(download_url, stream=True)
        response.raise_for_status()
        with open(deepblink_models_path / file_name, 'wb') as handle:
            for block in response.iter_content(1024):
                handle.write(block)

### Train deepBlink with the Piscis dataset.

In [None]:
# Generate training config file.
subprocess.run(['deepblink', 'config'], cwd=tmp_path)

# Load training config file.
with open(tmp_path / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Set dataset name and save directory.
config['dataset_args']['name']['value'] = str(piscis_datasets_path / '20230905' / 'combined' / '20230905_combined.npz')
config['savedir']['value'] = str(deepblink_models_path)
config['train_args']['epochs']['value'] = 400

In [None]:
# Define grid cell sizes.
cell_sizes = (1, 2, 4)

# Train deepBlink for each grid cell size.
for cell_size in cell_sizes:

    # Set grid cell size.
    config['dataset_args']['cell_size']['value'] = cell_size
    config['run_name']['value'] = f'20230905_cell_size_{cell_size}'

    config_path = tmp_path / f'config_cell_size_{cell_size}.yaml'
    with open(config_path, 'w') as f:
        yaml.safe_dump(config, f)

    with open(tmp_path / f'stdout_cell_size_{cell_size}.txt', 'w') as stdout_file, open(tmp_path / f'stderr_cell_size_{cell_size}.txt', 'w') as stderr_file:
        subprocess.run(['deepblink', 'train', '-c', config_path], stdout=stdout_file, stderr=stderr_file)