In [1]:
import os
import torch
import pandas as pd
from PIL import Image
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from torchvision import transforms,models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch import nn, optim
from sklearn.model_selection import train_test_split

In [2]:
CSV_PATH = "d:/multiagent-disaster-reasoning/data/fusion_cleaned.csv"
IMAGE_ROOT = "d:/multiagent-disaster-reasoning/data/CrisisMMD_v2.0/data_image"
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class DisasterImageDataset(Dataset):
    def __init__(self, df, image_root, transform=None):
        self.df = df
        self.image_root = image_root
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        for _ in range(10):
            row = self.df.iloc[idx]
            image_path = os.path.join(self.image_root, row["image_name"])
            try:
                image = Image.open(image_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                label = row["label"]
                return image, label
            except Exception as e:
                print(f" skipping unreadable image : {image_path}")
                idx = (idx+1) % len(self.df)
        raise RuntimeError("All 10 attempts to load image failed.")


df = pd.read_csv(CSV_PATH)

label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['disaster_type'])

num_classes = df['label'].nunique()
print("Classes:", label_encoder.classes_)

train_df, val_df = train_test_split(df, test_size=0.1, stratify=df['label'], random_state=42)

class_counts=train_df['label'].value_counts().sort_index().values
class_weights=1. / class_counts
sample_weights= train_df['label'].map(lambda x:class_weights[x])
sampler = WeightedRandomSampler(sample_weights,len(sample_weights),replacement=True)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset=DisasterImageDataset(train_df,IMAGE_ROOT, transform =train_transform)
val_dataset = DisasterImageDataset(val_df,IMAGE_ROOT,transform = val_transform)

train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,sampler =sampler,num_workers =0)
val_loader= DataLoader(val_dataset, batch_size =BATCH_SIZE,shuffle=False,num_workers =0)

model = models.resnet50(weights ="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features,num_classes)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr = LEARNING_RATE)

print("Verifying first batch...")
for imgs, labels in train_loader:
    print(f"Batch loaded : {imgs.shape}, labels: {labels.shape}")
    break

for epoch in range(EPOCHS):
    model.train()
    running_loss, correct = 0.0, 0
    for imgs, labels in tqdm(train_loader, desc = f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    acc = correct/len(train_dataset)
    print(f"Train Loss : {running_loss: .4f} | Accuracy: {acc:.4f}")

    model.eval()
    val_correct = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            val_correct += (outputs.argmax(1) == labels.long()).sum().item()
    val_acc = val_correct/len(val_dataset)
    print(f"Validation Accuracy : {val_acc:.4f}")

torch.save(model.state_dict(), "vision_agent_resnet50.pth")
print("model saved as vision_agent_resnet50.pth")


Classes: ['california_wildfires' 'hurricane_harvey' 'hurricane_irma'
 'hurricane_maria' 'iraq_iran_earthquake' 'srilanka_floods']
Verifying first batch...
Batch loaded : torch.Size([32, 3, 224, 224]), labels: torch.Size([32])


Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [32:24<00:00,  8.10s/it]


Train Loss :  79.6857 | Accuracy: 0.8882
Validation Accuracy : 0.8911


Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [32:16<00:00,  8.07s/it]


Train Loss :  33.1956 | Accuracy: 0.9525
Validation Accuracy : 0.8876


Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [37:19<00:00,  9.33s/it]


Train Loss :  24.9271 | Accuracy: 0.9674
Validation Accuracy : 0.9087


Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [33:21<00:00,  8.34s/it]


Train Loss :  21.5960 | Accuracy: 0.9697
Validation Accuracy : 0.8993


Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [37:24<00:00,  9.35s/it]


Train Loss :  17.5957 | Accuracy: 0.9771
Validation Accuracy : 0.8911


Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [52:52<00:00, 13.22s/it]


Train Loss :  14.5955 | Accuracy: 0.9807
Validation Accuracy : 0.9052


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [54:27<00:00, 13.61s/it]


Train Loss :  12.6550 | Accuracy: 0.9841
Validation Accuracy : 0.9122


Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [51:34<00:00, 12.89s/it]


Train Loss :  13.2956 | Accuracy: 0.9829
Validation Accuracy : 0.8934


Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [50:52<00:00, 12.72s/it]


Train Loss :  12.9791 | Accuracy: 0.9836
Validation Accuracy : 0.9063


Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [51:09<00:00, 12.79s/it]


Train Loss :  10.5683 | Accuracy: 0.9867
Validation Accuracy : 0.9040
model saved as vision_agent_resnet50.pth
