In [None]:
import numpy as np
from copy import deepcopy
from math import exp, sqrt
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import time
from scipy.sparse import lil_matrix
import os
import math

In [None]:
def gaussian_mixture_field(size, n_components=None):
    x, y = np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))
    field = np.zeros((size, size))

    if not n_components:
        n_components = np.random.poisson(20)

    for _ in range(n_components):
        cx, cy = np.random.uniform(0, 1, 2) # random center
        sx, sy = np.random.uniform(0.01, 0.2, 2) # random covariance scale
        w = np.random.exponential(10) # weight scale

        gaussian = w * np.exp(-(((x - cx) ** 2) / (2 * sx**2) +
                                ((y - cy) ** 2) / (2 * sy**2)))
        field += gaussian

    field = (1 - (field - field.min()) / (field.max() - field.min()))*2
    return field

In [None]:
class Simulator():
    def __init__(self, size=256, wind_speed=0, wind_direction=[0,0], response_rate=0.1, response_start=20, base_spread_rate=0.3, n_components=None, decay_rate=1e-3):
        self.size = size
        self.map = np.zeros((size, size))
        self.wind_speed = wind_speed
        self.wind_direction = wind_direction
        self.response_rate = response_rate
        self.response_start = response_start
        self.spread_rate = base_spread_rate
        self.time = 0
        self.decay_rate = decay_rate
        self.maps = {}

        self.terrain = gaussian_mixture_field(size, n_components=n_components)

    def step(self):
        new_map = deepcopy(self.map)

        for i in range(self.size):
            for j in range(self.size):
                if self.map[i, j] >= 1:
                    if np.random.rand() < self.spread_rate*self.terrain[i, j]*np.exp(-self.decay_rate * self.time) and new_map[i, j] < 5:
                        new_map[i, j] += 1
                    
                    for di in [-1, 0, 1]:
                        for dj in [-1, 0, 1]:
                            if di == 0 and dj == 0:
                                continue

                            ni, nj = i + di, j + dj
                            spread_chance = self.spread_rate*self.map[i,j]

                            if 0 <= ni < self.size and 0 <= nj < self.size:
                                if self.wind_speed > 0:
                                    wind_influence = (di * self.wind_direction[0] + dj * self.wind_direction[1]) / (np.linalg.norm(self.wind_direction) + 1e-6)
                                    wind_influence *= np.random.normal(1, 0.5)


                                    if wind_influence > 0:
                                        spread_chance *= (1 + self.wind_speed * wind_influence)
                                    spread_chance *= self.terrain[ni, nj]
                                    spread_chance *= np.exp(-self.decay_rate * self.time)
                                    spread_chance = np.clip(spread_chance, 0, 1)

                                if np.random.rand() < spread_chance and new_map[ni, nj] <= new_map[i, j]:
                                    if new_map[ni, nj] < 5:
                                        if np.random.rand() <= exp(-self.time/1000):
                                            new_map[ni, nj] += 1

                                if self.time >= self.response_start and new_map[ni, nj] == 0:
                                    if np.random.rand() < 1 - exp(-(self.response_rate*(0.5+self.terrain[i,j]) * (self.time - self.response_start))):
                                        if new_map[i, j] > 0:
                                            new_map[i, j] -= 1 # Firefighting effort
                            if np.exp(-self.decay_rate * self.time) < 0.5:
                                if ni < 0 or ni >= self.size or nj < 0 or nj >= self.size:
                                    if np.random.rand() < 1 - exp(-(self.response_rate) * (self.time - self.response_start)):
                                        if new_map[i, j] > 0:
                                            new_map[i, j] -= 1 # Edge effect
                    
                else:
                    if 1 < i < self.size - 1 and 1 < j < self.size - 1:
                        neighbors_on_fire = np.sum(self.map[i-1:i+2, j-1:j+2] >= 1) - (1 if self.map[i, j] >= 1 else 0)
                        if neighbors_on_fire >= 6 and new_map[i, j] == 0:
                            new_map[i, j] += 1
        
        self.maps[self.time] = deepcopy(self.map)
        self.map = new_map
        self.time += 1

    def simulate(self):
        nodes = np.random.poisson(3)

        x_init, y_init = np.random.randint(0, self.size, size=2)
        self.map[x_init, y_init] = np.random.poisson(3)

        for _ in range(nodes - 1):
            while True:
                x, y = np.random.randint(-20, 21, size=2)
                if x_init+x < self.size and y_init+y < self.size:
                    break
                
            if self.map[x_init+x, y_init+y] == 0:
                self.map[x_init+x, y_init+y] = np.random.poisson(3)
                break
    
        while np.any(self.map > 0):
            self.step()

In [None]:
#simulator = Simulator(size=256, wind_speed=1.5, wind_direction=[1,2], response_rate=0.05, response_start=100, base_spread_rate=0.05)
#simulator.simulate()

In [None]:
def make_heatmap_gif(simulator, filename="simulation.gif", cmap="plasma"):
    times = sorted(simulator.maps.keys())
    frames = [simulator.maps[t] for t in times]

    fig, ax = plt.subplots()

    # --- fire as background ---
    vmin, vmax = np.min(frames), np.max(frames)
    fire_img = ax.imshow(frames[0], cmap=cmap, vmin=vmin, vmax=vmax)

    # --- terrain overlay ---
    terrain_img = ax.imshow(simulator.terrain, cmap="Greens", alpha=0.2)  # low alpha on top

    cbar = fig.colorbar(fire_img, ax=ax)
    cbar.set_label("Fire Intensity", rotation=270, labelpad=15)

    def update(frame):
        fire_img.set_data(frame)    
        return [fire_img, terrain_img]

    ani = animation.FuncAnimation(
        fig, update, frames=frames, interval=60, blit=True
    )

    ani.save(filename, writer="pillow")
    plt.close(fig)

# make_heatmap_gif(simulator, "fire.gif")

In [None]:
# One approach could be using GNN's, which would result in a scalable network to different grid sizes
# Another is to use a transformer, easier to train
# I'm going to try a Graph Attention Network.

In [None]:
def adjacency_matrix(length, width):
    N = length * width
    adj = lil_matrix((N, N), dtype=np.float32)
    directions = [(-1,0), (-1,1), (0,1), (1,1), (1,0), (1,-1), (0,-1), (-1,-1), (0,0)] # 8 sided
    for i in range(N):
        x, y = divmod(i, width)
        for dx, dy in directions:
            nx, ny = x+dx, y+dy
            if 0 <= nx < length and 0 <= ny < width:
                j = nx*width + ny
                adj[i,j] = 1
    return adj.tocoo()

In [None]:
class FireGraph(Dataset):
    def __init__(self, length=256, width=256, path="simulation_data"):
        self.path = path
        self.length = length
        self.width = width


        adj = adjacency_matrix(self.length, self.width)
        self.adjacency_matrix = torch.sparse_coo_tensor(
            indices=torch.tensor(np.vstack((adj.row, adj.col)), dtype=torch.long),
            values=torch.tensor(adj.data, dtype=torch.float32),
            size=adj.shape
        )

        self.data = []

        self.save_dir = os.path.join(self.path, f"{self.length}x{self.width}")
        os.makedirs(self.save_dir, exist_ok=True)

    def generate_data(self, topology:np.array=None, past_info:np.array=None, wind_direction:np.array=np.array([0,0]),
                         wind_speed:int=0, time:int=0, label:np.array=None):
        
        flat_topo = topology.ravel()
        flat_info = past_info[:, :, 0].ravel()
        flat_info_date = past_info[:, :, 1].ravel()
        flat_label = label.ravel()
        
        data = np.stack([
            flat_topo,
            flat_info,
            flat_info_date,
            np.full(flat_topo.shape, wind_direction[0], dtype=np.float32),
            np.full(flat_topo.shape, wind_direction[1], dtype=np.float32),
            np.full(flat_topo.shape, wind_speed, dtype=np.float32),
            np.full(flat_topo.shape, time, dtype=np.float32),
            flat_label
        ], axis=1)

        return data
    
    def save_data(self, data:np.array):
        np.save(os.path.join(self.save_dir, f"{time.time():.0f}.npy"), data)
    
    def generate_dataset(self):
        arrays = []
        for file in os.listdir(self.save_dir):
            if file.endswith(".npy"):
                arr = np.load(os.path.join(self.save_dir, file))
                arrays.append(arr)

        if arrays:
            self.data = np.concatenate(arrays, axis=0)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        arr = self.data[idx]  # arr shape = (N, F+1)
        x = arr[:, :-1]       # (N, F)
        y = arr[:, -1]        # (N,)
        return torch.from_numpy(x).float(), torch.from_numpy(y).long()


In [None]:
# referenced https://github.com/gordicaleksa/pytorch-GAT/blob/main/The%20Annotated%20GAT%20(Cora).ipynb for some of the code.

class BelieverModel(nn.Module):
    def __init__(self, nodes=256*256, input_features=7, num_layers=3, num_heads=3, num_features_per_head=4, num_output_classes=5, dropout=False):
        super().__init__()
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.relu = nn.ReLU()
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = False
        self.N = nodes
        self.num_heads = num_heads
        self.num_features_per_head = num_features_per_head

        self.initial_transformation = nn.Linear(input_features, num_heads * num_features_per_head, bias=False)
        nn.init.xavier_normal_(self.initial_transformation.weight)

        self.a_lefts = nn.ParameterList()
        self.a_rights = nn.ParameterList()
        self.Ws = nn.ParameterList()
        for i in range(num_layers):
            a_left = nn.Parameter(torch.zeros(size=(num_heads, num_features_per_head)))
            nn.init.xavier_uniform_(a_left)
            a_right = nn.Parameter(torch.zeros(size=(num_heads, num_features_per_head)))
            nn.init.xavier_uniform_(a_right)
            W = nn.Parameter(torch.zeros(size=(num_heads, num_features_per_head, num_features_per_head)))
            nn.init.xavier_normal_(W)

            self.a_lefts.append(a_left)
            self.a_rights.append(a_right)
            self.Ws.append(W)
        
        self.final_transformation = nn.Linear(num_heads*num_features_per_head, num_output_classes)
        nn.init.xavier_normal_(self.final_transformation.weight)
    
    def forward(self, x, adj):
        # adj = (N, N) adjacency matrix, in coo matrix format
        # x = inputs (N, F_in)

        if self.dropout:
            x = self.dropout(x)
            
        N = x.size(0)
        x = self.initial_transformation(x) # (N, F_out*H)
        x = x.view(N, self.num_heads, self.num_features_per_head) # (N, H, F_out)

        for W, a_left, a_right in zip(self.Ws, self.a_lefts, self.a_rights):
            # W = (H, F_out, F_out)
            # a_left = (H, F_out)
            # a_right = (H, F_out)

            # alpha_i,j = exp(a * [h_i||h_j] * adj[i,j]) / sum_j(exp(a * [h_i||h_j] * adj[i,j])), softmax
            # h(i') = adj * sum_j (alpha_i,j * W * h_j)
            # to simplify, we split a into 2 parts a_left and a_right, and calculate the attention scores for each of those parts, then sum up the scores only for viable pairs for computational efficiency

            h_prime = torch.einsum("nhf,hfo->nho", x, W) # (N, H, F_out) x (H, F_out, F_out) -> (N, H, F_out)

            source_scores = (h_prime * a_left).sum(-1) # elementwise product, (N, H)
            neighbor_scores = (h_prime * a_right).sum(-1) # (N, H)

            adj = adj.coalesce()
            row, col = adj.indices()
            row = row.long(); col = col.long()
            e = self.leakyrelu(source_scores[row] + neighbor_scores[col]) # (E, H) where E is the number of edges

            H =e.size(1)
            if hasattr(torch.Tensor, "scatter_reduce"):
                max_per_node = torch.zeros((N, H), device=e.device, dtype=e.dtype).scatter_reduce(
                    0, row.unsqueeze(-1).expand(-1,H), e, reduce="amax", include_self=False
                ) # maximum score among all neighbors of node i, (N, H)
            else:
                # fallback: compute max per node manually (safe but slower)
                max_per_node = torch.full((N,H), -1e9, device=e.device, dtype=e.dtype)
                for i_edge, i_node in enumerate(row):
                    max_per_node[i_node] = torch.maximum(max_per_node[i_node], e[i_edge])
            exp_e = torch.exp(e - max_per_node[row]) # for numerical stability, (E, H)

            denom = torch.zeros((N, H), device=e.device, dtype=e.dtype)
            denom.index_add_(0, row, exp_e) # summation over neighbors, (N, H)

            alpha = exp_e / (denom[row] + 1e-9) # elementwise division, (E, H)

            messages = h_prime[col] * alpha.unsqueeze(-1) # (E, H, F)

            out = torch.zeros_like(h_prime, device=h_prime.device, dtype=h_prime.dtype)
            out.index_add_(0, row, messages) # summation over neighbors again, (N, H, F)
            x = self.relu(out)
        
        x = x.reshape(N, self.num_heads * self.num_features_per_head)
        logits = self.final_transformation(x)
        return logits

In [None]:
class DatasetGenerator():
    def __init__(self, size=256, wind_speed=1.5, wind_direction=[1,2], response_rate=0.05, response_start=100, base_spread_rate=0.05, perturb=True):
        self.size = size
        self.wind_speed = wind_speed
        self.wind_direction = wind_direction
        self.response_rate = response_rate
        self.response_start = response_start
        self.base_spread_rate = base_spread_rate
        self.perturb = perturb
        self.dataset = FireGraph(length=self.size, width=self.size)

    def generate(self, num_sims=100, std_dev=0.2):
        for i in range(num_sims):
            if self.perturb:
                
                simulator = Simulator(size=self.size, wind_speed=np.random.normal(self.wind_speed, std_dev*self.wind_speed), wind_direction=[np.random.uniform(-1, 1), np.random.uniform(-1,1)], 
                    response_rate=max(np.random.normal(self.response_rate, std_dev*self.response_rate), 0.005),
                    response_start=max(np.random.normal(self.response_start, math.floor(std_dev*self.response_start)), 0), 
                    base_spread_rate=max(np.random.normal(self.base_spread_rate, std_dev*self.base_spread_rate), 0),
                )
            else:
                simulator = Simulator(size=self.size, wind_speed=self.wind_speed, wind_direction=self.wind_direction, response_rate=self.response_rate, 
                    response_start=self.response_start, base_spread_rate=self.base_spread_rate)
                
            simulator.simulate()

            simulation_data = []
            past_info = np.zeros((self.size, self.size, 2))
            for t in simulator.maps:
                if 0.5 > np.random.rand():
                    coords = np.random.randint(0, 256, 2)
                    if t>0:
                        prev_map = simulator.maps[t-1]
                        y_min, y_max = max(0, coords[0]-3), min(self.size, coords[0]+4)
                        x_min, x_max = max(0, coords[1]-3), min(self.size, coords[1]+4)

                        past_info[y_min:y_max, x_min:x_max, 0] = prev_map[y_min:y_max, x_min:x_max]
                        past_info[y_min:y_max, x_min:x_max, 1] = 0     

                    past_info[:, :, 1] += 1

                
                data_point = self.dataset.generate_data(topology=simulator.terrain, past_info=past_info, wind_direction=np.array(simulator.wind_direction), 
                                           wind_speed=simulator.wind_speed, time=t, label=simulator.maps[t])
                simulation_data.append(data_point)
                
            self.dataset.save_data(np.array(simulation_data))


In [None]:
class Trainer():
    def __init__(self, size=256, model_layers=10, lr=1e-4, num_output_classes=5, model_dir=None):
        self.generator = DatasetGenerator(size=size)

        self.model_dir = model_dir
        self.model = BelieverModel(num_layers=model_layers, num_output_classes=num_output_classes, nodes=size**2)
        if model_dir:
            self.model.load_state_dict(torch.load(model_dir))

        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.95)

        self.loss_fn = nn.CrossEntropyLoss() # ordinal labels

    def generate(self, num_sims=100):
        self.generator.generate(num_sims=num_sims)

    def clear_data(self):
        dir = self.generator.dataset.save_dir

        for f in os.listdir(dir):
            if f.endswith(".npy"):
                os.remove(os.path.join(dir, f))

    def train(self, device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
              num_sims=100, num_epochs=100):
        self.model.to(device)
        self.model.train()
        
        for epoch in range(num_epochs):
            self.generate(num_sims=num_sims)

            self.generator.dataset.generate_dataset()
            print(f"Generated the simulations for {epoch+1}/{num_epochs}.")
            train_loader = DataLoader(self.generator.dataset, batch_size=32, shuffle=True)

            total_loss = 0.0
            total_correct = 0
            total_nodes = 0

            for batch_x, batch_y in train_loader:
                B, N, F = batch_x.shape
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                flat_x = batch_x.reshape(B * N, F)
                adj_batch = block_diag_batch(self.generator.dataset.adjacency_matrix.to(device), B)

                self.optimizer.zero_grad()
                logits = self.model(flat_x, adj_batch)

                loss = self.loss_fn(
                    logits, 
                    batch_y.reshape(-1)
                ) # averaged over all nodes in all maps

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()

                total_loss += loss.item()
                preds = logits.argmax(dim=-1)
                total_correct += (preds == batch_y).sum().item()
                total_nodes += batch_y.numel()

            avg_loss = total_loss / len(train_loader)
            acc = total_correct / total_nodes
            print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | Acc: {acc:.4f}")

            self.clear_data()

            self.scheduler.step()

        torch.save(self.model.state_dict(), self.model_dir)

    def generate_comparison_gif(self, filename="comparison.gif", cmap="plasma", 
                                device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):

        # Generate one simulation and load the saved .npy so we have the full time sequence
        self.generate(num_sims=1)
        dataset = self.generator.dataset
        # find the most recent saved simulation file in the dataset save_dir
        files = [f for f in os.listdir(dataset.save_dir) if f.endswith('.npy')]
        if not files:
            raise RuntimeError(f"No simulation files found in {dataset.save_dir}")
        latest = max(files, key=lambda f: os.path.getmtime(os.path.join(dataset.save_dir, f)))
        sim_path = os.path.join(dataset.save_dir, latest)

        sim = np.load(sim_path)  # shape = (T, N, F+1)
        T = sim.shape[0]
        N = sim.shape[1]
        size = int(np.sqrt(N))

        # terrain and past info come from the first timestep
        terrain = sim[0, :, 0].reshape(size, size)
        past_info = sim[0, :, 1].reshape(size, size)
        past_info_date = sim[0, :, 2].reshape(size, size)

        # truth_seq: (T, size, size)
        truth_seq = sim[:, :, -1].reshape(T, size, size)

        self.model.to(device)
        self.model.eval()

        # Compute model predictions for each timestep (so predictions vary across frames)
        pred_seq = []
        with torch.no_grad():
            for t in range(T):
                x_t = torch.from_numpy(sim[t, :, :-1]).float().to(device)  # (N, F)
                logits = self.model(x_t, dataset.adjacency_matrix.to(device))
                pred_map = logits.argmax(dim=1).cpu().numpy().reshape(size, size)
                pred_seq.append(pred_map)
        pred_seq = np.array(pred_seq)

        vmin = min(truth_seq.min(), pred_seq.min(), past_info.min())
        vmax = max(truth_seq.max(), pred_seq.max(), past_info.max())

        fig, axes = plt.subplots(1, 3, figsize=(12, 4))
        panels = ["Truth", "Knowledge", "Model Prediction"]
        fire_imgs = []
        terrain_imgs = []

        for ax, title in zip(axes, panels):
            ax.set_title(title)
            ax.axis("off")
            fire_img = ax.imshow(np.zeros((size, size)), cmap=cmap, vmin=vmin, vmax=vmax)
            fire_imgs.append(fire_img)
            terrain_img = ax.imshow(terrain, cmap="Greens", alpha=0.2)
            terrain_imgs.append(terrain_img)

        max_frame = truth_seq.shape[0]

        def update(frame_idx):
            # Truth for this timestep
            fire_imgs[0].set_data(truth_seq[frame_idx])

            # Knowledge, fading
            fading_map = np.where(past_info_date <= frame_idx, past_info,
                                past_info * np.exp(-(past_info_date - frame_idx) / 10))
            fire_imgs[1].set_data(fading_map)

            # Model predictions for this timestep
            fire_imgs[2].set_data(pred_seq[frame_idx])

            # return artists (blit disabled below, but returning is fine)
            return fire_imgs + terrain_imgs

        # disable blit to avoid backend-specific issues when saving with pillow
        ani = animation.FuncAnimation(fig, update, frames=max_frame, interval=200, blit=False)
        ani.save(filename, writer="pillow")
        plt.close(fig)

    def save_model(self):
        torch.save(self.model.state_dict(), self.model_dir)


In [None]:
def block_diag_batch(adj, batch_size):
    """
    Build block-diagonal adjacency matrix for efficient processing of batches.
    """
    N = adj.size(0)
    indices = adj.indices()
    values = adj.values()

    offsets = torch.arange(batch_size, device=indices.device) * N
    offsets = offsets.view(1, -1, 1)

    expanded = indices.unsqueeze(1).expand(-1, batch_size, -1)
    expanded = expanded + offsets
    expanded = expanded.permute(1, 0, 2).reshape(2, -1)

    expanded_values = values.repeat(batch_size)

    size = (N * batch_size, N * batch_size)
    return torch.sparse_coo_tensor(expanded, expanded_values, size=size, device=adj.device)

In [None]:
model_dir = "model/model.pth"
trainer = Trainer(model_dir = model_dir)

In [None]:
trainer.generate_comparison_gif()

In [None]:
trainer.train()