In [None]:
import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch.nn.functional as F  # Import F for softmax
import torch.optim as optim
from PIL import Image, ImageFile
import os
import numpy as np
from tqdm.notebook import tqdm
import json
import matplotlib.pyplot as plt

# Parameters
identifier = 'experiment_01'
class_names = ['CityA', 'CityB', 'CityC', 'CityD']
folders = {
    'CityA': '../data/ma-boston/buildings',
    'CityB': '../data/nc-charlotte/buildings',
    'CityC': '../data/ny-manhattan/buildings',
    'CityD': '../data/pa-pittsburgh/buildings'
}
output_folder = 'softmax-output'
normalize_mean = [0.485, 0.456, 0.406]
normalize_std = [0.229, 0.224, 0.225]
batch_size = 128
num_classes = len(class_names)
train_split_ratio = 0.8
num_epochs = 10
learning_rate = 0.001
checkpoint_interval = 1
checkpoint_dir = os.path.join(output_folder, f'checkpoints-{identifier}')
model_save_path = os.path.join(output_folder, f'trained-model-{identifier}.pth')
loss_log_path = os.path.join(output_folder, f'loss-log-{identifier}.json')
training_params_path = os.path.join(output_folder, f'training-params-{identifier}.json')
feature_file_name = f'city-features-{identifier}.npy'
new_image_path = '../data/ny-brooklyn/buildings/buildings_1370.jpg' # test an image
predictions_output_file = os.path.join(output_folder, f'predictions-{identifier}.txt')

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Define output folder
os.makedirs(output_folder, exist_ok=True)

# Define a custom dataset class
class CityDataset(Dataset):
    def __init__(self, folders, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(folders.keys())}

        for class_name, folder in folders.items():
            for filename in os.listdir(folder):
                if filename.endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(folder, filename))
                    self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path)

        if self.transform:
            image = self.transform(image)

        return image, label

# Create dataset
dataset = CityDataset(folders)

# Automatically detect the input image size
first_image_path = dataset.image_paths[0]
first_image = Image.open(first_image_path)
image_size = first_image.size  # (width, height)

# Define transformations: resize, convert to tensor, normalize
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=normalize_mean, std=normalize_std),
])

# Update dataset with transform
dataset.transform = transform
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Load a pre-trained ResNet50 model using the 'weights' parameter
weights = models.ResNet50_Weights.DEFAULT
model = models.resnet50(weights=weights)

# Modify the final layer to match the number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Set device to MPS if available, otherwise fall back to CUDA or CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

model.to(device)

# Save model outputs as features
def extract_and_save_features(model, data_loader, file_name):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for images, label in tqdm(data_loader, desc="Extracting features", leave=False):
            images = images.to(device)
            outputs = model(images)
            features.append(outputs.cpu().numpy())
            labels.append(label.numpy())
    
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    np.save(os.path.join(output_folder, file_name), {'features': features, 'labels': labels})

# Extract and save features
extract_and_save_features(model, data_loader, feature_file_name)

# Model training
def train_and_save_model(model, train_loader, num_epochs, checkpoint_interval, checkpoint_dir):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_log = []

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            progress_bar.set_postfix(loss=loss.item())

        epoch_loss = running_loss / len(train_loader.dataset)
        loss_log.append(epoch_loss)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

        # Save checkpoint
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            os.makedirs(checkpoint_dir, exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)
            print(f'Checkpoint saved at {checkpoint_path}')

    # Save the final model weights
    torch.save(model.state_dict(), model_save_path)

    # Save the loss log
    with open(loss_log_path, 'w') as f:
        json.dump(loss_log, f)

    # Save the training parameters and plot
    training_params = {
        'num_epochs': num_epochs,
        'learning_rate': learning_rate,
        'checkpoint_interval': checkpoint_interval,
        'batch_size': train_loader.batch_size,
        'identifier': identifier
    }
    with open(training_params_path, 'w') as f:
        json.dump(training_params, f)

    # Plot the loss log
    plt.figure()
    plt.plot(range(1, num_epochs + 1), loss_log, marker='o')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig(os.path.join(output_folder, f'training_loss_plot-{identifier}.png'))
    plt.close()

# Train and save the model
train_size = int(train_split_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

train_and_save_model(model, train_loader, num_epochs, checkpoint_interval, checkpoint_dir)

# Classification using softmax
# Load model weights
if os.path.exists(checkpoint_dir):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_dir)
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    print("Loading full model...")
    model.load_state_dict(torch.load(model_save_path))

# Function to classify a new image
def classify_new_image(image_path, model, transform):
    model.eval()
    input_image = Image.open(image_path)
    input_tensor = transform(input_image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        probabilities = probabilities.cpu().numpy().flatten()

    predictions = [(class_names[i], prob * 100) for i, prob in enumerate(probabilities)]
    predictions.sort(key=lambda x: x[1], reverse=True)
    return predictions

# Example of classifying a new image with progress bar
with tqdm(total=1, desc="Classifying new image", leave=False) as pbar:
    predictions = classify_new_image(new_image_path, model, transform)
    pbar.update(1)

# Save predictions to a file
with open(predictions_output_file, 'w') as f:
    for label, percentage in predictions:
        f.write(f'Predicted class: {label}, Confidence: {percentage:.2f}%\n')
        print(f'Predicted class: {label}, Confidence: {percentage:.2f}%')

### Model training

In [None]:
def train_and_save_model(model, train_loader, num_epochs=10, checkpoint_interval=1, checkpoint_dir='checkpoints'):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_log = []

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            progress_bar.set_postfix(loss=loss.item())

        epoch_loss = running_loss / len(train_loader.dataset)
        loss_log.append(epoch_loss)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

        # Save checkpoint
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            os.makedirs(checkpoint_dir, exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)
            print(f'Checkpoint saved at {checkpoint_path}')

    # Save the final model weights
    torch.save(model.state_dict(), 'resnet50_city_classifier.pth')

    # Save the loss log
    with open('loss_log.json', 'w') as f:
        json.dump(loss_log, f)

    # Save the training parameters
    training_params = {
        'num_epochs': num_epochs,
        'learning_rate': 0.001,
        'checkpoint_interval': checkpoint_interval,
        'batch_size': train_loader.batch_size
    }
    with open('training_params.json', 'w') as f:
        json.dump(training_params, f)

# Train and save the model
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

train_and_save_model(model, train_loader)

### Classification

In [None]:
import torch.nn.functional as F  # Import F for softmax

# Paths to the model and checkpoint
model_path = 'resnet50_city_classifier.pth'
checkpoint_path = 'resnet50_city_classifier_checkpoint.pth'

# Load model weights
if os.path.exists(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    print("Loading full model...")
    model.load_state_dict(torch.load(model_path))

# Function to classify a new image
def classify_new_image(image_path, model, transform):
    model.eval()
    input_image = Image.open(image_path)
    input_tensor = transform(input_image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        probabilities = probabilities.cpu().numpy().flatten()

    predictions = [(class_names[i], prob * 100) for i, prob in enumerate(probabilities)]
    predictions.sort(key=lambda x: x[1], reverse=True)
    return predictions

# Example of classifying a new image with progress bar
new_image_path = '../data/ny-brooklyn/buildings/buildings_1370.jpg'
with tqdm(total=1, desc="Classifying new image", leave=False) as pbar:
    predictions = classify_new_image(new_image_path, model, transform)
    pbar.update(1)

# Save predictions to a file
output_file_path = os.path.join(output_folder, 'predictions.txt')
with open(output_file_path, 'w') as f:
    for label, percentage in predictions:
        f.write(f'Predicted class: {label}, Confidence: {percentage:.2f}%\n')
        print(f'Predicted class: {label}, Confidence: {percentage:.2f}%')