# Make prediction with trained DeepSDM

## Load packages

In [1]:
import torch
import pytorch_lightning as pl
from types import SimpleNamespace
import mlflow
# from pytorch_lightning.strategies import DDPStrategy
# from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from LitDeepSDMData_prediction import LitDeepSDMData
from LitUNetSDM_prediction import LitUNetSDM
import os
from matplotlib import pyplot as plt
import yaml
import torch.multiprocessing as mp

  from .autonotebook import tqdm as notebook_tqdm


## Assign the experiment_id and run_id of mlflow of which model will be predicted

In [2]:
experiment_id = '115656750127464383'
run_id = 'e52c8ac9a3e24c75ac871f63bbdea060'
logged_path = os.path.join('./mlruns', experiment_id, run_id)

## Model configuration
Some settings do not matter here. Just copy these configs from 02_train_deepsdm.py

In [3]:
# load configurations
yaml_conf_logged = os.path.join(logged_path, 'artifacts', 'conf', 'DeepSDM_conf.yaml')
with open(yaml_conf_logged, 'r') as f:
    DeepSDM_conf = yaml.load(f, Loader = yaml.FullLoader)
DeepSDM_conf = SimpleNamespace(**DeepSDM_conf)

# For 1 GPU

## Load parameters from the checkpoint of a trained model

Check which top-k models are logged

In [None]:
checkpoint_path = os.path.join(logged_path, 'checkpoints')
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt') and f != 'last.ckpt']
checkpoints

 - Read the logged average state_dict of top-k models
 - In function `LitUNetSDM`, set `predict_attention = True` to write the attention score map (default be False)

In [None]:
ckpt = torch.load(os.path.join(logged_path, 'checkpoints', 'epoch=25-step=2938.ckpt'))

In [None]:
device = 'cuda:0'
avg_state_dict = torch.load(os.path.join(logged_path, 'artifacts', 'top_k_avg_state_dict', 'top_k_avg_state_dict.pt'), map_location = torch.device(device))
model = LitUNetSDM(custom_device = device, yaml_conf = yaml_conf_logged, predict_attention = False)
model.load_state_dict(avg_state_dict)

## Initialize datamodel

In [None]:
deep_sdm_data = LitDeepSDMData(device = device, yaml_conf = yaml_conf_logged)

## Select the species and dates for prediction.
 - The format of date must be YYYY-MM-01
 - If `species_list` or `date_list` is too big to load, sperate it and run this entire .ipynb multiple times

In [None]:
# We can still predict distribution with help of species embeddings.
# We can also assign multiple species and dates for batch predictions. 
predict_dataloaders = deep_sdm_data.predict_dataloader(
    species_list = DeepSDM_conf.training_conf['species_list_predict'], 
#     species_list = ['Yuhina_brunneiceps'], 
    date_list = DeepSDM_conf.training_conf['date_list_predict'])
# predict_dataloaders = deep_sdm_data.predict_dataloader(
#     species_list = ['Carpodacus_formosanus'], 
#     date_list = ['2018-01-01', '2018-07-01', '2018-10-01'])
# predict_dataloaders = deep_sdm_data.predict_dataloader(
#     species_list = DeepSDM_conf.training_conf['species_list_predict'], 
#     date_list = DeepSDM_conf.training_conf['date_list_predict'])

## Start prediction
The results inlcuding png images and geotiff will be output to the `output_dir`

In [None]:
model.eval()
with torch.no_grad():
    raw_results = model.predict(predict_dataloaders, datamodule = deep_sdm_data, output_dir = os.path.join('./predicts', run_id))

# For Multiple GPUs

## Paramters to specify

In [4]:
num_gpu = 4

# Choose the split type ('species' or 'date')
split_type = 'species'  # Change to 'date' if splitting by date

# Save attention map or not
predict_attention = True

## Load parameters from the checkpoint of a trained model

Check which top-k models are logged

In [5]:
checkpoint_path = os.path.join(logged_path, 'checkpoints')
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt') and f != 'last.ckpt']
checkpoints

['epoch=22-step=2231.ckpt',
 'epoch=42-step=4171.ckpt',
 'epoch=16-step=1649.ckpt']

## Start prediction
The results inlcuding png images and geotiff will be output to the `output_dir`

In [6]:
def load_state_dict(device):
    """Load the model state dictionary and average over checkpoints if necessary."""
    avg_state_dict = None
    state_dict_path = os.path.join(logged_path, 'artifacts', 'top_k_avg_state_dict', 'top_k_avg_state_dict.pt')
    if os.path.exists(state_dict_path):
        avg_state_dict = torch.load(state_dict_path, map_location=torch.device(device))
        print(f'Load state dictionary from {state_dict_path}')
    else:
        for ckp_file in [os.path.join(checkpoint_path, ckp) for ckp in checkpoints]:
            state_dict = torch.load(ckp_file, map_location=torch.device(device))['state_dict']
            if avg_state_dict is None:
                avg_state_dict = state_dict
            else:
                for key in state_dict:
                    avg_state_dict[key] += state_dict[key]
        for key in avg_state_dict:
            avg_state_dict[key] = avg_state_dict[key].float() / len(checkpoints)
        print(f'Load state dictionary from average of {checkpoints}')
    return avg_state_dict


def split_tasks(species_list, date_list, num_gpus):
    """Split tasks evenly across GPUs, with each task being a (species, date) pair."""
    all_tasks = [(sp, dt) for sp in species_list for dt in date_list]
    avg_tasks = len(all_tasks) // num_gpus
    return [all_tasks[i * avg_tasks:(i + 1) * avg_tasks] for i in range(num_gpus)]


def run_prediction_on_gpu(gpu_id, tasks, predict_attention=False, device_prefix='cuda'):
    """Run predictions for the assigned tasks on the specified GPU."""
    device = f"{device_prefix}:{gpu_id}"
    avg_state_dict = load_state_dict(device)
    
    model = LitUNetSDM(custom_device=device, yaml_conf=yaml_conf_logged, predict_attention=predict_attention)
    model.load_state_dict(avg_state_dict)
    model.eval()

    deep_sdm_data = LitDeepSDMData(device=device, yaml_conf=yaml_conf_logged)
    
    for sp, date in tasks:
#         if os.path.exists(f'./predicts/{run_id}/tif/{sp}_{date}_predict.tif'):
#             continue
        with torch.no_grad():
            dataloader = deep_sdm_data.predict_dataloader(species_list=[sp], date_list=[date])
            model.predict(dataloader, datamodule=deep_sdm_data, output_dir=os.path.join('./predicts', run_id))

In [None]:
if __name__ == "__main__":
    num_gpus = 4  # Number of GPUs
    species_list = DeepSDM_conf.training_conf['species_list_predict']
    date_list = DeepSDM_conf.training_conf['date_list_predict'][:80]
    
    # Split tasks across GPUs
    tasks_per_gpu = split_tasks(species_list, date_list, num_gpus)
    
    # Multiprocessing for running predictions on multiple GPUs
    processes = []
    for gpu_id in range(num_gpus):
        p = mp.Process(target=run_prediction_on_gpu, args=(gpu_id, tasks_per_gpu[gpu_id], predict_attention))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()

Load state dictionary from ./mlruns/115656750127464383/e52c8ac9a3e24c75ac871f63bbdea060/artifacts/top_k_avg_state_dict/top_k_avg_state_dict.ptLoad state dictionary from ./mlruns/115656750127464383/e52c8ac9a3e24c75ac871f63bbdea060/artifacts/top_k_avg_state_dict/top_k_avg_state_dict.ptLoad state dictionary from ./mlruns/115656750127464383/e52c8ac9a3e24c75ac871f63bbdea060/artifacts/top_k_avg_state_dict/top_k_avg_state_dict.pt


Load state dictionary from ./mlruns/115656750127464383/e52c8ac9a3e24c75ac871f63bbdea060/artifacts/top_k_avg_state_dict/top_k_avg_state_dict.pt
./workspace/cooccurrence_vector.json./workspace/cooccurrence_vector.json

./workspace/cooccurrence_vector.json
./workspace/env_information.json./workspace/env_information.json

./workspace/k_information.json./workspace/k_information.json

./workspace/species_information.json./workspace/species_information.json

./workspace/env_information.json
./workspace/k_information.json./workspace/cooccurrence_vector.json

./workspace/sp

## Plot prediction results

In [None]:
# raw_results.keys()

In [None]:
# plt.imshow(
#     raw_results['Prinia_flaviventris_2018-09-01'][1],
#     cmap='jet',
# )
# plt.colorbar()