## Model 3: Modified EfficientNet with SE Block

In [1]:
"""
Burning Signals, Forecasting Wildfires
"""

'\nBurning Signals, Forecasting Wildfires\n'

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
import timm
from torchviz import make_dot

# ========== DATASET LOADING ==========
#augmented to help model generalize better
data_transform_train = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #normalizes each color channel and converts raw pixel value to mean may be 0 and std cloe to 1
])

data_transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder("WildFire/train", transform=data_transform_train)
val_dataset = datasets.ImageFolder("WildFire/val", transform=data_transform_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Confirming dataset sizes
print(f"Training images: {len(train_dataset)}")
print(f"Validation images: {len(val_dataset)}")
#print(f"Test images: {len(test_dataset)}")

Training images: 1887
Validation images: 410


# ===== Starting model definition =====

In [3]:
from torchviz import make_dot
# ========== PROXY NORMALIZED ACTIVATION ==========
#Normalizes across for each individual input(each channels at each pixel location in a CNN)
#Replacement for batch normalization which normalization across the batch of data
class ProxyNormReLU(nn.Module):
     def __init__(self, num_channels):
        super().__init__()
        self.layernorm = nn.LayerNorm(num_channels)
        self.beta_hat = nn.Parameter(torch.zeros(1, num_channels))
        self.gamma_hat = nn.Parameter(torch.zeros(1, num_channels))

     def forward(self, x):
        b, c, h, w = x.size()
        x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, c)  # reshape to [B*H*W, C]
        x_norm = self.layernorm(x_flat)
        x_norm = x_norm.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
        act = torch.relu(x_norm) #It's defined as f(x) = max(0, x), meaning it outputs the input directly if it's positive and zero otherwise(keep only positive values)
        beta = self.beta_hat.view(1, c, 1, 1)
        gamma = (1 + self.gamma_hat).clamp(min=1e-5).view(1, c, 1, 1)
        return (act - beta) / gamma
        
# ========== SE BLOCK ==========
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

# ========== MODEL ==========
class EfficientNetEnhanced(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.base = timm.create_model("efficientnet_b0", pretrained=True, features_only=False)
        self.base.classifier = nn.Identity()
        self.base.global_pool = nn.Identity()

        in_channels = 1280
        self.conv_grouped = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1, groups=16)
        self.norm = ProxyNormReLU(256)

        self.head = nn.Sequential(
            SEBlock(256), #256 input channels
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.base(x)
        x = self.conv_grouped(x)
        x = self.norm(x)
        x = self.head(x)
        return x

model = EfficientNetEnhanced(num_classes=2)
x = torch.randn(1, 3, 224, 224)
y = model(x)

diagram = make_dot(y, params=dict(list(model.named_parameters())))
diagram.format = "png"
diagram.render("efficientnet_enhanced_diagram")


'efficientnet_enhanced_diagram.png'

In [4]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

In [5]:
def evaluate(model, loader, criterion, device):
    model.eval()
    total = correct = 0
    running_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return running_loss / len(loader), correct / total, all_preds, all_labels

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientNetEnhanced(num_classes=2).to(device)

# Class weights if imbalance exists 
class_weights = torch.tensor([1.0, 1.0], device=device)  
criterion = nn.CrossEntropyLoss(weight=class_weights)


best_acc = 0.0
best_model_path = "best_model.pth"

print("Training on 160x160 images...")
val_dataset.transform = data_transform_val
val_loader = DataLoader(val_dataset, batch_size=32)

# Freeze base encoder for phase 1
for param in model.base.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

for epoch in range(15):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, all_preds, all_labels = evaluate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/15 - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")

    # best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved")

    scheduler.step(val_loss)


model.load_state_dict(torch.load("best_model.pth"))
model.to(device)
model.eval()

val_loss, val_acc, all_preds, all_labels = evaluate(model, val_loader, criterion, device)
print(f"Validation Accuracy after fine-tuning: {val_acc:.4f}")

Training on 160x160 images...
Epoch 1/15 - Train Loss: 0.6627, Val Loss: 0.6078, Val Accuracy: 0.6561
Best model saved


In [None]:
class_names = ["fire", "nofire"]
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.title("Confusion Matrix - EfficientNetEnhanced")
plt.xlabel("Predicted")
plt.ylabel("True")
os.makedirs("results", exist_ok=True)
plt.savefig("results/confusion_matrix_efficientnet_enhanced.png")
plt.show()
plt.close()

# Model size
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}")

# Save model
torch.save(model.state_dict(), "efficientnet_g16_ln_pn.pth")
print("Model saved as 'efficientnet_g16_ln_pn.pth'")