## Model 1: Convolutional Neural Network (CNN) with Attention Layer

This section demonstrates how to build and train a Convolutional Neural Network (CNN) with an attention mechanism. The model is designed for image classification tasks. We will use the Keras library to construct the model, compile it, and train it on a dataset.

### Steps:
1. **Build the CNN model with an attention layer.**
2. **Compile the model with an appropriate optimizer and loss function.**
3. **Train the model on a given dataset.**
4. **Plot the training accuracy and loss.**
5. **Evaluate the model's performance.**
6. **Visualize the attention weights on sample images.**

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, Dict,List

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import List, Tuple

# Define the transformation pipeline
transform: transforms.Compose = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize with mean and std for each channel
])

# Load CIFAR-10 dataset
# trainset and testset are instances of torchvision.datasets.CIFAR10
trainset: torchvision.datasets.CIFAR10 = torchvision.datasets.CIFAR10(
    root='./data',  # Directory to store the dataset
    train=True,     # This is the training set
    download=True,  # Download if not present
    transform=transform  # Apply the defined transforms
)
testset: torchvision.datasets.CIFAR10 = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,    # This is the test set
    download=True, 
    transform=transform
)

# Set batch size for data loading
batch_size: int = 64

# Create data loaders
# DataLoader handles batching, shuffling, and parallel data loading
trainloader: DataLoader = DataLoader(
    trainset, 
    batch_size=batch_size, 
    shuffle=True,   # Shuffle the training data
    num_workers=2   # Number of subprocesses for data loading
)
testloader: DataLoader = DataLoader(
    testset, 
    batch_size=batch_size, 
    shuffle=False,  # No need to shuffle test data
    num_workers=2
)

# Class names for CIFAR-10 dataset (to simulate satellite image classes)
class_names: List[str] = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# If you need to access the raw data (similar to the original cell)
# Convert uint8 to float and normalize to [0, 1]
train_data: torch.Tensor = torch.tensor(trainset.data).float() / 255.0
test_data: torch.Tensor = torch.tensor(testset.data).float() / 255.0
train_labels: torch.Tensor = torch.tensor(trainset.targets)
test_labels: torch.Tensor = torch.tensor(testset.targets)

# Move data to GPU if available
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = train_data.to(device)
test_data = test_data.to(device)
train_labels = train_labels.to(device)
test_labels = test_labels.to(device)

# Print shapes for verification
print(f"Train data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")
print(f"Train labels shape: {train_labels.shape}")
print(f"Test labels shape: {test_labels.shape}")

Files already downloaded and verified
Files already downloaded and verified
Train data shape: torch.Size([50000, 32, 32, 3])
Test data shape: torch.Size([10000, 32, 32, 3])
Train labels shape: torch.Size([50000])
Test labels shape: torch.Size([10000])


In [4]:
# Get a batch of test images and labels
test_images, test_labels = next(iter(testloader))

In [5]:
class CNNModel(nn.Module):
    """
    A Convolutional Neural Network (CNN) model implemented in PyTorch.
    """
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 8 * 8, 64)  # Changed from 64 * 4 * 4 to 64 * 8 * 8
        self.fc2 = nn.Linear(64, 10)
        
        # Activation functions
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
        x (torch.Tensor): Input tensor of shape (batch_size, 3, 32, 32)

        Returns:
        torch.Tensor: Output tensor of shape (batch_size, 10)
        """

        #x = x.permute(0, 3, 1, 2)
        # Apply convolutional layers with ReLU activation and max pooling
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.relu(self.conv3(x))

        # print("Shape before reshape:", x.shape)  # Add this line
        
        # Flatten the output for the fully connected layers
        x = x.reshape(x.size(0), -1)
        
        # Apply fully connected layers with ReLU activation
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        
        # Apply softmax to the output
        x = self.softmax(x)
        
        return x

def create_cnn_model() -> nn.Module:
    """
    Creates a Convolutional Neural Network (CNN) model.

    Returns:
    nn.Module: The constructed CNN model.
    """
    return CNNModel()

In [None]:

# Create the model instance
cnn_model = CNNModel().to('cuda:0')  # Adjust num_classes if needed

def train_model(model, train_loader, test_loader, criterion, optimizer, epochs):
    # Set the device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print ('device = ',device)
    model = model.to(device)  # Move the model to the selected device
    
    # Initialize dictionary to store training history
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    
    for epoch in range(epochs):
        # Training phase
        model.train()  # Set model to training mode
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to device
            #print (inputs.shape)
            #inputs = inputs.permute(0, 3, 1, 2)  # Permute here, before passing to the model
            #print ('After permute, shape is ',inputs.shape)
            optimizer.zero_grad()  # Reset gradients
           
            #inputs = inputs.permute(0, 3, 1, 2)
            
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights
            
            # Accumulate loss and accuracy statistics
            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Calculate average training loss and accuracy
        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total
        
        # Validation phase
        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():  # Disable gradient calculation for validation
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                #inputs = inputs.permute(0, 3, 1, 2)  # Permute here as well
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                # Accumulate validation loss and accuracy statistics
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        # Calculate average validation loss and accuracy
        val_loss = val_loss / len(test_loader.dataset)
        val_acc = val_correct / val_total
        
        # Store the results in the history dictionary
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}: "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    return history

# Define the loss function (criterion)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)  # You can adjust the learning rate

# Train the model
history = train_model(cnn_model, trainloader, testloader, criterion, optimizer, epochs=100)

device =  cuda
Epoch 1/100: Train Loss: 2.0750, Train Acc: 0.3778, Val Loss: 1.9730, Val Acc: 0.4870


In [None]:
def rescale_image(img: torch.Tensor) -> np.ndarray:
    """
    Rescale the image to be between 0 and 1.
    """
    img = img.numpy()
    img = (img - img.min()) / (img.max() - img.min())
    img = np.transpose(img, (1, 2, 0))  # Change from (C, H, W) to (H, W, C)
    return img

In [None]:
# Set the model to evaluation mode
cnn_model.eval()

# Get a batch of test images and labels using the testloader
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
test_images, test_labels = next(iter(test_loader))

# Move test_images to the same device as the model
device = next(cnn_model.parameters()).device
test_images = test_images.to(device)
test_labels = test_labels.to(device)

# No need to permute dimensions as the DataLoader already provides the correct format
# test_images is already in shape [N, C, H, W]

# Get a batch of test images (assuming test_images is a PyTorch tensor)
test_batch = test_images[:5]  # Take the first 5 images

In [None]:

# Get predictions
with torch.no_grad():  # Disable gradient computation
    outputs = cnn_model(test_batch)
    probabilities = torch.softmax(outputs, dim=1)
    predictions = probabilities.cpu().numpy()  # Convert to numpy array

# Now 'predictions' is a numpy array containing the probabilities for each class

In [None]:
def display_predictions(predictions: np.ndarray, test_images: torch.Tensor, test_labels: torch.Tensor, class_names: List[str], num_images=5) -> None:
    """
    Displays the actual and predicted labels for a subset of test images.
    """
    # Ensure tensors are on CPU
    test_images = test_images.cpu()
    test_labels = test_labels.cpu()
    
    num_images = min(num_images, len(test_images))
    for i in range(num_images):
        plt.figure(figsize=(10, 4))
        
        # Display the original image with actual label
        plt.subplot(1, 2, 1)
        plt.imshow(rescale_image(test_images[i]))
        plt.axis('off')
        plt.title(f'Actual: {class_names[test_labels[i].item()]}')
        
        # Display the original image with predicted label
        plt.subplot(1, 2, 2)
        plt.imshow(rescale_image(test_images[i]))
        plt.axis('off')
        plt.title(f'Predicted: {class_names[np.argmax(predictions[i])]}')
        
        plt.show()

# Assume predictions, test_images, test_labels, and class_names are already defined

# Call the function
display_predictions(predictions, test_images, test_labels, class_names)

## Generative Adversarial Network (GAN)

This section demonstrates how to create, compile, train, and generate images using a Generative Adversarial Network (GAN). GANs consist of two neural networks, a generator and a discriminator, that compete against each other to produce realistic images.

### Steps:
1. **Create the generator and discriminator models.**
2. **Compile the models with appropriate optimizers and loss functions.**
3. **Train the GAN on a dataset of images obtained from the European Space Agency**
4. **Generate and display new images using the trained generator.**

# Applications of GANs in Satellite Imagery Processing

GANs (Generative Adversarial Networks) can significantly enhance the processing and analysis of satellite imagery in various ways:

1. **Super-resolution**
   - Enhance the resolution of low-quality satellite images
   - Useful for working with older or lower-resolution satellite data

2. **Image-to-image translation**
   - Transform images between different domains (e.g., daytime to nighttime)
   - Convert between different spectral bands

3. **Cloud removal**
   - Remove cloud cover from images to reveal ground features

4. **Data augmentation**
   - Generate synthetic satellite imagery to expand training datasets

5. **Change detection**
   - Compare GAN-generated "expected" images with actual imagery to detect changes

6. **Filling in missing data**
   - Reconstruct missing or corrupted parts of satellite images

7. **Multi-temporal analysis**
   - Generate time series of satellite imagery for studying changes over time

8. **Domain adaptation**
   - Adapt imagery from one geographic region to match characteristics of another

9. **Sensor fusion**
   - Combine data from multiple satellite sensors into composite images

10. **Anomaly detection**
    - Identify unusual features or patterns in satellite imagery

These applications can improve the quality and usability of satellite data, enabling more accurate analysis in fields such as:

- Environmental monitoring
- Urban planning
- Agriculture
- Disaster response

The specific application of GANs in your project would depend on your particular challenges and goals. For instance:

- Use super-resolution for low-quality imagery
- Apply cloud removal if cloud cover is a significant issue
- Implement change

In [1]:
import requests
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os
import zipfile
from PIL import Image
import io
import random
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import List, Tuple, Optional
from torchvision.utils import save_image
from torch.utils.data import TensorDataset, DataLoader, Dataset

In [2]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In this setup, we're not generating anomalous images. Instead, we're using the existing EuroSAT dataset to create a dataset suitable for anomaly detection. Here's how it works:

1. We define certain classes as "normal" (e.g., 'Forest', 'AnnualCrop').
2. All other classes in the dataset are considered potential "anomalies".
3. We create a dataset that consists mostly of "normal" images, with a small percentage of "anomalous" images mixed in.

This approach is based on the idea that in a real-world anomaly detection scenario, you typically have mostly normal data with occasional anomalies.

Here's a breakdown of what the code is doing:

1. It downloads all images from the EuroSAT dataset.
2. It separates the images into two categories:
   - Normal: Images from the classes specified in `normal_classes`
   - Anomaly: Images from all other classes
3. It keeps all the "normal" images.
4. It randomly selects a subset of the "anomaly" images to match the specified `anomaly_ratio`.
5. It combines these normal and anomaly images to create the final dataset.

The resulting dataset can be used to train and evaluate anomaly detection models, including GANs. Here's how you might use it with a GAN:

1. Train the GAN only on the normal images (label 0).
2. Use the trained GAN to detect anomalies:
   - The GAN should be able to reconstruct normal images well.
   - It should struggle to reconstruct anomalous images (which it wasn't trained on).
   - This difference in reconstruction quality can be used to detect anomalies.

We're not generating new anomalous images. Instead, we're using real images from different classes as stand-ins for anomalies. This approach allows us to create a controlled dataset for developing and testing anomaly detection methods.

In [3]:
class EuroSATAnomalyDataset(Dataset):
    def __init__(self, 
                 image_size: int = 64, 
                 normal_classes: Optional[List[str]] = None, 
                 anomaly_ratio: float = 0.1, 
                 root_dir: Optional[str] = None):
        self.image_size: int = image_size
        self.normal_classes: List[str] = normal_classes if normal_classes else ['AnnualCrop']
        self.anomaly_ratio: float = anomaly_ratio
        
        self.root_dir: str = r"c:\Users\bbrel\esa_webinar"
        
        self.data_dir: str = os.path.join(self.root_dir, 'data')
        self.zip_path: str = os.path.join(self.data_dir, 'EuroSAT.zip')
        self.extract_path: str = os.path.join(self.data_dir, '2750')

        self.transform: transforms.Compose = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        self.data: List[Tuple[str, int]] = self.download_and_extract_dataset()

    def download_and_extract_dataset(self) -> List[Tuple[str, int]]:
        if not os.path.exists(self.zip_path):
            raise FileNotFoundError(f"EuroSAT.zip not found at {self.zip_path}. Please ensure the file is in the correct location.")

        if not os.path.exists(self.extract_path):
            print(f"Extracting dataset to {self.extract_path}")
            with zipfile.ZipFile(self.zip_path, 'r') as zip_ref:
                zip_ref.extractall(self.data_dir)

        data: List[Tuple[str, int]] = []
        all_classes: List[str] = [d for d in os.listdir(self.extract_path) if os.path.isdir(os.path.join(self.extract_path, d))]
        
        if not all_classes:
            raise FileNotFoundError(f"No class directories found in {self.extract_path}. Please check the dataset structure.")

        anomaly_classes: List[str] = [c for c in all_classes if c not in self.normal_classes]

        for class_name in all_classes:
            class_path: str = os.path.join(self.extract_path, class_name)
            is_normal: bool = class_name in self.normal_classes
            for img_name in os.listdir(class_path):
                img_path: str = os.path.join(class_path, img_name)
                if is_normal or (not is_normal and random.random() < self.anomaly_ratio):
                    data.append((img_path, 0 if is_normal else 1))

        return data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img_path, label = self.data[idx]
        image: Image.Image = Image.open(img_path).convert('RGB')
        image: torch.Tensor = self.transform(image)
        return image, label



In [4]:
def load_eurosat_anomaly_dataset(
    batch_size: int = 32, 
    image_size: int = 64, 
    normal_classes: Optional[List[str]] = None, 
    anomaly_ratio: float = 0.1, 
    num_workers: int = 4
) -> Tuple[DataLoader, EuroSATAnomalyDataset]:
    dataset: EuroSATAnomalyDataset = EuroSATAnomalyDataset(image_size, normal_classes, anomaly_ratio)
    dataloader: DataLoader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return dataloader, dataset


In [5]:
# Usage
config = {
    'latent_dim': 100,
    'lr': 0.0002,
    'beta1': 0.5,
    'beta2': 0.999,
    'num_epochs': 100,
    'batch_size': 64,
    'image_size': 64,
    'print_every': 100,
    'save_interval': 10
}

In [6]:
normal_classes = None  
anomaly_ratio = 0.1  
dataloader, dataset = load_eurosat_anomaly_dataset(config['batch_size'], config['image_size'], normal_classes, anomaly_ratio)

print(f'Dataloader length: {len(dataloader)}')
print(f'Dataloader batch size: {dataloader.batch_size}')

# Print some information about the dataset
print(f"Dataset size: {len(dataset)}")
print(f"Normal classes: {dataset.normal_classes}")
print(f"Anomaly ratio: {dataset.anomaly_ratio}")

# Check the balance of normal and anomaly samples
normal_count = sum(1 for _, label in dataset if label == 0)
anomaly_count = sum(1 for _, label in dataset if label == 1)
print(f"Normal samples: {normal_count}")
print(f"Anomaly samples: {anomaly_count}")

Dataset size: 5352
Normal classes: ['AnnualCrop']
Anomaly ratio: 0.1
Normal samples: 3000
Anomaly samples: 2352


In [7]:
class Generator(nn.Module):
    
    
    """
    Generator model for a Generative Adversarial Network (GAN).


    This Generator takes a random noise vector of size 100 and transforms it 
    into a 64x64 RGB image. The architecture consists of several transposed 
    convolutional layers with increasing sizes, followed by batch normalization 
    and ReLU activation. The final layer uses a Tanh activation to produce 
    pixel values in the range [-1, 1].

    The network structure is as follows:
    1. Input: 100-dimensional noise vector
    2. Transposed Conv2d: 512 channels, 4x4
    3. Transposed Conv2d: 256 channels, 8x8
    4. Transposed Conv2d: 128 channels, 16x16
    5. Transposed Conv2d: 64 channels, 32x32
    6. Output layer: 3 channels, 64x64 with tanh activation
   The output is a 64x64x3 image.
   """

    def __init__(self, latent_dim=100, channels=3):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
         
        """
        Forward pass of the generator.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, latent_dim)
        Returns:
            torch.Tensor: Generated images of shape (batch_size, 3, 64, 64)
        """
        
        # Reshape the input noise vector to (batch_size, latent_dim, 1, 1) and
        # pass it through the generator model. This transformation allows the
        # first transposed convolution to interpret each latent vector as a 1x1 "image"
        # with latent_dim channels, which it then upscales to the final 64x64 RGB image.
        
        img = self.model(x.view(x.size(0), self.latent_dim, 1, 1))
        return img



In [8]:
def build_generator() -> nn.Module:
    
    """
    Builds and returns an instance of the Generator model.

    Returns:
        nn.Module: An instance of the Generator model.
    """
    return Generator.to(device)

In [9]:
class Discriminator(nn.Module):
    
    """
    Discriminator model for a Generative Adversarial Network (GAN).

    This Discriminator takes a 64x64 RGB image and classifies it as real 
    or fake. The architecture consists of several fully connected layers with 
    convolutional layers with decreasing sizes, using LeakyReLU activations. 
    The final layer uses a Sigmoid activation to produce a probability output.


    The network structure is as follows:
    1. Input: 64x64x3 image
    2. Conv2d: 64 channels, 4x4 kernel, stride 2, padding 1
    3. Conv2d: 128 channels, 4x4 kernel, stride 2, padding 1
    4. Conv2d: 256 channels, 4x4 kernel, stride 2, padding 1
    5. Conv2d: 512 channels, 4x4 kernel, stride 2, padding 1
    6. Conv2d: 1 channel, 4x4 kernel, stride 1, padding 0
    7. Sigmoid activation

    Attributes:
        model (nn.Sequential): The sequential container of layers.

    Methods:
        forward(x): Defines the computation performed at every call.
    """

    def __init__(self, channels: int = 3) -> None:
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
           nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
           nn.LeakyReLU(0.2, inplace=True),
           nn.Conv2d(64, 128, 4, 2, 1, bias=False),
           nn.BatchNorm2d(128),
           nn.LeakyReLU(0.2, inplace=True),
           nn.Conv2d(128, 256, 4, 2, 1, bias=False),
           nn.BatchNorm2d(256),
           nn.LeakyReLU(0.2, inplace=True),
           nn.Conv2d(256, 512, 4, 2, 1, bias=False),
           nn.BatchNorm2d(512),
           nn.LeakyReLU(0.2, inplace=True),
           nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the discriminator.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 3, 64, 64)

        Returns:
            torch.Tensor: Probability of input being real, shape (batch_size, 1, 1, 1)
        """
        return self.model(x)



In [10]:
def build_discriminator() -> nn.Module:
    """
    Builds and returns an instance of the Discriminator model.

    Returns:
        nn.Module: An instance of the Discriminator model.
    """
    return Discriminator.to(device)

In [13]:
class DiscriminatorTrainer:

    """
    A class for training a Generative Adversarial Network (GAN) for anomaly detection.

    This class implements the training process for a GAN, specifically tailored for
    anomaly detection tasks. It trains a generator to produce realistic normal samples
    and a discriminator to distinguish between real and generated samples. After training,
    the GAN can be used to detect anomalies by comparing input images to generated ones.

    The class handles the entire training process, including:
    - Initializing and managing the generator and discriminator models
    - Setting up optimizers and loss functions
    - Implementing the GAN training loop
    - Saving generated images and trained models
    - Providing a method for anomaly detection using the trained models

    Attributes:
        generator (nn.Module): The generator model of the GAN.
        discriminator (nn.Module): The discriminator model of the GAN.
        config (dict): A dictionary containing all hyperparameters and configuration settings.
        device (torch.device): The device (CPU or GPU) on which to perform computations.
        criterion (nn.Module): The loss function (typically BCELoss for GANs).
        optimizer_G (torch.optim.Optimizer): The optimizer for the generator.
        optimizer_D (torch.optim.Optimizer): The optimizer for the discriminator.

    The config dictionary should contain the following keys:
    - 'latent_dim': Dimension of the latent space for the generator
    - 'lr': Learning rate for the optimizers
    - 'beta1', 'beta2': Beta parameters for the Adam optimizer
    - 'num_epochs': Number of epochs to train
    - 'batch_size': Batch size for training
    - 'print_every': How often to print training progress
    - 'save_interval': How often to save generated images

    Methods:
        train_step(real_images): Performs a single training step for both generator and discriminator.
        train(dataloader): Runs the full training loop over the entire dataset.
        save_images(epoch): Saves a batch of generated images.
        save_models(): Saves the trained generator and discriminator models.
        detect_anomalies(images, threshold): Uses the trained GAN to detect anomalies in input images.
    """
    
    def __init__(self, generator: nn.Module, discriminator: nn.Module, config: dict):
        
        """
        Initializes the DiscriminatorTrainer with given models and configuration.

        Args:
            generator (nn.Module): The generator model.
            discriminator (nn.Module): The discriminator model.
            config (dict): Configuration dictionary containing hyperparameters.
        """
        
        self.generator = generator
        self.discriminator = discriminator
        self.config = config
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Move models to the appropriate device
        self.generator.to(self.device)
        self.discriminator.to(self.device)
        
        self.criterion = nn.BCELoss()
        self.optimizer_G = optim.Adam(self.generator.parameters(), 
                                      lr=config['lr'], 
                                      betas=(config['beta1'], config['beta2']))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), 
                                      lr=config['lr'], 
                                      betas=(config['beta1'], config['beta2']))

    def train_step(self, real_images: torch.Tensor) -> Tuple[float, float]:

        """
        Performs a single training step for both the generator and discriminator.

        Args:
            real_images (torch.Tensor): A batch of real images to train on.

        Returns:
            tuple: A tuple containing the discriminator loss and generator loss for this step.
        """
        print("Starting train step")
        batch_size = real_images.size(0)
        real_images = real_images.to(self.device)

        # Train Discriminator
        self.optimizer_D.zero_grad()
        
        # Real images
        # Create a tensor of ones with shape (batch_size, 1, 1, 1)
        # This represents the "real" labels for the discriminator
        # The shape (batch_size, 1, 1, 1) matches the output of the discriminator
        # We use ones because we want the discriminator to identify these as real image
        real_labels = torch.ones(batch_size, 1, 1, 1).to(self.device)
        

        # Pass the real images through the discriminator
        # The discriminator attempts to classify these images as real or fake
        # output will have shape (batch_size, 1, 1, 1), where each value represents
        # the discriminator's confidence that the corresponding image is real 
        output = self.discriminator(real_images)
        
        # Calculate the loss for the real images
        # We use Binary Cross Entropy (BCE) loss, comparing the discriminator's output
        # to the real_labels (all ones)
        # This loss will be low if the discriminator correctly identifies the real images,
        # and high if it misclassifies them as fake
        d_loss_real = self.criterion(output, real_labels)
        
        # Fake images
        # Generate random noise as input for the generator
        # The shape is (batch_size, latent_dim), where latent_dim is defined in the config
        # This noise serves as the seed for generating fake images
        z = torch.randn(batch_size, self.config['latent_dim']).to(self.device)

        # Use the generator to create fake images from the random noise
        # fake_images will have the same shape as real images in the dataset
        fake_images = self.generator(z)
        print("Generated fake images")
        # Create a tensor of zeros with shape (batch_size, 1, 1, 1)
        # This represents the "fake" labels for the discriminator
        # We use zeros because we want the discriminator to identify these as fake images
        fake_labels = torch.zeros(batch_size, 1, 1, 1).to(self.device)

        # Pass the fake images through the discriminator
        # We use detach() to prevent gradients from flowing back to the generator
        # This is because we're currently training the discriminator, not the generator 
        output = self.discriminator(fake_images.detach())

        # Calculate the loss for the fake images
        # We use Binary Cross Entropy (BCE) loss, comparing the discriminator's output
        # to the fake_labels (all zeros)
        # This loss will be low if the discriminator correctly identifies the fake images,
        # and high if it misclassifies them as real 
        d_loss_fake = self.criterion(output, fake_labels)
        print("Computed discriminator loss")
        # Combine the losses from real and fake images
        # This gives us the total loss for the discriminator
        d_loss = d_loss_real + d_loss_fake

        # Compute gradients of the total loss with respect to the discriminator's parameters
        # This prepares for the backward pass in neural network training 
        d_loss.backward()
         
        # Update the discriminator's parameters using the computed gradients
        # This is the actual learning step for the discriminator
        self.optimizer_D.step()

        # Train Generator
        # Reset the gradients for the generator's parameters to zero
        # This is necessary before computing new gradients 
        self.optimizer_G.zero_grad()

        # Pass the fake images through the discriminator
        # Note: We're using the same fake images generated earlier, but this time we don't detach them
        # because we want to compute gradients with respect to the generator's parameters
        output = self.discriminator(fake_images)

        # Calculate the generator's loss
        # We use the same criterion (BCE loss) as before, but now we compare the output to real_labels
        # This trains the generator to produce images that the discriminator will classify as real
        g_loss = self.criterion(output, real_labels)
        print("Computed generator loss")

        # Compute gradients of the loss with respect to the generator's parameters 
        g_loss.backward()
         
        # Update the generator's parameters using the computed gradients
        self.optimizer_G.step()

        print("Finished train step")

        return d_loss.item(), g_loss.item()

    def train(self, dataloader: torch.utils.data.DataLoader) -> None:
        """
        Runs the full training loop over the entire dataset for the specified number of epochs.

        Args:
            dataloader (torch.utils.data.DataLoader): The DataLoader containing the training data.
        """

        print('Testing dataloader')
        for i, (real_images, labels) in enumerate(dataloader):
            print(f'Loaded batch {i+1}, shape: {real_images.shape}')
            if i == 2:  # Just print the first 3 batches
                break
        print('Dataloader test complete')
        total_batches = len(dataloader) * self.config['num_epochs']
        print ('total_batches = ',total_batches)
        print ('In train')
        with tqdm(total=total_batches, desc="Training Progress") as pbar:
            print ('Starting for loop')
            for epoch in range(self.config['num_epochs']):
                print(f'Starting epoch {epoch+1}')
                for i, (real_images, labels) in enumerate(dataloader):
                    print(f'Batch {i+1}: Loading data')
                    print(f'Batch {i+1}: Real images shape: {real_images.shape}')
                    print(f'Batch {i+1}: Calling train step')
                    d_loss, g_loss = self.train_step(real_images)

                    # Update tqdm progress bar
                    pbar.update(1)
                    pbar.set_postfix({
                        'Epoch': f"{epoch+1}/{self.config['num_epochs']}",
                        'Batch': f"{i+1}/{len(dataloader)}",
                        'D_loss': f"{d_loss:.4f}",
                        'G_loss': f"{g_loss:.4f}"
                    })

                # Save generated images
                if (epoch + 1) % self.config['save_interval'] == 0:
                    self.save_images(epoch)
                    pbar.write(f"Saved images for epoch {epoch+1}")

        print("Training finished!")
        self.save_models()

    def save_images(self, epoch: int) -> None:

        """
        Saves a batch of generated images.

        Args:
            epoch (int): The current epoch number, used in the filename of the saved image.
        """
        
        with torch.no_grad():
            fake = self.generator(torch.randn(64, self.config['latent_dim'], 1, 1, device=self.device))
            save_image(fake.detach(), f"fake_images_epoch_{epoch+1}.png", normalize=True)

    def save_models(self) -> None:

        """
        Saves the trained generator and discriminator models to disk.
        """
        
        torch.save(self.generator.state_dict(), "generator.pth")
        torch.save(self.discriminator.state_dict(), "discriminator.pth")

    def detect_anomalies(self, images: torch.Tensor, threshold: float) -> torch.Tensor:

        """
        Uses the trained GAN to detect anomalies in input images.

        Args:
            images (torch.Tensor): The input images to check for anomalies.
            threshold (float): The threshold value for determining anomalies.

        Returns:
            torch.Tensor: A boolean tensor indicating whether each input image is an anomaly.
        """
        
        self.generator.eval()
        self.discriminator.eval()

        # Disable gradient computation to save memory and speed up calculations
        # This is appropriate since we're not training, just doing inference
        with torch.no_grad():

        # Generate random noise as input for the generator
        # The shape matches the batch size of input images and the latent dimension from config
        # Adding two extra dimensions (1, 1) at the end to match potential generator input requirements
            z = torch.randn(images.size(0), self.config['latent_dim'], 1, 1).to(self.device)
            
            # Use the generator to create reconstructed images from the random noise
            # These reconstructed images should ideally be similar to normal, non-anomalous images
            reconstructed = self.generator(z)
            
            # Calculate the mean squared error between original images and reconstructed images
            # This error is computed for each image separately, across all channels and pixels
            # The result is a 1D tensor with one value per image in the batch
            reconstruction_error = torch.mean((images - reconstructed) ** 2, dim=(1, 2, 3))

            # Use the reconstruction error as the anomaly score
            # Higher reconstruction error suggests higher likelihood of being an anomaly
            anomaly_scores = reconstruction_error

            # Compare anomaly scores to the threshold
            # Returns a boolean tensor: True for anomalies (score > threshold), False otherwise
            return anomaly_scores > threshold


In [14]:
# Initialize the generator and discriminator
generator = Generator(config['latent_dim']).to(device)
discriminator = Discriminator().to(device)

# Initialize trainer
trainer = DiscriminatorTrainer(generator, discriminator, config)

In [15]:
print("Generator device:", next(trainer.generator.parameters()).device)
print("Discriminator device:", next(trainer.discriminator.parameters()).device)

Generator device: cuda:0
Discriminator device: cuda:0


In [16]:
print("Testing CUDA operation")
test_tensor = torch.rand(100, 100).cuda()
result = torch.matmul(test_tensor, test_tensor)
print("CUDA operation completed")

Testing CUDA operation
CUDA operation completed


In [None]:
# Train the GAN
trainer.train(dataloader)

Testing dataloader


In [None]:
class GANTester:
    
    """
    A class for testing a trained Generative Adversarial Network (GAN) for anomaly detection.

    This class provides methods to load a trained GAN model and use it to detect anomalies
    in a test dataset.

    Attributes:
        generator (nn.Module): The trained generator model.
        discriminator (nn.Module): The trained discriminator model.
        device (torch.device): The device (CPU or GPU) on which to perform computations.
        config (dict): Configuration dictionary containing model and testing parameters.
    """

    def __init__(self, generator: nn.Module, discriminator: nn.Module, config: dict):
        """
        Initializes the GANTester with trained models and configuration.

        Args:
            generator (nn.Module): The trained generator model.
            discriminator (nn.Module): The trained discriminator model.
            config (dict): Configuration dictionary containing model and testing parameters.
        """
        # Store the generator and discriminator models
        self.generator = generator
        self.discriminator = discriminator
        self.config = config

        # Determine the device (GPU if available, otherwise CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Move models to the appropriate device
        self.generator.to(self.device)
        self.discriminator.to(self.device)

        # Set models to evaluation mode
        # This is crucial as it affects certain layers like BatchNorm and Dropout
        self.generator.eval()
        self.discriminator.eval()

    def load_test_data(self) -> DataLoader:
        """
        Loads and prepares the test dataset.

        Returns:
            DataLoader: A DataLoader containing the test dataset.
        """
        # Define the image transformations
        # These transformations ensure that all images are of the same size and format
        transform = transforms.Compose([
            transforms.Resize(self.config['image_size']),  # Resize image to a standard size
            transforms.CenterCrop(self.config['image_size']),  # Crop the center part of the image
            transforms.ToTensor(),  # Convert image to PyTorch tensor
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize pixel values
        ])

        # Load the test dataset from the specified path
        test_dataset = datasets.ImageFolder(root=self.config['test_data_path'], transform=transform)

        # Create a DataLoader for efficient batching and parallel processing
        test_loader = DataLoader(test_dataset, batch_size=self.config['batch_size'], shuffle=False)
        
        return test_loader

    def detect_anomalies(self, images: torch.Tensor, threshold: float) -> torch.Tensor:
        """
        Detects anomalies in the input images using the trained GAN.

        Args:
            images (torch.Tensor): Batch of input images to check for anomalies.
            threshold (float): Threshold value for anomaly detection.

        Returns:
            torch.Tensor: A boolean tensor indicating anomalies (True) and normal samples (False).
        """
        # Disable gradient computation to save memory and speed up calculations
        with torch.no_grad():
            # Generate random noise as input for the generator
            # The shape matches the batch size of input images and the latent dimension from config
            z = torch.randn(images.size(0), self.config['latent_dim'], 1, 1).to(self.device)

            # Use the generator to create reconstructed images from the random noise
            reconstructed = self.generator(z)

            # Calculate the mean squared error between original images and reconstructed images
            # This error is computed for each image separately, across all channels and pixels
            reconstruction_error = torch.mean((images - reconstructed) ** 2, dim=(1, 2, 3))

            # Use the reconstruction error as the anomaly score
            anomaly_scores = reconstruction_error

            # Compare anomaly scores to the threshold
            # Returns a boolean tensor: True for anomalies (score > threshold), False otherwise
            return anomaly_scores > threshold

    def evaluate(self, dataloader: DataLoader, threshold: float) -> Tuple[List[bool], List[int]]:
        """
        Evaluates the anomaly detection performance on the entire test dataset.

        Args:
            dataloader (DataLoader): DataLoader containing the test dataset.
            threshold (float): Threshold value for anomaly detection.

        Returns:
            Tuple[List[bool], List[int]]: A tuple containing lists of anomaly detection results and true labels.
        """
        all_results = []
        all_labels = []

        # Iterate through all batches in the dataloader
        for images, labels in dataloader:
            # Move images to the same device as the models
            images = images.to(self.device)

            # Perform anomaly detection on the current batch
            results = self.detect_anomalies(images, threshold)
            
            # Collect results and labels
            # Convert tensors to Python lists for easier post-processing
            all_results.extend(results.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

        return all_results, all_labels

In [None]:
# Instantiate the GANTester
tester: GANTester = GANTester(generator, discriminator, config)

# Load the test data
test_loader: DataLoader = tester.load_test_data()

# Set a threshold for anomaly detection
threshold: float = 0.5  # Adjust this value based on your specific needs

# Evaluate the model
results: List[bool]
labels: List[int]
results, labels = tester.evaluate(test_loader, threshold)

# Print basic results
print(f"Total samples tested: {len(results)}")
print(f"Anomalies detected: {sum(results)}")
print(f"Actual anomalies: {sum(labels)}")

# Calculate accuracy
accuracy: float = sum(r == l for r, l in zip(results, labels)) / len(results)
print(f"Accuracy: {accuracy:.2f}")

In [None]:
from torchvision.utils import make_grid

def show_images_with_predictions(
    images: torch.Tensor, 
    reconstructed: torch.Tensor, 
    predictions: List[bool], 
    actual: List[int], 
    num_images: int = 20
) -> None:
    
    """
    Plot original images, their reconstructions, and highlight anomalies.
    
    Args:
        images (torch.Tensor): Original test images.
        reconstructed (torch.Tensor): Reconstructed images from the GAN.
        predictions (List[bool]): List of anomaly predictions (True for anomaly).
        actual (List[int]): List of actual labels (1 for anomaly, 0 for normal).
        num_images (int, optional): Number of images to display. Defaults to 20.
    """
    
    # Ensure we don't try to display more images than we have
    num_images = min(num_images, len(images))
    
    # Randomly select indices for the images we'll display
    indices = np.random.choice(len(images), num_images, replace=False)
    
    # Extract the selected images and their corresponding data
    orig_images = images[indices].cpu()
    recon_images = reconstructed[indices].cpu()
    pred_anomalies = [predictions[i] for i in indices]
    actual_anomalies = [actual[i] for i in indices]

    # Create a figure with two rows: original images on top, reconstructions on bottom
    fig, axes = plt.subplots(2, num_images, figsize=(num_images*2, 4))

    # Define color scheme for different prediction outcomes
    color_scheme = {
        (True, True): 'lime',    # True Positive: correctly identified anomaly
        (True, False): 'red',    # False Positive: normal sample incorrectly identified as anomaly
        (False, True): 'yellow', # False Negative: missed anomaly
        (False, False): 'none'   # True Negative: correctly identified normal sample
    }

    # Plot original and reconstructed images
    for i in range(num_images):
        for row, img in enumerate([orig_images, recon_images]):
            ax = axes[row, i]
            
            # Display the image
            ax.imshow(img[i].permute(1, 2, 0).detach().numpy() * 0.5 + 0.5)
            ax.axis('off')
            
            # Set the border color based on prediction and actual label
            border_color = color_scheme[(pred_anomalies[i], actual_anomalies[i])]
            ax.set_facecolor(border_color)

    plt.tight_layout()
    plt.show()

In [None]:
# Function to generate reconstructed images using the GAN
def generate_reconstructions(
    generator: torch.nn.Module, 
    num_samples: int, 
    latent_dim: int, 
    device: torch.device
) -> torch.Tensor:
    """
    Generate reconstructed images using the GAN's generator.

    Args:
        generator (torch.nn.Module): The trained generator model.
        num_samples (int): Number of samples to generate.
        latent_dim (int): Dimension of the latent space.
        device (torch.device): Device to perform the computation on.

    Returns:
        torch.Tensor: The generated (reconstructed) images.
    """
    with torch.no_grad():  # Disable gradient computation
        # Generate random noise as input for the generator
        z = torch.randn(num_samples, latent_dim, 1, 1).to(device)
        
        # Use the generator to create reconstructed images
        reconstructed = generator(z)
        
    return reconstructed

In [None]:
# Generate reconstructed images
reconstructed = generate_reconstructions(
    tester.generator, 
    len(test_loader.dataset), 
    config['latent_dim'], 
    tester.device
)

In [None]:
 # Get a batch of original images
original_images, _ = next(iter(test_loader))

In [None]:
# Ensure we have the same number of original and reconstructed images
num_images = min(len(original_images), len(reconstructed))
original_images = original_images[:num_images]
reconstructed = reconstructed[:num_images]

In [None]:
# Plot the results
show_images_with_predictions(original_images, reconstructed, results[:num_images], labels[:num_images])

## Variational Autoencoder (VAE)

This section demonstrates how to create, compile, train, and generate images using a Variational Autoencoder (VAE). VAEs are generative models that learn to encode input data into a latent space and decode from the latent space to generate new data.

### Steps:
1. **Create the encoder and decoder models.**
2. **Define the VAE by combining the encoder and decoder.**
3. **Compile the VAE with an appropriate loss function.**
4. **Train the VAE on a dataset of images.**
5. **Generate and display new images using the trained decoder.**