# Linear latent space traversal

In [None]:
# Define the arguments for Linear Latent Space Traversal
# generator = Instance of the GAN generator model

In [None]:
def linear_latent_space_traversal(generator, start_vector, end_vector, num_samples):
    """
    Perform linear latent space traversal between two noise vectors and generate images.
    
    Args:
    generator (nn.Module): The trained generator model.
    start_vector (torch.Tensor): Starting noise vector.
    end_vector (torch.Tensor): Ending noise vector.
    num_samples (int): Number of samples to generate along the traversal.
    
    Returns:
    torch.Tensor: Tensor of generated images.
    """
    generator.eval()  # Set generator to evaluation mode
    
    # Ensure input vectors are on the correct device
    device = next(generator.parameters()).device  # Get the device of the generator
    start_vector = start_vector.to(device)  # Move start vector to the correct device
    end_vector = end_vector.to(device)  # Move end vector to the correct device
    
    # Generate interpolation coefficients
    alphas = torch.linspace(0, 1, num_samples).to(device)  # Create evenly spaced coefficients
    
    # Initialize list to store generated images
    generated_images = []
    
    # Generate images for each interpolated vector
    with torch.no_grad():  # Disable gradient computation for inference
        for alpha in alphas:
            # Perform linear interpolation between start and end vectors
            interpolated_vector = start_vector * (1 - alpha) + end_vector * alpha
            # Ensure the interpolated vector has the correct shape (1, latent_dim, 1, 1)
            interpolated_vector = interpolated_vector.unsqueeze(0)
            # Generate image from interpolated vector
            generated_image = generator(interpolated_vector)
            generated_images.append(generated_image)
    
    # Stack all generated images into a single tensor
    return torch.cat(generated_images, dim=0)  # Return tensor of generated images

In [None]:
# Create dummy data for testing linear latent space traversal
latent_dim = 100  # Latent dimension size, should match the generator's input size
num_samples = 20  # Number of samples for traversal

# Generate random start and end vectors
start_vector = torch.randn(latent_dim, 1, 1)  # Random start vector with shape (100, 1, 1)
end_vector = torch.randn(latent_dim, 1, 1)  # Random end vector with shape (100, 1, 1)

# Ensure the generator is on the correct device
device = next(generator.parameters()).device  # Get the device of the generator
generator = generator.to(device)  # Move generator to the correct device

# Perform linear latent space traversal
traversal_images = linear_latent_space_traversal(generator, start_vector, end_vector, num_samples)  # Generate images along the traversal

# Display the generated images
plt.figure(figsize=(20, 4))  # Create a new figure with specified size
for i in range(num_samples):  # Iterate through each generated image
    plt.subplot(1, num_samples, i + 1)  # Create a subplot for each image
    img = (traversal_images[i].squeeze().permute(1, 2, 0).cpu().numpy() + 1) / 2  # Denormalize the image
    plt.imshow(img)  # Display the image
    plt.axis('off')  # Turn off axis labels
plt.tight_layout()  # Adjust the layout to prevent overlap
plt.show()  # Display the plot

print(f"Shape of traversal_images: {traversal_images.shape}")  # Print the shape of the generated images tensor