In [None]:
%load_ext autoreload
%autoreload 2

In [28]:
import os
from pathlib import Path
from zipfile import ZipFile
import urllib.request
from collections import defaultdict

import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.utils import draw_bounding_boxes
from tfrecord.torch.dataset import MultiTFRecordDataset
import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np

In [ ]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=24, out_channels=1, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Upsample(64),
            nn.Flatten(),
            nn.Linear(in_features=4096, out_features=4096),
            nn.Unflatten(1, (64, 64)),
        )

    def forward(self, x):
        return self.cnn(x).unsqueeze(1)


In [ ]:
class Reshape(torch.nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class ConvoAE(torch.nn.Module):
    def __init__(self):
        super(ConvoAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(12, 24, 3, 1, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(24, 24, 3, 1, 0),  # 32 x 32 -> 30 x 30
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Conv2d(24, 32, 3, 2, 0),  # 30 x 30 -> 14 x 14
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Conv2d(32, 32, 3, 2, 0),  # 14 x 14 -> 6 x 6
            nn.Flatten(),
            nn.Linear(1152, 2),  # 1152 = 32 * 6  * 6
        )

        self.decoder = nn.Sequential(
            nn.Linear(2, 1152),
            Reshape(-1, 32, 6, 6),
            nn.ConvTranspose2d(32, 32, 3, 1, 0),  # 6 x 6 -> 8 x 8
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(32, 16, 3, 2, 1),  # 8 x 8 -> 15 x 15
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(16, 16, 3, 2, 0),  # 15 x 15 -> 31 x 31
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(16, 8, 3, 1, 0),  # 31 x 31 -> 33 x 33
            nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(8, 1, 2, 2, 1),  # 33 x 33 -> 64 x 64
            nn.Flatten(),
            nn.Linear(in_features=4096, out_features=4096 * 2),
            nn.Unflatten(1, (2, 4096)),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [ ]:
def reweight(cls_num_list, beta=0.9999):
    per_cls_weights = torch.Tensor(
        list(map(lambda n: (1 - beta) / (1 - beta**n), cls_num_list))
    )
    per_cls_weights *= len(cls_num_list) / per_cls_weights.sum()
    return per_cls_weights


class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=0.0):
        super().__init__()
        assert gamma >= 0
        self.gamma = gamma
        self.weight = weight

    def forward(self, input, target):
        return F.cross_entropy(
            (1 - F.softmax(input, dim=1)) ** self.gamma * F.log_softmax(input, dim=1),
            target,
            weight=self.weight,
        )


In [None]:
# parameters
BATCH_SIZE=256
FEATURES = [
    "elevation",
    "th",
    "vs",
    "tmmn",
    "tmmx",
    "sph",
    "pr",
    "pdsi",
    "NDVI",
    "population",
    "erc",
    "PrevFireMask",
]
LABELS = ["FireMask"]

EPOCHS = 10
LEARNING_RATE = 1e-2

ARR_SIZE = 4096
LENGTH, WIDTH = 64, 64

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2

In [None]:
# set up data directory
data_dir = os.path.join(os.path.abspath("."), "data")
Path(data_dir).mkdir(parents=True, exist_ok=True)

In [None]:
# download data zip
data_zip = os.path.join(data_dir, "archive.zip")
if not os.path.exists(data_zip):
    url = "https://www.kaggle.com/api/v1/datasets/download/fantineh/next-day-wildfire-spread"
    urllib.request.urlretrieve(url, data_zip)

In [None]:
# extract files from zip
files = defaultdict(list)
file_types = ["eval", "train", "test"]
with ZipFile(data_zip, "r") as z:
    for file in z.namelist():
        for file_type in file_types:
            if file_type in file:
                files[file_type].append(Path(file).stem)
        if not os.path.exists(os.path.join(data_dir, file)):
            z.extract(file, data_dir)

In [None]:
# get all records into a data loader
from torch.utils.data import default_collate
def collate_fn(batch):
    batch = [b for b in batch if (-1 not in b['PrevFireMask']) and (-1  not in b['FireMask'])]
    return default_collate(batch)

def get_loader_from_file_type(files: list,  record_path: str, file_types: list):
    f = [file for file_type in file_types for file in files[file_type]]
    dataset = MultiTFRecordDataset(
        tfrecord_path, 
        None, 
        splits={file: 1.0 for file in f},
        infinite=False
    )
    loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    return loader

tfrecord_path = os.path.join(data_dir, "{}.tfrecord")
train_loader = get_loader_from_file_type(files, tfrecord_path, ["train"])
test_loader = get_loader_from_file_type(files, tfrecord_path, ["test", "eval"])

In [None]:
def plot_losses(train_losses, test_losses):
    plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss")
    plt.plot(range(1, len(test_losses)+1), test_losses, label="Test Loss")
    plt.xticks(range(1, len(train_losses)+1))
    plt.legend()
    plt.show()

In [None]:
data = next(iter(train_loader))
print(data['FireMask'].shape)

data = next(iter(test_loader))
print(data['FireMask'].shape)

In [None]:
def get_dataset_items(data, item_list, length=LENGTH, width=WIDTH):
    items = torch.cat([data[key][:, None, :] for key in item_list], dim=1)
    items = items.reshape(items.shape[0], items.shape[1], length, width)
    return items

In [None]:
# gather batch of features
features = get_dataset_items(data, FEATURES)
labels = get_dataset_items(data, LABELS)
print(features.shape)
print(labels.shape)

In [None]:
def make_rcnn_target(labels: list[torch.Tensor], n_boxes=2, l=64, w=64, label_function=torch.max):
    """
    length and width should be divisible evenly by n
    i.e., 64x64 image can have 2, 4, 8, etc. evenly spaced boxes
    
    Create list of dict, each containing box bounds and corresponding labels for the boxes
    Number of boxes is in one side of the resulting square, i.e., boxes=2 makes 4 boxes, boxes=4 makes 16 boxes
    Label function determines how the label is computed, usually either the max or mode of the box's area
    """
    assert l%n_boxes==0
    assert w%n_boxes==0
    
    x_step = int(l/n_boxes)
    y_step = int(w/n_boxes)
    targets = []
    
    for label in labels:
        boxes = torch.zeros((n_boxes**2, 4))
        target_labels = torch.zeros((n_boxes**2)).type(torch.int64)
        for x in range(n_boxes):
            for y in range(n_boxes):
                boxes[x*n_boxes+y]=torch.Tensor([[x*x_step, y*y_step, (x+1)*x_step-1, (y+1)*y_step-1]])
                target_labels[x*n_boxes+y]=torch.Tensor(
                    [label_function(label[0, x*x_step:(x+1)*x_step, y*y_step:(y+1)*y_step])]
                 )
        
        targets.append({"boxes": boxes, "labels": target_labels})
    return targets

In [None]:
def train(model, rcnn, loader, optimizer, n=None, n_boxes=2):
    model.train()
    rcnn.train()
    total_loss = 0
    for i, data in enumerate(loader):
        if n is None:
            n = len(data)
        features = get_dataset_items(data, FEATURES)[0:n]
        labels = get_dataset_items(data, LABELS)[0:n]
        weights = model(features)
        target = make_rcnn_target(labels, n_boxes=n_boxes)
        loss = rcnn(weights, target)["loss_classifier"]

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        # if i % 100 == 0:
        loss_val, current = loss.item(), i * BATCH_SIZE + len(features)
        total_loss += loss_val
    print(f"Train loss: {total_loss:>7f}")
    return loss, features, labels, weights, total_loss

In [None]:
def test(model, rcnn, loader, n=None):
    model.eval()
    rcnn.eval()
    total_loss, correct = 0, 0
    # output = []
    with torch.no_grad():
        for i, data in enumerate(loader):
            if n is None:
                n = len(data)
            features = get_dataset_items(data, FEATURES)[0:n]
            labels = get_dataset_items(data, LABELS)[0:n]
            # Compute prediction and loss
            weights = model(features)
            output = rcnn(weights)
    return output, features, labels

In [None]:
def calculate_test_loss(model, rcnn, loader, n=None, n_boxes=2):
    """Hacky workaround to get test losses for RCNN"""
    total_loss = 0
    model.eval()
    rcnn.train()
    with torch.no_grad():
        for i, data in enumerate(loader):
            if n is None:
                n = len(data)
            features = get_dataset_items(data, FEATURES)[0:n]
            labels = get_dataset_items(data, LABELS)[0:n]
            target = make_rcnn_target(labels, n_boxes=n_boxes)

            # Compute prediction and loss
            weights = model(features)
            loss = rcnn(weights, target)
            total_loss+=loss['loss_classifier'].item()
    print(f"Test loss: {total_loss:>7f}")
    return loss, total_loss

In [None]:
# from models.cnn import CNN
from torchvision.models.detection import fasterrcnn_resnet50_fpn

model = CNN()
rcnn = fasterrcnn_resnet50_fpn(num_classes=num_classes, progress=True)

model.to(device)
rcnn.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()

train_losses, test_losses = [], []

for t in range(EPOCHS):
    print(f'Epoch {t+1}\n')
    _, _, _, weights, train_loss = train(model, rcnn, train_loader, optimizer, n_boxes=4)
    pred, features, labels = test(model, rcnn, test_loader)
    _, test_loss = calculate_test_loss(model, rcnn, test_loader, n_boxes=4)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
plot_losses(train_losses, test_losses)
    

In [None]:
# visualize features and labels

rows = 5
cols = 14
CMAP = colors.ListedColormap(['black', 'silver', 'orangered'])
BOUNDS = [-1, -0.1, 0.001, 1]
NORM = colors.BoundaryNorm(BOUNDS, CMAP.N)

TITLES = [
    'Elevation',
    'Wind Direction',
    'Wind Velocity',
    'Min Temperature',
    'Max Temperature',
    'Humidity',
    'Precip',
    'Drought',
    'Vegetation',
    'Population Density',
    'Energy Release Component',
    'Previous Fire Mask',
    'True Fire Mask',
    'Predicted Fire Mask',
]

In [None]:
fig = plt.figure(figsize=(15,6.5))
fig.suptitle("Visualizations", fontsize=20)
# samples
for i in range(rows):
    # features and labels
    plots = torch.cat((features[i], labels[i]), dim=0)
    for j, plot in enumerate(plots):
        plot = plot.detach().numpy()
        plt.subplot(rows, cols, i*cols+j+1)
        if i==0:
            title = TITLES[j].replace(' ', '\n')
            plt.title(title)
        if j >= cols-3:
            plt.imshow(plot, cmap=CMAP, norm=NORM)
        else:
            plt.imshow(plot, cmap='viridis')   
        plt.axis('off')
plt.tight_layout()

In [None]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        print(np.asarray(img).shape)
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

BOXES = 5
for i in range(rows):
    boxed = draw_bounding_boxes(labels[i], pred[i]['boxes'][0:BOXES])
    show(boxed)

# Encoder-Decoder with Focal Loss

In [None]:
# get class counts
class_counts = torch.zeros(2)
for data in iter(train_loader):
    label_batch = get_dataset_items(data, LABELS)
    n_classes = label_batch.unique(return_counts=True)[1]
    class_counts += n_classes

In [None]:
# from models.focal_loss import reweight
per_class_weights = reweight(class_counts, beta=.999999).to(device)
print(per_class_weights)

In [None]:
def train(model, loss_fn, loader, optimizer, n=None):
    model.train()
    total_loss = 0
    for i, data in enumerate(loader):
        if n is None:
            n = len(data)        
        features = get_dataset_items(data, FEATURES)[0:n]
        labels = get_dataset_items(data, LABELS)[0:n]
        
        pred = model(features)
        
        _pred = torch.flatten(torch.flatten(pred, 2).transpose(0, 1), 1).transpose(0, 1)
        _labels =torch.flatten(labels).long()

        loss = loss_fn(_pred, _labels)
            
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # if i % 100 == 0:
        loss_val, current = loss.item(), i * BATCH_SIZE + len(features)
        total_loss += loss_val
        break
    print(f"Train loss: {total_loss:>7f}")    
    return loss, features, labels, total_loss

In [None]:
def test(model, loss_fn, loader, n=None):
    model.eval()
    total_loss, correct = 0, 0
    all_pred = []
    all_features = []
    all_labels = []
    with torch.no_grad():
        for i, data in enumerate(loader):
            if n is None:
                n = len(data)
            features = get_dataset_items(data, FEATURES)[0:n]
            labels = get_dataset_items(data, LABELS)[0:n]
            # Compute prediction and loss
            pred = model(features)
            _pred = torch.flatten(torch.flatten(pred, 2).transpose(0, 1), 1).transpose(0, 1)
            _labels =torch.flatten(labels).long()
            
            loss = loss_fn(_pred, _labels)
            total_loss += loss.item()
            all_pred.append(pred)
            all_features.append(features)
            all_labels.append(labels)
            break
    print(f"Test loss: {total_loss:>7f}")
    return all_pred, all_features, all_labels, total_loss

In [None]:
# from models.encoder_decoder import ConvoAE
# from models.focal_loss import FocalLoss
model = ConvoAE()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
loss_fn = FocalLoss(weight=per_class_weights)

train_losses=[]
test_losses=[]

for t in range(EPOCHS):
    print(f'Epoch {t+1}\n')                                            
    loss, features, labels, train_loss = train(model, loss_fn, train_loader, optimizer)
    test_pred, test_features, test_labels, test_loss = test(model, loss_fn, test_loader)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
plot_losses(train_losses, test_losses)

In [None]:
fig = plt.figure(figsize=(15,6.5))
fig.suptitle("Visualizations", fontsize=20)
# samples
features = test_features[0]
labels = test_labels[0]
pred = test_pred[0]
for i in range(rows):
    # features and labels
    
    pred = pred.argmax(1, keepdim=True).reshape(pred.shape[0], 1, LENGTH, WIDTH)

    plots = torch.cat((features[i], labels[i], pred[i]), dim=0)
    for j, plot in enumerate(plots):
        plot = plot.detach().numpy()
        plt.subplot(rows, cols, i*cols+j+1)
        if i==0:
            title = TITLES[j].replace(' ', '\n')
            plt.title(title)
        if j >= cols-3:
            plt.imshow(plot, cmap=CMAP, norm=NORM)
        else:
            plt.imshow(plot, cmap='viridis')   
        plt.axis('off')
plt.tight_layout()