# Make prediction with trained DeepSDM

## Load packages

In [6]:
import time
import torch
import pytorch_lightning as pl
from types import SimpleNamespace
import mlflow
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from LitDeepSDMData import LitDeepSDMData
from LitUNetSDM import LitUNetSDM
import os
from matplotlib import pyplot as plt
import yaml
import copy
from collections import OrderedDict

## Assign the experiment_id and run_id of mlflow of which model will be predicted

In [14]:
experiment_id = '923241236702221005'
run_id = 'd227de26b8784f629682acd0d0922682'
logged_path = f'mlruns/{experiment_id}/{run_id}'

## Model configuration
Some settings do not matter here. Just copy these configs from 02_train_deepsdm.py

In [15]:
# load configurations
with open(f'{logged_path}/artifacts/conf/DeepSDM_conf.yaml', 'r') as f:
    DeepSDM_conf = yaml.load(f, Loader = yaml.FullLoader)
DeepSDM_conf = SimpleNamespace(**DeepSDM_conf)

# packed the species lists and date lists for training
info = SimpleNamespace(**dict(
    env_list = sorted(DeepSDM_conf.env_list),
    non_normalize_env_list = sorted(DeepSDM_conf.non_normalize_env_list),
    species_list = sorted(DeepSDM_conf.species_list_train),
    species_list_val = sorted(DeepSDM_conf.species_list_train),
    species_list_smoothviz = sorted(DeepSDM_conf.species_list_smoothviz),
    date_list = sorted(DeepSDM_conf.date_list_train),
    date_list_val = sorted(DeepSDM_conf.date_list_train),
    date_list_smoothviz = sorted(DeepSDM_conf.date_list_smoothviz)
))

conf = SimpleNamespace(**DeepSDM_conf.conf)

## Load parameters from the checkpoint of a trained model

Check which top-k models are logged

In [16]:
checkpoint_path = f'{logged_path}/checkpoints'
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt') and f != 'last.ckpt']
checkpoints

['epoch=27-step=2072.ckpt',
 'epoch=21-step=1628.ckpt',
 'epoch=25-step=1924.ckpt']

Read the logged average state_dict of top-k models

In [17]:
avg_state_dict = torch.load(f'{logged_path}/artifacts/avg_state_dict/avg_state_dict.pt')

model = LitUNetSDM(info = info, conf = conf).cuda()
model.load_state_dict(avg_state_dict)

## Initialize datamodel

In [12]:
deep_sdm_data = LitDeepSDMData(info = info, conf = conf)

./workspace/env_information.json
./workspace/species_information.json
./workspace/k_information.json
./workspace/cooccurrence_vector.json


## Select the species and dates for prediction.
The format of date must be YYYY-MM-01

In [18]:
# The Carpodacus_formosanus is a species unknown to the model in the training process.
# We can still predict its distribution with help of species embeddings.
# We can also assign multiple species and dates for batch predictions. 
predict_dataloaders = deep_sdm_data.predict_dataloader(species_list = DeepSDM_conf.species_list_predict, date_list = DeepSDM_conf.date_list_predict)

Setting up dataset for prediction...


## Start prediction
The results inlcuding png images and geotiff will be output to the `output_dir`

In [20]:
model.eval()
with torch.no_grad():
    raw_results = model.predict(predict_dataloaders, datamodule = deep_sdm_data, output_dir = f'predicts/{run_id}')

Prinia_flaviventris_2018-12-01: 8.210442543029785 seconds .........................

In [None]:
raw_results.keys()

In [None]:
plt.imshow(
    raw_results['Prinia_flaviventris_2018-09-01'][1],
    cmap='jet',
)
plt.colorbar()