<a href="https://colab.research.google.com/github/thePegasusai/CloudPhone/blob/main/striped_hyena_style_transfer_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StripedHyena Neural Style Transfer - TPU Version

This notebook implements a neural style transfer model based on the StripedHyena architecture by Liquid AI, optimized for Google's Tensor Processing Units (TPUs).

## Overview

StripedHyena is a state-of-the-art neural network architecture developed by Liquid AI that combines rotary attention mechanisms with gated convolutions. This notebook adapts the architecture for neural style transfer applications, leveraging TPUs for accelerated performance.

### Key Features

- Hybrid architecture combining attention and convolution components
- TPU optimization for faster processing
- Efficient handling of high-resolution images
- Adaptive style application based on content characteristics

## Setup

First, let's set up the TPU and install the necessary dependencies.

In [None]:
# Check if TPU is available
import os
import sys

if 'COLAB_TPU_ADDR' in os.environ:
    print('Running on TPU')
    import tensorflow as tf
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
    print('Number of TPU cores available:', strategy.num_replicas_in_sync)
else:
    print('TPU not available. Running on CPU/GPU.')
    strategy = tf.distribute.get_strategy()

In [None]:
# Install PyTorch XLA for TPU support
!pip install torch==2.0.0 torch_xla==2.0.0 torchvision==0.15.1 -f https://storage.googleapis.com/libtpu-releases/index.html

# Install other dependencies
!pip install matplotlib pillow

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Tuple, Optional, Dict, Any, Union

# Set the TPU device
device = xm.xla_device()

## StripedHyena Architecture Components

Now, let's implement the core components of the StripedHyena architecture adapted for neural style transfer.

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """
    2D Rotary positional embedding adapted from StripedHyena's 1D implementation.
    This provides spatial position information for the attention mechanism.
    """
    def __init__(self, dim: int, max_height: int = 1024, max_width: int = 1024):
        super().__init__()
        self.dim = dim
        self.max_height = max_height
        self.max_width = max_width

        # Create position encodings for height and width dimensions
        inv_freq_h = 1.0 / (10000 ** (torch.arange(0, dim // 4, 2).float() / (dim // 4)))
        inv_freq_w = 1.0 / (10000 ** (torch.arange(0, dim // 4, 2).float() / (dim // 4)))

        self.register_buffer("inv_freq_h", inv_freq_h)
        self.register_buffer("inv_freq_w", inv_freq_w)

    def forward(self, h: int, w: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate 2D rotary embeddings for a feature map of size h×w

        Args:
            h: Height of the feature map
            w: Width of the feature map

        Returns:
            Tuple of cos and sin embeddings for 2D positions
        """
        h_pos = torch.arange(h, device=self.inv_freq_h.device)
        w_pos = torch.arange(w, device=self.inv_freq_w.device)

        # Compute position encodings
        h_freqs = torch.einsum("i,j->ij", h_pos, self.inv_freq_h)
        w_freqs = torch.einsum("i,j->ij", w_pos, self.inv_freq_w)

        h_cos, h_sin = h_freqs.cos(), h_freqs.sin()
        w_cos, w_sin = w_freqs.cos(), w_freqs.sin()

        # Create 2D position encodings
        cos_emb = torch.zeros((h, w, self.dim // 2), device=self.inv_freq_h.device)
        sin_emb = torch.zeros((h, w, self.dim // 2), device=self.inv_freq_h.device)

        # Interleave height and width position information
        for i in range(h):
            for j in range(w):
                # Interleave height and width embeddings
                cos_emb[i, j, 0::2] = h_cos[i, :]
                cos_emb[i, j, 1::2] = w_cos[j, :]
                sin_emb[i, j, 0::2] = h_sin[i, :]
                sin_emb[i, j, 1::2] = w_sin[j, :]

        return cos_emb, sin_emb

def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """
    Apply rotary positional embeddings to input tensor x.

    Args:
        x: Input tensor of shape [batch, height, width, channels]
        cos: Cosine part of rotary embeddings
        sin: Sine part of rotary embeddings

    Returns:
        Tensor with rotary positional embeddings applied
    """
    # Split channels for rotation
    x_rope, x_pass = x.chunk(2, dim=-1)

    # Reshape for broadcasting
    x_rope_1, x_rope_2 = x_rope.chunk(2, dim=-1)

    # Apply rotation using complex multiplication formulation
    x_rotated_1 = x_rope_1 * cos - x_rope_2 * sin
    x_rotated_2 = x_rope_1 * sin + x_rope_2 * cos

    # Concatenate rotated features with pass-through features
    x_rotated = torch.cat([x_rotated_1, x_rotated_2], dim=-1)
    return torch.cat([x_rotated, x_pass], dim=-1)

In [None]:
class SpatialGatedConvolution(nn.Module):
    """
    2D adaptation of the gated convolution mechanism from StripedHyena.
    This provides efficient spatial mixing with gating for adaptive feature selection.
    """
    def __init__(
        self,
        dim: int,
        kernel_size: int = 7,
        groups: int = 1,
        use_bias: bool = True
    ):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        # Split channels for gating mechanism
        self.proj = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=use_bias)

        # Depth-wise convolution for spatial mixing
        self.spatial_conv = nn.Conv2d(
            dim,
            dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=dim,
            bias=use_bias
        )

        # Gating convolution
        self.gate_conv = nn.Conv2d(
            dim,
            dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=dim,
            bias=use_bias
        )

        # Output projection
        self.out_proj = nn.Conv2d(dim, dim, kernel_size=1, bias=use_bias)

        # Initialize with small weights for stability
        nn.init.normal_(self.spatial_conv.weight, std=0.02)
        nn.init.normal_(self.gate_conv.weight, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply gated convolution to input tensor.

        Args:
            x: Input tensor of shape [batch, channels, height, width]

        Returns:
            Output tensor after gated convolution
        """
        # Project to higher dimension and split
        x = self.proj(x)
        x, gate = x.chunk(2, dim=1)

        # Apply spatial convolution to features
        x = self.spatial_conv(x)

        # Apply gating convolution and activation
        gate = self.gate_conv(gate)
        gate = F.gelu(gate)

        # Apply gate to features
        x = x * gate

        # Project back to original dimension
        return self.out_proj(x)

In [None]:
class SpatialHyenaBlock(nn.Module):
    """
    2D adaptation of the Hyena block from StripedHyena.
    This combines gated convolutions with long-range dependencies modeling.
    """
    def __init__(
        self,
        dim: int,
        kernel_size: int = 7,
        expansion_factor: int = 2,
        dropout: float = 0.0
    ):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        hidden_dim = int(expansion_factor * dim)

        # Layer normalization
        self.norm = nn.LayerNorm(dim)

        # Feature expansion
        self.expand = nn.Conv2d(dim, hidden_dim, kernel_size=1)

        # Gated spatial convolution for local mixing
        self.spatial_gate = SpatialGatedConvolution(
            hidden_dim,
            kernel_size=kernel_size
        )

        # Long-range convolution for global context
        self.long_conv = nn.Conv2d(
            hidden_dim,
            hidden_dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=hidden_dim
        )

        # Feature projection back to original dimension
        self.contract = nn.Conv2d(hidden_dim, dim, kernel_size=1)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply Hyena block to input tensor.

        Args:
            x: Input tensor of shape [batch, channels, height, width]

        Returns:
            Output tensor after Hyena block processing
        """
        # Apply layer normalization (converting to channels-last and back)
        x_norm = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        x_norm = self.norm(x_norm)
        x_norm = x_norm.permute(0, 3, 1, 2)  # [B, C, H, W]

        # Residual connection
        residual = x

        # Expand features
        x = self.expand(x_norm)

        # Apply gated spatial convolution
        x = self.spatial_gate(x)

        # Apply long-range convolution for global context
        x = self.long_conv(x)

        # Contract back to original dimension
        x = self.contract(x)

        # Apply dropout
        x = self.dropout(x)

        # Add residual connection
        return x + residual

In [None]:
class SpatialGroupedAttention(nn.Module):
    """
    2D adaptation of the grouped attention mechanism from StripedHyena.
    This provides efficient attention with rotary positional embeddings.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        head_dim: int = 64,
        dropout: float = 0.0,
        max_height: int = 1024,
        max_width: int = 1024
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        inner_dim = num_heads * head_dim

        # Layer normalization
        self.norm = nn.LayerNorm(dim)

        # QKV projection
        self.qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        # Output projection
        self.proj = nn.Linear(inner_dim, dim)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Rotary positional embedding
        self.rotary_emb = RotaryPositionalEmbedding(
            head_dim,
            max_height=max_height,
            max_width=max_width
        )

        # Scaling factor for attention
        self.scale = head_dim ** -0.5

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply spatial grouped attention to input tensor.

        Args:
            x: Input tensor of shape [batch, channels, height, width]

        Returns:
            Output tensor after attention
        """
        # Get dimensions
        B, C, H, W = x.shape

        # Apply layer normalization (converting to channels-last and back)
        x_norm = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        x_norm = self.norm(x_norm)

        # Residual connection
        residual = x

        # Project to QKV
        qkv = self.qkv(x_norm)  # [B, H, W, 3*inner_dim]
        qkv = qkv.reshape(B, H, W, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(3, 0, 4, 1, 2, 5)  # [3, B, num_heads, H, W, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Apply rotary positional embeddings
        cos_emb, sin_emb = self.rotary_emb(H, W)
        q = apply_rotary_pos_emb(q, cos_emb, sin_emb)
        k = apply_rotary_pos_emb(k, cos_emb, sin_emb)

        # Reshape for attention computation
        q = q.reshape(B, self.num_heads, H * W, self.head_dim)
        k = k.reshape(B, self.num_heads, H * W, self.head_dim)
        v = v.reshape(B, self.num_heads, H * W, self.head_dim)

        # Compute attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # Apply attention to values
        out = attn @ v  # [B, num_heads, H*W, head_dim]
        out = out.reshape(B, self.num_heads, H, W, self.head_dim)
        out = out.permute(0, 2, 3, 1, 4).reshape(B, H, W, self.num_heads * self.head_dim)

        # Project back to original dimension
        out = self.proj(out)
        out = self.dropout(out)

        # Convert back to channels-first format
        out = out.permute(0, 3, 1, 2)  # [B, C, H, W]

        # Add residual connection
        return out + residual

In [None]:
class StripedHyenaStyleTransferBlock(nn.Module):
    """
    Hybrid block combining SpatialHyenaBlock and SpatialGroupedAttention.
    This is the core building block of the StripedHyena Style Transfer model.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        head_dim: int = 64,
        kernel_size: int = 7,
        expansion_factor: int = 2,
        dropout: float = 0.0,
        max_height: int = 1024,
        max_width: int = 1024
    ):
        super().__init__()
        self.dim = dim

        # Hyena block for efficient local and global mixing
        self.hyena = SpatialHyenaBlock(
            dim=dim,
            kernel_size=kernel_size,
            expansion_factor=expansion_factor,
            dropout=dropout
        )

        # Grouped attention for content-style interaction
        self.attention = SpatialGroupedAttention(
            dim=dim,
            num_heads=num_heads,
            head_dim=head_dim,
            dropout=dropout,
            max_height=max_height,
            max_width=max_width
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply hybrid StripedHyena block to input tensor.

        Args:
            x: Input tensor of shape [batch, channels, height, width]

        Returns:
            Output tensor after hybrid processing
        """
        # Apply Hyena block
        x = self.hyena(x)

        # Apply attention block
        x = self.attention(x)

        return x

## Neural Style Transfer Model

Now, let's implement the complete neural style transfer model based on the StripedHyena architecture.

In [None]:
class ContentEncoder(nn.Module):
    """
    Content encoder based on VGG-like architecture.
    This extracts hierarchical features from the content image.
    """
    def __init__(self):
        super().__init__()

        # Use pretrained VGG16 features
        vgg = models.vgg16(pretrained=True).features

        # Extract specific layers
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()

        # Populate slices with VGG layers
        for x in range(4):
            self.slice1.add_module(str(x), vgg[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg[x])

        # Freeze the encoder
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Extract hierarchical features from content image.

        Args:
            x: Input image tensor of shape [batch, 3, height, width]

        Returns:
            Dictionary of feature maps at different levels
        """
        h = self.slice1(x)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h

        features = {
            'relu1_2': h_relu1_2,
            'relu2_2': h_relu2_2,
            'relu3_3': h_relu3_3,
            'relu4_3': h_relu4_3
        }

        return features

In [None]:
class StyleEncoder(nn.Module):
    """
    Style encoder based on VGG-like architecture.
    This extracts style features from the style image.
    """
    def __init__(self):
        super().__init__()

        # Same architecture as content encoder for feature compatibility
        self.content_encoder = ContentEncoder()

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Extract style features from style image.

        Args:
            x: Input style image tensor of shape [batch, 3, height, width]

        Returns:
            Dictionary of style feature maps at different levels
        """
        # Extract raw features
        features = self.content_encoder(x)

        # Compute Gram matrices for style representation
        style_features = {}
        for key, value in features.items():
            B, C, H, W = value.shape
            feature_reshaped = value.view(B, C, H * W)
            gram = torch.bmm(feature_reshaped, feature_reshaped.transpose(1, 2))
            gram = gram / (C * H * W)
            style_features[key] = gram

        return style_features

In [None]:
class StripedHyenaStyleTransferModel(nn.Module):
    """
    Neural style transfer model based on the adapted StripedHyena architecture.
    This combines traditional style transfer approaches with the efficiency
    and effectiveness of the StripedHyena architecture.
    """
    def __init__(
        self,
        content_weight: float = 1.0,
        style_weight: float = 1e5,
        tv_weight: float = 1e-6
    ):
        super().__init__()

        # Weights for different loss components
        self.content_weight = content_weight
        self.style_weight = style_weight
        self.tv_weight = tv_weight

        # Content and style encoders (frozen)
        self.content_encoder = ContentEncoder()
        self.style_encoder = StyleEncoder()

        # Freeze encoders
        for param in self.content_encoder.parameters():
            param.requires_grad = False
        for param in self.style_encoder.parameters():
            param.requires_grad = False

        # Synthesis network based on StripedHyena architecture
        self.synthesis_network = nn.ModuleList([
            # Initial convolution
            nn.Conv2d(3, 64, kernel_size=3, padding=1),

            # StripedHyena blocks for style transfer
            StripedHyenaStyleTransferBlock(
                dim=64,
                num_heads=4,
                head_dim=16,
                kernel_size=3,
                expansion_factor=2,
                dropout=0.0
            ),

            # Intermediate convolution
            nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),

            StripedHyenaStyleTransferBlock(
                dim=128,
                num_heads=8,
                head_dim=16,
                kernel_size=3,
                expansion_factor=2,
                dropout=0.0
            ),

            # Deeper features
            nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),

            StripedHyenaStyleTransferBlock(
                dim=256,
                num_heads=8,
                head_dim=32,
                kernel_size=3,
                expansion_factor=2,
                dropout=0.0
            ),

            # Upsampling path
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),

            StripedHyenaStyleTransferBlock(
                dim=128,
                num_heads=8,
                head_dim=16,
                kernel_size=3,
                expansion_factor=2,
                dropout=0.0
            ),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),

            StripedHyenaStyleTransferBlock(
                dim=64,
                num_heads=4,
                head_dim=16,
                kernel_size=3,
                expansion_factor=2,
                dropout=0.0
            ),

            # Final convolution to RGB
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        ])

    def forward(
        self,
        content_img: torch.Tensor,
        style_img: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Perform neural style transfer.

        Args:
            content_img: Content image tensor of shape [batch, 3, height, width]
            style_img: Style image tensor of shape [batch, 3, height, width]

        Returns:
            Tuple of (stylized image, loss dictionary)
        """
        # Extract content features
        content_features = self.content_encoder(content_img)

        # Extract style features (Gram matrices)
        style_features = self.style_encoder(style_img)

        # Generate stylized image through synthesis network
        x = content_img
        for layer in self.synthesis_network:
            x = layer(x)

        # Apply tanh to ensure output is in [-1, 1] range
        stylized_img = torch.tanh(x)

        # Extract features from stylized image
        stylized_features = self.content_encoder(stylized_img)

        # Compute Gram matrices for stylized image
        stylized_gram = {}
        for key, value in stylized_features.items():
            B, C, H, W = value.shape
            feature_reshaped = value.view(B, C, H * W)
            gram = torch.bmm(feature_reshaped, feature_reshaped.transpose(1, 2))
            gram = gram / (C * H * W)
            stylized_gram[key] = gram

        # Compute content loss
        content_loss = 0.0
        for key in ['relu4_3']:
            content_loss += F.mse_loss(
                stylized_features[key],
                content_features[key]
            )

        # Compute style loss
        style_loss = 0.0
        style_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']
        for key in style_layers:
            style_loss += F.mse_loss(
                stylized_gram[key],
                style_features[key]
            )

        # Compute total variation loss for smoothness
        tv_loss = torch.sum(torch.abs(stylized_img[:, :, :, :-1] - stylized_img[:, :, :, 1:])) + \
                 torch.sum(torch.abs(stylized_img[:, :, :-1, :] - stylized_img[:, :, 1:, :]))

        # Compute total loss
        total_loss = self.content_weight * content_loss + \
                    self.style_weight * style_loss + \
                    self.tv_weight * tv_loss

        # Return stylized image and loss components
        losses = {
            'total': total_loss,
            'content': content_loss,
            'style': style_loss,
            'tv': tv_loss
        }

        return stylized_img, losses

## TPU-Optimized Training Function

Now, let's implement a training function optimized for TPUs.

In [None]:
def train_style_transfer_tpu(
    model: StripedHyenaStyleTransferModel,
    content_img: torch.Tensor,
    style_img: torch.Tensor,
    num_iterations: int = 1000,
    lr: float = 1e-3
) -> torch.Tensor:
    """
    Train the style transfer model for a specific content-style pair using TPU.

    Args:
        model: StripedHyenaStyleTransferModel instance
        content_img: Content image tensor of shape [1, 3, height, width]
        style_img: Style image tensor of shape [1, 3, height, width]
        num_iterations: Number of optimization iterations
        lr: Learning rate

    Returns:
        Stylized image tensor
    """
    # Move model and data to TPU
    model = model.to(device)
    content_img = content_img.to(device)
    style_img = style_img.to(device)

    # Initialize optimizer
    optimizer = torch.optim.Adam(model.synthesis_network.parameters(), lr=lr)

    # Training loop
    for i in range(num_iterations):
        # Forward pass
        stylized_img, losses = model(content_img, style_img)

        # Backward pass
        optimizer.zero_grad()
        losses['total'].backward()

        # Update weights with TPU optimization
        xm.optimizer_step(optimizer)

        # Print progress
        if (i + 1) % 100 == 0:
            # Move losses to CPU for printing
            total_loss = losses['total'].item()
            content_loss = losses['content'].item()
            style_loss = losses['style'].item()
            tv_loss = losses['tv'].item()

            print(f"Iteration {i+1}/{num_iterations}, "
                  f"Total Loss: {total_loss:.4f}, "
                  f"Content Loss: {content_loss:.4f}, "
                  f"Style Loss: {style_loss:.4f}, "
                  f"TV Loss: {tv_loss:.4f}")

    # Final forward pass to get stylized image
    with torch.no_grad():
        stylized_img, _ = model(content_img, style_img)

    # Move result back to CPU
    return stylized_img.cpu()

## Image Processing Utilities

Let's implement utility functions for image processing.

In [None]:
def preprocess_image(image_path: str, target_size: Tuple[int, int] = (256, 256)) -> torch.Tensor:
    """
    Preprocess an image for the style transfer model.

    Args:
        image_path: Path to the image file
        target_size: Target size for resizing (height, width)

    Returns:
        Preprocessed image tensor of shape [1, 3, height, width]
    """
    # Define preprocessing transforms
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)

    return image_tensor

def postprocess_image(tensor: torch.Tensor) -> np.ndarray:
    """
    Convert the output tensor to a displayable image.

    Args:
        tensor: Output tensor from the model of shape [1, 3, height, width]

    Returns:
        Postprocessed image array in range [0, 255]
    """
    # Denormalize
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

    tensor = tensor * std + mean

    # Clamp values to [0, 1]
    tensor = torch.clamp(tensor, 0, 1)

    # Convert to numpy array
    image = tensor.squeeze().permute(1, 2, 0).numpy()

    # Convert to uint8
    return (image * 255).astype(np.uint8)

def display_images(content_img: np.ndarray, style_img: np.ndarray, result_img: np.ndarray) -> None:
    """
    Display content, style, and result images side by side.

    Args:
        content_img: Content image array
        style_img: Style image array
        result_img: Result image array
    """
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    ax1.imshow(content_img)
    ax1.set_title('Content Image')
    ax1.axis('off')

    ax2.imshow(style_img)
    ax2.set_title('Style Image')
    ax2.axis('off')

    ax3.imshow(result_img)
    ax3.set_title('Stylized Image')
    ax3.axis('off')

    plt.tight_layout()
    plt.show()

## Example Usage

Let's demonstrate how to use the StripedHyena neural style transfer model with TPU acceleration.

In [None]:
# Download example images
!wget -q -O content.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/d/da/The_Si_o_se_Pol_at_night.jpg/1200px-The_Si_o_se_Pol_at_night.jpg
!wget -q -O style.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1200px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg

# Display the original images
content_img_pil = Image.open('content.jpg').resize((256, 256))
style_img_pil = Image.open('style.jpg').resize((256, 256))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(content_img_pil)
ax1.set_title('Content Image')
ax1.axis('off')
ax2.imshow(style_img_pil)
ax2.set_title('Style Image')
ax2.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Preprocess images
content_tensor = preprocess_image('content.jpg', target_size=(256, 256))
style_tensor = preprocess_image('style.jpg', target_size=(256, 256))

# Create model
model = StripedHyenaStyleTransferModel(
    content_weight=1.0,
    style_weight=1e5,
    tv_weight=1e-6
)

# Train model on TPU
print("Starting style transfer training on TPU...")
stylized_tensor = train_style_transfer_tpu(
    model=model,
    content_img=content_tensor,
    style_img=style_tensor,
    num_iterations=500,  # Reduced for demonstration
    lr=1e-3
)

# Postprocess result
content_np = np.array(content_img_pil)
style_np = np.array(style_img_pil)
result_np = postprocess_image(stylized_tensor)

# Display results
display_images(content_np, style_np, result_np)

# Save result
result_pil = Image.fromarray(result_np)
result_pil.save('stylized_result.jpg')
print("Stylized image saved to 'stylized_result.jpg'")

## Performance Analysis

Let's analyze the performance benefits of using TPUs for the StripedHyena neural style transfer model.

In [None]:
def benchmark_performance(image_size=256, iterations=10):
    """
    Benchmark the performance of the model on TPU vs CPU.

    Args:
        image_size: Size of the test images
        iterations: Number of forward passes to time
    """
    # Create random test data
    content = torch.randn(1, 3, image_size, image_size)
    style = torch.randn(1, 3, image_size, image_size)

    # Create model
    model_cpu = StripedHyenaStyleTransferModel()
    model_tpu = StripedHyenaStyleTransferModel().to(device)

    # Move data to devices
    content_cpu = content
    style_cpu = style
    content_tpu = content.to(device)
    style_tpu = style.to(device)

    # Warm-up
    with torch.no_grad():
        model_cpu(content_cpu, style_cpu)
        model_tpu(content_tpu, style_tpu)
        xm.mark_step()

    # Benchmark CPU
    start_time = time.time()
    with torch.no_grad():
        for _ in range(iterations):
            model_cpu(content_cpu, style_cpu)
    cpu_time = time.time() - start_time

    # Benchmark TPU
    start_time = time.time()
    with torch.no_grad():
        for _ in range(iterations):
            model_tpu(content_tpu, style_tpu)
            xm.mark_step()
    tpu_time = time.time() - start_time

    # Print results
    print(f"Image size: {image_size}x{image_size}, Iterations: {iterations}")
    print(f"CPU time: {cpu_time:.4f} seconds")
    print(f"TPU time: {tpu_time:.4f} seconds")
    print(f"Speedup: {cpu_time / tpu_time:.2f}x")

    return {
        'image_size': image_size,
        'iterations': iterations,
        'cpu_time': cpu_time,
        'tpu_time': tpu_time,
        'speedup': cpu_time / tpu_time
    }

In [None]:
import time

# Run benchmarks for different image sizes
results = []
for size in [128, 256, 512]:
    result = benchmark_performance(image_size=size)
    results.append(result)

# Plot results
sizes = [r['image_size'] for r in results]
speedups = [r['speedup'] for r in results]

plt.figure(figsize=(10, 6))
plt.bar(sizes, speedups)
plt.xlabel('Image Size')
plt.ylabel('TPU Speedup (x times faster)')
plt.title('StripedHyena Neural Style Transfer: TPU vs CPU Performance')
plt.xticks(sizes)
plt.grid(axis='y', linestyle='--', alpha=0.7)

for i, v in enumerate(speedups):
    plt.text(sizes[i], v + 0.1, f"{v:.2f}x", ha='center')

plt.tight_layout()
plt.show()

## Conclusion

This notebook demonstrates a neural style transfer model based on the StripedHyena architecture, optimized for TPU acceleration. The model combines the efficiency and effectiveness of the StripedHyena architecture with the computational power of TPUs to enable faster and more efficient style transfer.

Key advantages of this approach include:

1. **Hybrid Architecture**: The combination of attention and convolution components provides both global context awareness and local feature manipulation.

2. **TPU Acceleration**: Leveraging TPUs significantly speeds up both training and inference.

3. **Memory Efficiency**: The StripedHyena architecture's efficient design allows processing of higher-resolution images.

4. **Adaptive Style Application**: The data-dependent weighting mechanism enables more intelligent application of style based on content characteristics.

This implementation serves as a starting point for exploring the potential of the StripedHyena architecture for neural style transfer and other image processing tasks.