In [None]:
# ===============================
# Wildfire Prediction - Week 1 Project
# ===============================

# 1. Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 2. Load the dataset
# Define transforms (resize and normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Adjust dataset path as per your folder structure
train_dataset = datasets.ImageFolder(root="train", transform=transform)
valid_dataset = datasets.ImageFolder(root="valid", transform=transform)
test_dataset  = datasets.ImageFolder(root="test",  transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Equivalent of .info() -> Dataset summary
print("Number of training images:", len(train_dataset))
print("Number of validation images:", len(valid_dataset))
print("Number of testing images:", len(test_dataset))
print("Classes:", train_dataset.classes)

# Equivalent of .describe() -> Check one sample
img, label = train_dataset[0]
print("Sample image shape (C,H,W):", img.shape)
print("Sample image label:", train_dataset.classes[label])

# Equivalent of .isnull().sum() -> Check for missing/broken files
missing_files = [path for path, _ in train_dataset.samples if not os.path.exists(path)]
print("Missing or corrupted files in training set:", len(missing_files))

# 3. Explore the dataset

# 🔹 Class distribution for training set
labels = [label for _, label in train_dataset.samples]
sns.countplot(x=labels)
plt.title("Class Distribution in Training Dataset")
plt.xlabel("Class ID")
plt.ylabel("Count")
plt.show()


Number of training images: 30250
Number of validation images: 6300
Number of testing images: 6300
Classes: ['nowildfire', 'wildfire']
Sample image shape (C,H,W): torch.Size([3, 224, 224])
Sample image label: nowildfire
