# Make prediction with trained DeepSDM

## Load packages

In [1]:
import time
import torch
import pytorch_lightning as pl
from types import SimpleNamespace
import mlflow
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from LitDeepSDMData import LitDeepSDMData
from LitUNetSDM import LitUNetSDM
import os
from matplotlib import pyplot as plt
import yaml

## Assign the experiment_id and run_id of mlflow of which model will be predicted

In [2]:
experiment_id = '510006563675756139'
run_id = '7985dc17d09643aea94b0ca98b4e2cc2'
logged_path = f'mlruns/{experiment_id}/{run_id}'

## Model configuration
Some settings do not matter here. Just copy these configs from 02_train_deepsdm.py

In [3]:
# load configurations
yaml_conf_logged = f'{logged_path}/artifacts/conf/DeepSDM_conf.yaml'
with open(yaml_conf_logged, 'r') as f:
    DeepSDM_conf = yaml.load(f, Loader = yaml.FullLoader)
DeepSDM_conf = SimpleNamespace(**DeepSDM_conf)

## Load parameters from the checkpoint of a trained model

Check which top-k models are logged

In [4]:
checkpoint_path = f'{logged_path}/checkpoints'
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt') and f != 'last.ckpt']
checkpoints

['epoch=91-step=1564.ckpt',
 'epoch=140-step=2397.ckpt',
 'epoch=79-step=1360.ckpt']

Read the logged average state_dict of top-k models

In [5]:
avg_state_dict = torch.load(f'{logged_path}/artifacts/top_k_avg_state_dict/top_k_avg_state_dict.pt')
model = LitUNetSDM(yaml_conf_logged, predict_attention = False).cuda()
model.load_state_dict(avg_state_dict)

<All keys matched successfully>

## Initialize datamodel

In [6]:
deep_sdm_data = LitDeepSDMData(yaml_conf_logged)

./workspace/cooccurrence_vector_v3.json
./workspace/env_information.json
./workspace/k_information.json
./virtual/v3/v3_realworldvirtual_species_information_reorganize_bias_virtual.json


## Select the species and dates for prediction.
The format of date must be YYYY-MM-01

In [8]:
# The Carpodacus_formosanus is a species unknown to the model in the training process.
# We can still predict its distribution with help of species embeddings.
# We can also assign multiple species and dates for batch predictions. 
predict_dataloaders = deep_sdm_data.predict_dataloader(
    species_list = DeepSDM_conf.training_conf['species_list_predict'], 
    date_list = DeepSDM_conf.training_conf['date_list_predict'])
# predict_dataloaders = deep_sdm_data.predict_dataloader(
#     species_list = DeepSDM_conf.training_conf['species_list_resident'], 
#     date_list = ['2018-01-01', '2018-07-01', '2018-10-01'])

Setting up dataset for prediction...


## Start prediction
The results inlcuding png images and geotiff will be output to the `output_dir`

In [9]:
model.eval()
with torch.no_grad():
    raw_results = model.predict(predict_dataloaders, datamodule = deep_sdm_data, output_dir = f'predicts/{run_id}')

sp21_2018-10-01: 8.094553232192993 seconds ........................................

## Plot prediction results

In [9]:
raw_results.keys()

dict_keys(['Abroscopus_albogularis_2018-01-01', 'Abroscopus_albogularis_2018-07-01', 'Abroscopus_albogularis_2018-10-01', 'Accipiter_trivirgatus_2018-01-01', 'Accipiter_trivirgatus_2018-07-01', 'Accipiter_trivirgatus_2018-10-01', 'Accipiter_virgatus_2018-01-01', 'Accipiter_virgatus_2018-07-01', 'Accipiter_virgatus_2018-10-01', 'Acridotheres_cristatellus_2018-01-01', 'Acridotheres_cristatellus_2018-07-01', 'Acridotheres_cristatellus_2018-10-01', 'Actinodura_morrisoniana_2018-01-01', 'Actinodura_morrisoniana_2018-07-01', 'Actinodura_morrisoniana_2018-10-01', 'Aegithalos_concinnus_2018-01-01', 'Aegithalos_concinnus_2018-07-01', 'Aegithalos_concinnus_2018-10-01', 'Aix_galericulata_2018-01-01', 'Aix_galericulata_2018-07-01', 'Aix_galericulata_2018-10-01', 'Alauda_gulgula_2018-01-01', 'Alauda_gulgula_2018-07-01', 'Alauda_gulgula_2018-10-01', 'Alcedo_atthis_2018-01-01', 'Alcedo_atthis_2018-07-01', 'Alcedo_atthis_2018-10-01', 'Alcippe_morrisonia_2018-01-01', 'Alcippe_morrisonia_2018-07-01', 'A

In [10]:
plt.imshow(
    raw_results['Prinia_flaviventris_2018-09-01'][1],
    cmap='jet',
)
plt.colorbar()

KeyError: 'Prinia_flaviventris_2018-09-01'