In [3]:
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image, ImageFile
import os
import numpy as np
from tqdm.notebook import tqdm
import json
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import random

# Parameters to tweak
batch_size = 64  # 32
learning_rate = 1e-5 #0.001
num_epochs = 200  # 5
checkpoint_interval = 50

# Directories
identifier = f"softmax-resnet_{batch_size}-batch_{learning_rate}-learning_{num_epochs}-epochs"
class_names = ['Boston', 'Charlotte', 'Manhattan', 'Pittsburgh']
folders = {
    'Boston': '../data/ma-boston/buildings',
    'Charlotte': '../data/nc-charlotte/buildings',
    'Manhattan': '../data/ny-manhattan/buildings',
    'Pittsburgh': '../data/pa-pittsburgh/buildings'
}
output_folder = os.path.join('softmax-output', identifier)
checkpoint_dir = os.path.join(output_folder, 'checkpoints')
model_save_path = os.path.join(output_folder, 'trained-model.pth')
loss_log_path = os.path.join(output_folder, 'loss-log.json')
new_image_path = '../data/ny-brooklyn/buildings/buildings_1370.jpg'
predictions_output_file = os.path.join(output_folder, 'predictions.txt')

# More Parameters
normalize_mean = [0.485, 0.456, 0.406]
normalize_std = [0.229, 0.224, 0.225]
num_classes = len(class_names)
weight_decay = 1e-5

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Define output folder
os.makedirs(output_folder, exist_ok=True)

# Define a custom dataset class
class CityDataset(Dataset):
    def __init__(self, folders, transform=None, max_images_per_class=500):  # Limit images per class
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(folders.keys())}

        for class_name, folder in folders.items():
            class_images = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(('.jpg', '.jpeg', '.png'))]
            selected_images = random.sample(class_images, min(max_images_per_class, len(class_images)))
            self.image_paths.extend(selected_images)
            self.labels.extend([self.class_to_idx[class_name]] * len(selected_images))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Create dataset
dataset = CityDataset(folders)

# Define transformations with data augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a smaller, fixed size
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=normalize_mean, std=normalize_std),
])

# Update dataset with transform
dataset.transform = transform

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# Load a pre-trained ResNet18 model (smaller than ResNet50)
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)

# Modify the final layer to match the number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(device)

# Model training function
def train_and_save_model(model, train_loader, val_loader, num_epochs, checkpoint_interval, checkpoint_dir):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    train_loss_log = []
    val_loss_log = []
    val_accuracy_log = []

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        train_loss_log.append(epoch_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_epoch_loss = val_loss / len(val_loader.dataset)
        val_accuracy = correct / total
        val_loss_log.append(val_epoch_loss)
        val_accuracy_log.append(val_accuracy)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Val Loss: {val_epoch_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

        # Save checkpoint
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            os.makedirs(checkpoint_dir, exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)

    # Save the final model weights
    torch.save(model.state_dict(), model_save_path)

    # Save the loss and accuracy logs
    with open(loss_log_path, 'w') as f:
        json.dump({
            'train_loss': train_loss_log,
            'val_loss': val_loss_log,
            'val_accuracy': val_accuracy_log
        }, f)

    # Plot the loss and accuracy curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), train_loss_log, label='Train Loss')
    plt.plot(range(1, num_epochs + 1), val_loss_log, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), val_accuracy_log)
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'training_curves.png'))
    plt.close()

# K-fold cross-validation function
def k_fold_cross_validation(dataset, num_folds=3):  # Reduced number of folds
    kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset), 1):
        print(f"Fold {fold}")
        
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)
        
        train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler)
        val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_subsampler)
        
        model = models.resnet18(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.to(device)
        
        train_and_save_model(model, train_loader, val_loader, num_epochs, checkpoint_interval, 
                             os.path.join(checkpoint_dir, f'fold_{fold}'))
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = correct / total
        fold_results.append(accuracy)
        print(f"Fold {fold} accuracy: {accuracy:.4f}")
    
    print(f"Average accuracy across folds: {sum(fold_results) / len(fold_results):.4f}")

# Run k-fold cross-validation
k_fold_cross_validation(dataset)

# Function to predict on a new image
def predict_image(model, image_path, transform):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image)
        probabilities = F.softmax(outputs, dim=1)[0]
        predicted_class = torch.argmax(probabilities).item()
    
    return probabilities, predicted_class

# Predict on a new image
new_image_probabilities, new_image_class = predict_image(model, new_image_path, transform)

# Print and save predictions
predictions = [f"{class_names[i]}: {prob:.2f}" for i, prob in enumerate(new_image_probabilities)]
print(f"Predictions for {new_image_path}:")
print(f"Predicted class: {class_names[new_image_class]}")
print("Class probabilities:")
for pred in predictions:
    print(pred)

with open(predictions_output_file, 'w') as f:
    f.write(f"Predictions for {new_image_path}:\n")
    f.write(f"Predicted class: {class_names[new_image_class]}\n")
    f.write("Class probabilities:\n")
    for pred in predictions:
        f.write(f"{pred}\n")

print(f"Predictions saved to {predictions_output_file}")

Using device: mps
Fold 1


Epoch [1/50]:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch [1/50], Train Loss: 0.8915, Val Loss: 0.3843, Val Accuracy: 0.4633


Epoch [2/50]:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch [2/50], Train Loss: 0.5957, Val Loss: 0.2719, Val Accuracy: 0.7121


Epoch [3/50]:   0%|          | 0/21 [00:00<?, ?it/s]