<a href="https://colab.research.google.com/github/xiazeyu/PyTorchFire/blob/jupyter-examples/examples/calibration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install pytorchfire and dependencies

%pip install 'pytorchfire[examples]'
%pip install requests
%pip install matplotlib

In [None]:
# download the dataset

map_name = 'Bear_2020'

import requests

def download_file(url, filename):
    response = requests.get(url, stream=True)
    response.raise_for_status()

    with open(filename, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)

file_url = f'https://github.com/xiazeyu/PyTorchFire/raw/refs/heads/jupyter-examples/examples/{map_name}.npz'
filename = f'{map_name}.npz'
download_file(file_url, filename)

print(f"Dataset downloaded and saved as {filename}")


In [None]:
# load the dataset
import numpy as np

ds = np.load('Bear_2020.npz')

print('Dataset keys:', ds.files)

In [None]:
# define Trainer class (~DataLoader)

from pytorchfire import WildfireModel, BaseTrainer
from tqdm import tqdm
import torch

class DemoTrainer(BaseTrainer):
    def train(self, ds):
        wind_step_interval = ds['wind_step_interval']

        self.reset()
        self.model.to(self.device)
        self.model.train()

        max_iterations = self.max_steps // self.steps_update_interval

        postfix = {}
        with tqdm() as progress_bar:
            for epochs in range(self.max_epochs):
                postfix['epoch'] = f'{epochs + 1}/{self.max_epochs}'
                self.model.reset()
                batch_seed = self.model.seed

                for iterations in range(max_iterations):
                    postfix['iteration'] = f'{iterations + 1}/{max_iterations}'
                    iter_max_steps = min(self.max_steps, (iterations + 1) * self.steps_update_interval)
                    progress_bar.reset(total=iter_max_steps)

                    for steps in range(iter_max_steps):
                        postfix['step'] = f'{steps + 1}/{iter_max_steps}'

                        if steps % wind_step_interval == 0:
                            self.model.wind_towards_direction = torch.tensor(
                                ds['wind_towards_direction'][steps // wind_step_interval], device=self.device)
                            self.model.wind_velocity = torch.tensor(ds['wind_velocity'][steps // wind_step_interval],
                                                                    device=self.device)

                        self.model.compute(attach=self.check_if_attach(steps, iter_max_steps))

                        progress_bar.set_postfix(postfix)
                        progress_bar.update(1)

                    outputs = self.model.accumulator
                    targets = ds['target'][iter_max_steps - 1]
                    targets = torch.tensor(targets, device=self.device)

                    loss = self.criterion(outputs, targets)
                    postfix['loss'] = f'{loss.item():.4f}'

                    self.backward(loss)
                    self.model.reset(seed=batch_seed)

    def evaluate(self, ds):

        wind_step_interval = ds['wind_step_interval']

        self.reset()
        self.model.to(self.device)
        self.model.eval()

        affected_cell_count_outputs = []
        affected_cell_count_targets = []

        postfix = {}
        output_list = []

        with tqdm(total=self.max_steps) as progress_bar:
                with torch.no_grad():
                    for steps in range(self.max_steps):
                        postfix['steps'] = f'{steps + 1}/{self.max_steps}'

                        if steps % wind_step_interval == 0:
                            self.model.wind_towards_direction = torch.tensor(
                                ds['wind_towards_direction'][steps // wind_step_interval], device=device)
                            self.model.wind_velocity = torch.tensor(ds['wind_velocity'][steps // wind_step_interval],
                                                                    device=device)

                        self.model.compute()
                        outputs = self.model.state[0] | self.model.state[1]

                        postfix['burning'] = self.model.state[0].sum().detach().cpu().item()
                        postfix['burned'] = self.model.state[1].sum().detach().cpu().item()

                        output_list.append(outputs.cpu().detach().numpy())

                        progress_bar.set_postfix(postfix)
                        progress_bar.update(1)

        return output_list

In [None]:
# define trainer

from pytorchfire import WildfireModel, BaseTrainer
from tqdm import tqdm
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {device}")

trainer = DemoTrainer(model=WildfireModel({
            'p_veg': torch.tensor(ds['p_veg']),
            'p_den': torch.tensor(ds['p_den']),
            'wind_towards_direction': torch.tensor(ds['wind_towards_direction'][0]),
            'wind_velocity': torch.tensor(ds['wind_velocity'][0]),
            'slope': torch.tensor(ds['slope']),
            'initial_ignition': torch.tensor(ds['initial_ignition'], dtype=torch.bool)
        }, {
            'a': torch.tensor(.0),
            'p_h': torch.tensor(.15),
            'p_continue': torch.tensor(ds['p_continue']),
            'c_1': torch.tensor(.0),
            'c_2': torch.tensor(.0),
        }), device=torch.device(device))

trainer.max_epochs = 5
trainer.steps_update_interval = 10
trainer.max_steps = int(ds['max_steps'])
trainer.lr = 0.005
trainer.seed = None


In [None]:
# Visualize the simulation (before calibration)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# generate calibrated simulation
after_calibration = np.array(trainer.evaluate(ds))
ground_truth = ds['target']

combined = np.concatenate((np.array(after_calibration), ground_truth), axis=2)

fig, ax = plt.subplots()

im = ax.imshow(combined[0])
ax.set_title('Left: Uncalibrated, Right: Target')

def update(frame):
    im.set_array(combined[frame])
    return [im]

ani = FuncAnimation(
    fig, update, frames=len(combined), interval=100, blit=True
)

ani.save('calibration_before.gif', fps=10)
HTML(ani.to_jshtml())

![Animation](https://github.com/xiazeyu/PyTorchFire/blob/jupyter-examples/examples/calibration_before.gif?raw=1)

In [None]:
# perform parameter calibration
trainer.train(ds=ds)

In [None]:
# Visualize the simulation

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# generate calibrated simulation
after_calibration = np.array(trainer.evaluate(ds))
ground_truth = ds['target']

combined = np.concatenate((np.array(after_calibration), ground_truth), axis=2)

fig, ax = plt.subplots()

im = ax.imshow(combined[0])
ax.set_title('Left: Calibrated, Right: Target')

def update(frame):
    im.set_array(combined[frame])
    return [im]

ani = FuncAnimation(
    fig, update, frames=len(combined), interval=100, blit=True
)

ani.save('calibration_after.gif', fps=10)
HTML(ani.to_jshtml())

![Animation](https://github.com/xiazeyu/PyTorchFire/blob/jupyter-examples/examples/calibration_after.gif?raw=1)