Task:  Fine tune ConvNeXt V2 (facebook/convnextv2-huge-384)

Architecture: ConvNeXt V2 builds on the success of ConvNeXt V1, which was designed to improve the efficiency and performance of convolutional networks, making them competitive with transformer models.

Disaster imagery can include intricate details (e.g., damaged buildings, roads, etc.), and ConvNeXt V2 is particularly good at capturing such local patterns due to its advanced convolutional layers.
State-of-the-art: ConvNeXt V2 is one of the most powerful convolutional models, with an architecture designed to handle large-scale image classification tasks like ImageNet. It has shown excellent performance in both high-level and fine-grained image tasks.
Efficiency: While it's large, ConvNeXt V2 is optimized for efficiency compared to some transformer models, making it more manageable in terms of computational cost for training.

In [None]:
# !pip install datasets

In [None]:
import torch
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, IterableDataset
from datasets import load_dataset
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoImageProcessor
from tqdm.auto import tqdm

In [None]:
# Load the LADI dataset
ds = load_dataset("MITLL/LADI-v2-dataset", streaming=True)

In [None]:
# Define the label keys for multi-label classification
label_keys = ['bridges_any', 'buildings_any', 'buildings_affected_or_greater', 'buildings_minor_or_greater',
              'debris_any', 'flooding_any', 'flooding_structures', 'roads_any', 'roads_damage',
              'trees_any', 'trees_damage', 'water_any']

In [None]:
# Model and processor setup
model_name = "facebook/convnextv2-huge-22k-384"
processor = AutoImageProcessor.from_pretrained(model_name)

# Load the model while ignoring the size mismatch for the classifier layer
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=len(label_keys),
    ignore_mismatched_sizes=True  # Ignore classifier weight size mismatch
)

In [None]:
# Move the model to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# Define optimizer and gradient scaler for mixed precision
optimizer = AdamW(model.parameters(), lr=2e-5)
scaler = GradScaler()

In [None]:
# Image preprocessing transformation
image_transforms = transforms.Compose([
    transforms.Resize((384, 384)),  # Resize to match the input size of ConvNeXtV2
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
])

In [None]:
# StreamDataset class for handling image and label processing
class StreamDataset(IterableDataset):
    def __init__(self, dataset, split_name, label_keys, image_transforms):
        self.dataset = dataset
        self.split_name = split_name
        self.label_keys = label_keys
        self.image_transforms = image_transforms

    def process_item(self, item):
        image = item['image']
        labels = [int(item[key]) for key in self.label_keys]

        # Apply transformations to the image
        processed_image = self.image_transforms(image)
        return processed_image, labels

    def __iter__(self):
        for item in self.dataset[self.split_name]:
            yield self.process_item(item)

# Function to process the dataset for training
def process_dataset(model, dataset, split_name, label_keys, image_transforms, optimizer=None, train=False, batch_size=8):
    model.train() if train else model.eval()

    running_loss = 0.0
    all_labels = []
    all_preds = []

    processed_dataset = StreamDataset(dataset, split_name, label_keys, image_transforms)
    loader = DataLoader(processed_dataset, batch_size=batch_size, collate_fn=lambda x: tuple(zip(*x)))

    if not train:
        torch.no_grad()

    for batch_images, batch_labels in tqdm(loader):
        batch_images = torch.stack(batch_images).to(device)
        batch_labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)

        if train:
            with autocast():
                outputs = model(batch_images)
                loss = torch.nn.BCEWithLogitsLoss()(outputs.logits, batch_labels)

            running_loss += loss.item()

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            with torch.no_grad():
                outputs = model(batch_images)
                loss = torch.nn.BCEWithLogitsLoss()(outputs.logits, batch_labels)
                running_loss += loss.item()

        logits = outputs.logits.cpu().detach().numpy()
        predictions = torch.sigmoid(torch.tensor(logits)).cpu().detach().numpy()

        all_preds.extend(predictions)
        all_labels.extend(batch_labels.cpu().numpy())

        torch.cuda.empty_cache()

    return running_loss / len(all_preds), all_labels, all_preds

# Training loop
num_epochs = 5
batch_size = 8

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Training step
    train_loss, train_labels, train_preds = process_dataset(model, ds, 'train', label_keys, image_transforms, optimizer, train=True, batch_size=batch_size)
    print(f"Training Loss: {train_loss:.4f}")

    del train_labels, train_preds
    torch.cuda.empty_cache()

    # Validation step
    val_loss, val_labels, val_preds = process_dataset(model, ds, 'validation', label_keys, image_transforms, batch_size=batch_size)
    print(f"Validation Loss: {val_loss:.4f}")

    del val_labels, val_preds
    torch.cuda.empty_cache()

print("Training complete. You can now evaluate the model using the evaluation pipeline.")