# Non linear(Cosine Interpolation) latent space traversal

The formula for cosine interpolation between two values $a$ and $b$, with $t$ ranging from 0 to 1, is:

$$f(t) = a + (b - a) * (1 - \cos(\pi * t)) / 2$$

This formula creates a smooth S-curve transition, accelerating at the beginning and decelerating at the end, resulting in more visually pleasing interpolations in latent space.

In [None]:
# Define the arguments for Linear Latent Space Traversal
# generator = Instance of the GAN generator model

In [None]:
def cosine_latent_space_traversal(generator, start_vector, end_vector, num_samples):
    generated_images = []  # List to store generated images
    device = next(generator.parameters()).device  # Get the device of the generator

    for t in torch.linspace(0, 1, num_samples):  # Create evenly spaced points between 0 and 1
        # Cosine interpolation formula
        interpolated_vector = start_vector + (end_vector - start_vector) * (1 - torch.cos(torch.pi * t)) / 2  # Perform cosine interpolation between start and end vectors
        
        # Ensure the interpolated vector has the correct shape (1, 100, 1, 1)
        interpolated_vector = interpolated_vector.view(1, 100, 1, 1).to(device)  # Reshape and move to correct device
        
        # Generate image from the interpolated vector
        with torch.no_grad():  # Disable gradient calculation for inference
            generated_image = generator(interpolated_vector)  # Generate image using the interpolated vector
        
        generated_images.append(generated_image)  # Add the generated image to the list

    # Stack all generated images into a single tensor
    return torch.cat(generated_images, dim=0)  # Return tensor of generated images

In [None]:
# Generate random start and end vectors
start_vector = torch.randn(1, 100, 1, 1)  # Random start vector with shape (1, 100, 1, 1)
end_vector = torch.randn(1, 100, 1, 1)  # Random end vector with shape (1, 100, 1, 1)
num_samples = 30

In [None]:
# Test the cosine interpolation function
cosine_traversal_images = cosine_latent_space_traversal(generator, start_vector, end_vector, num_samples)  # Generate images using cosine interpolation

# Display the generated images from cosine interpolation
plt.figure(figsize=(20, 4))  # Create a new figure with specified size
for i in range(num_samples):  # Iterate through each generated image
    plt.subplot(1, num_samples, i + 1)  # Create a subplot for each image
    img = (cosine_traversal_images[i].squeeze().permute(1, 2, 0).cpu().numpy() + 1) / 2  # Denormalize the image
    plt.imshow(img)  # Display the image
    plt.axis('off')  # Turn off axis labels
plt.tight_layout()  # Adjust the layout to prevent overlap
plt.show()  # Display the plot