## Frechet Inception Distance (FID)

1. **Feature Extraction:** FID uses a pre-trained Inception v3 network to extract features from both real and generated images. This network captures high-level representations of the images.

2. **Distribution Comparison:** It assumes that the feature vectors for both real and generated images follow a multidimensional Gaussian distribution.

3. **Statistical Moments:** FID calculates the mean and covariance of the feature distributions for both real and generated images.

4. **Distance Calculation:** The Frechet distance between these two Gaussian distributions is then computed. This distance is defined as:

$$FID = ||μ_r - μ_g||^2 + Tr(Σ_r + Σ_g - 2\sqrt{Σ_r Σ_g})$$

   Where:
   - $μ_r$ and $μ_g$ are the mean feature vectors for real and generated images
   - $Σ_r$ and $Σ_g$ are the covariance matrices for real and generated images
   - Tr denotes the trace of a matrix (sum of diagonal elements)

**Interpretation:** A lower FID score indicates that the generated images are more similar to the real images. A score of 0 would mean the two distributions are identical.


In [None]:
# Define the arguments for FID calculation
# generator = Instance of the GAN generator model
# dataloader = DataLoader for real images

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((299, 299)),  # Resize images to Inception v3 input size
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet stats ie we will also preprocess the data in same way as the original imagenet dataset that was used to train Inception v3
])  # Define preprocessing steps for Inception v3 input

def prepare_inception_input(images):
    return preprocess(images)  # Apply preprocessing to the input images

The Inception v3 model expects input images with the following characteristics:
- Size: 299x299 pixels
- Normalization: Using ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- Tensor shape: (batch_size, 3, 299, 299)

For FID calculation, we're interested in the features from an intermediate layer of the Inception v3 model, not its final classification output. Typically, we use the output of the last pooling layer before the fully connected layers.
To extract these features, we need to modify how we use the Inception model slightly.

In [None]:
def load_inception_model():
    model = inception_v3(pretrained=True, transform_input=False)  # Load pre-trained Inception v3 model without input transformation
    model.fc = torch.nn.Identity()  # Removes the last fully connected layer
    model.eval()  # Set the model to evaluation mode
    return model.to(device)  # Move the model to the same device as our GAN

inception_model = load_inception_model()  # Load and prepare the modified Inception v3 model

In [None]:
def extract_features(images):
    with torch.no_grad():
        features = inception_model(images)
    return features  # This will be a tensor of shape [batch_size, 2048]

In [None]:
def extract_real_features(dataloader, num_images=1000):
    real_features = []  # List to store features of real images
    image_count = 0  # Counter for processed images

    with torch.no_grad():  # Disable gradient computation for efficiency
        for batch in dataloader:
            images = batch[0].to(device)  # Move batch of images to the device (assuming batch[0] contains images)
            batch_size = images.size(0)  # Get the current batch size
            
            if image_count + batch_size > num_images:
                # If adding this batch would exceed num_images, only take what's needed
                images = images[:num_images - image_count]
            
            preprocessed_images = prepare_inception_input(images)  # Preprocess images for Inception v3
            batch_features = extract_features(preprocessed_images)  # Extract features using Inception v3
            
            real_features.append(batch_features)  # Add batch features to the list
            
            image_count += batch_features.size(0)  # Update the count of processed images
            if image_count >= num_images:
                break  # Stop if we've processed enough images

    real_features_tensor = torch.cat(real_features, dim=0)  # Concatenate all features into a single tensor
    return real_features_tensor[:num_images]  # Return exactly num_images features

# Extract features from 1000 real images
real_features = extract_real_features(dataloader, num_images=1000)  # Extract features from 1000 real images

In [None]:
# Calculate mean and covariance of real features
real_mean = torch.mean(real_features, dim=0)  # Calculate mean across all samples for each feature
real_cov = torch.cov(real_features.T)  # Calculate covariance matrix of features

In [None]:
def generate_and_extract_features(generator, num_samples=1000, batch_size=64):
    generated_features = []  # List to store features of generated images
    num_batches = (num_samples + batch_size - 1) // batch_size  # Calculate number of batches needed

    generator.eval()  # Set generator to evaluation mode
    print(f"Starting feature extraction for {num_samples} samples with batch size {batch_size}")  # Print start of extraction process
    with torch.no_grad():  # Disable gradient computation for efficiency
        for batch_idx in range(num_batches):
            current_batch_size = min(batch_size, num_samples - len(generated_features))  # Adjust batch size for last batch if needed
            print(f"Processing batch {batch_idx + 1}/{num_batches} with size {current_batch_size}")  # Print current batch information
            noise = torch.randn(current_batch_size, 100, 1, 1, device=device)  # Generate noise for input to generator
            fake_images = generator(noise)  # Generate fake images
            
            preprocessed_images = prepare_inception_input(fake_images)  # Preprocess images for Inception v3
            batch_features = extract_features(preprocessed_images)  # Extract features using Inception v3
            
            generated_features.append(batch_features)  # Add batch features to the list
            print(f"Extracted features shape: {batch_features.shape}")  # Print shape of extracted features
            
            if len(generated_features) * batch_size >= num_samples:
                print(f"Reached target number of samples. Stopping extraction.")  # Print when target samples reached
                break  # Stop if we've generated enough samples
    print(f"Feature extraction completed. Total features extracted: {len(generated_features) * batch_size}")  # Print completion of extraction process

    generated_features_tensor = torch.cat(generated_features, dim=0)  # Concatenate all features into a single tensor
    return generated_features_tensor[:num_samples]  # Return exactly num_samples features

In [None]:
# Generate 1000 samples and extract their features
generated_features = generate_and_extract_features(generator, num_samples=1000)  # Generate and extract features from 1000 fake images

In [None]:
# Calculate mean and covariance of generated features
generated_mean = torch.mean(generated_features, dim=0)  # Calculate mean across all samples for each feature
generated_cov = torch.cov(generated_features.T)  # Calculate covariance matrix of features

In [None]:
def calculate_frechet_inception_distance(real_mean, real_cov, generated_mean, generated_cov):
    """
    Calculate the Fréchet Inception Distance (FID) between real and generated image features.
    
    Args:
    real_mean (torch.Tensor): Mean of real image features.
    real_cov (torch.Tensor): Covariance matrix of real image features.
    generated_mean (torch.Tensor): Mean of generated image features.
    generated_cov (torch.Tensor): Covariance matrix of generated image features.
    
    Returns:
    float: The calculated FID score.
    """
    import torch  # Import torch for tensor operations
    from scipy import linalg  # Import linalg for matrix operations
    
    # Convert to numpy for scipy operations
    real_mean_np = real_mean.cpu().numpy()  # Convert real mean to numpy array
    real_cov_np = real_cov.cpu().numpy()  # Convert real covariance to numpy array
    generated_mean_np = generated_mean.cpu().numpy()  # Convert generated mean to numpy array
    generated_cov_np = generated_cov.cpu().numpy()  # Convert generated covariance to numpy array
    
    # Calculate squared L2 norm between means
    mean_diff = np.sum((real_mean_np - generated_mean_np) ** 2)  # Compute squared difference between means
    
    # Calculate sqrt of product of covariances
    covmean = linalg.sqrtm(real_cov_np.dot(generated_cov_np))  # Compute matrix square root
    
    # Check and correct imaginary parts if necessary
    if np.iscomplexobj(covmean):
        covmean = covmean.real  # Take only the real part if result is complex
    
    # Calculate trace term
    trace_term = np.trace(real_cov_np + generated_cov_np - 2 * covmean)  # Compute trace of the difference
    
    # Compute FID
    fid = mean_diff + trace_term  # Sum up mean difference and trace term
    
    return fid  # Return FID as a Python float

# Calculate FID using the improved function
fid_score = calculate_frechet_inception_distance(real_mean, real_cov, generated_mean, generated_cov)  # Compute FID score
print(f"Fréchet Inception Distance (FID): {fid_score:.4f}")  # Print the calculated FID score