# Make prediction with trained DeepSDM

## Load packages

In [1]:
import time
import torch
import pytorch_lightning as pl
from types import SimpleNamespace
import mlflow
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from LitDeepSDMData import LitDeepSDMData
from LitUNetSDM import LitUNetSDM
import os
from matplotlib import pyplot as plt
import yaml
import copy
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


## Assign the experiment_id and run_id of mlflow of which model will be predicted

In [2]:
experiment_id = '402040678262354481'
run_id = 'de206f42cb374a16815b52c5b1b07feb'
logged_path = f'mlruns/{experiment_id}/{run_id}'

## Model configuration
Some settings do not matter here. Just copy these configs from 02_train_deepsdm.py

In [3]:
# load configurations
yaml_conf_logged = f'{logged_path}/artifacts/conf/DeepSDM_conf.yaml'
with open(yaml_conf_logged, 'r') as f:
    DeepSDM_conf = yaml.load(f, Loader = yaml.FullLoader)
DeepSDM_conf = SimpleNamespace(**DeepSDM_conf)

## Load parameters from the checkpoint of a trained model

Check which top-k models are logged

In [4]:
checkpoint_path = f'{logged_path}/checkpoints'
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt') and f != 'last.ckpt']
checkpoints

['epoch=0-step=1.ckpt', 'epoch=1-step=2.ckpt', 'epoch=2-step=3.ckpt']

Read the logged average state_dict of top-k models

In [6]:
avg_state_dict = torch.load(f'{logged_path}/artifacts/top_k_avg_state_dict/top_k_avg_state_dict.pt')
model = LitUNetSDM(yaml_conf_logged).cuda()
model.load_state_dict(avg_state_dict)

<All keys matched successfully>

## Initialize datamodel

In [8]:
deep_sdm_data = LitDeepSDMData(yaml_conf_logged)

./workspace/env_information.json
./workspace/species_information.json
./workspace/k_information.json
./workspace/cooccurrence_vector.json


## Select the species and dates for prediction.
The format of date must be YYYY-MM-01

In [9]:
# The Carpodacus_formosanus is a species unknown to the model in the training process.
# We can still predict its distribution with help of species embeddings.
# We can also assign multiple species and dates for batch predictions. 
predict_dataloaders = deep_sdm_data.predict_dataloader(species_list = DeepSDM_conf.training_conf['species_list_predict'], date_list = DeepSDM_conf.training_conf['date_list_predict'])

Setting up dataset for prediction...


## Start prediction
The results inlcuding png images and geotiff will be output to the `output_dir`

In [10]:
model.eval()
with torch.no_grad():
    raw_results = model.predict(predict_dataloaders, datamodule = deep_sdm_data, output_dir = f'predicts/{run_id}')

Carpodacus_formosanus_2018-01-01: 7.589578628540039 seconds ..................

In [11]:
raw_results.keys()

dict_keys(['Passer_cinnamomeus_2018-01-01', 'Carpodacus_formosanus_2018-01-01'])

In [None]:
plt.imshow(
    raw_results['Prinia_flaviventris_2018-09-01'][1],
    cmap='jet',
)
plt.colorbar()