<a href="https://colab.research.google.com/github/thePegasusai/stripedhyena/blob/main/striped_hyena_enhanced_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StripedHyena Neural Style Transfer - Enhanced TPU Version

This notebook provides an enhanced implementation of neural style transfer using the StripedHyena architecture developed by Liquid AI. The implementation is specifically optimized for Google Colab's TPU environment.

## What's New in This Enhanced Version

- **Detailed Architecture Explanations**: In-depth explanations of the StripedHyena architecture components
- **Interactive Examples**: User-friendly interface for experimenting with different content and style images
- **Custom Dataset Training**: Support for training on your own dataset of images
- **TPU Optimizations**: Performance enhancements specifically for Google Colab's TPU environment
- **Visualization Tools**: Better visualization of the style transfer process and results
- **Hyperparameter Tuning**: Interactive controls for adjusting style transfer parameters

## Table of Contents

1. [Setup and Dependencies](#setup)
2. [StripedHyena Architecture Overview](#architecture)
3. [Model Implementation](#implementation)
4. [Basic Style Transfer Demo](#basic_demo)
5. [Advanced Style Transfer with Parameter Tuning](#advanced_demo)
6. [Training on Custom Datasets](#custom_training)
7. [Performance Analysis: TPU vs. CPU vs. GPU](#performance)
8. [Exporting Your Model](#export)

## 1. Setup and Dependencies <a name="setup"></a>

First, let's set up our environment and install the necessary dependencies. This notebook is designed to work with TPUs, but will fall back to GPU or CPU if TPUs are not available.

In [2]:
import os
import sys
import torch

def check_tpu_availability():
    """Check if TPU is available and properly configured"""

    # Method 1: Check environment variables
    tpu_address = os.environ.get('COLAB_TPU_ADDR')
    if tpu_address:
        print(f"TPU address found: {tpu_address}")
        IS_TPU_AVAILABLE = True
    else:
        print("COLAB_TPU_ADDR environment variable not found")
        IS_TPU_AVAILABLE = False

    # Method 2: Try to import and detect TPU using torch_xla
    try:
        import torch_xla
        import torch_xla.core.xla_model as xm

        # Check if XLA devices are available
        device = xm.xla_device()
        print(f"XLA device detected: {device}")

        # Get TPU device count (using new API to avoid deprecation warning)
        try:
            import torch_xla.runtime as xr
            device_count = xr.world_size()
        except (ImportError, AttributeError):
            # Fallback to deprecated method if new one isn't available
            device_count = xm.xrt_world_size()
        print(f"Number of TPU cores: {device_count}")

        if device_count > 0:
            IS_TPU_AVAILABLE = True
            print("✅ TPU is available and accessible via torch_xla!")
        else:
            IS_TPU_AVAILABLE = False
            print("❌ TPU cores not detected")

    except ImportError:
        print("torch_xla not installed - installing now...")
        # Install torch_xla for TPU support
        os.system('pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html')
        print("Please restart runtime after installation")
        IS_TPU_AVAILABLE = False

    except Exception as e:
        print(f"Error accessing TPU via torch_xla: {e}")
        IS_TPU_AVAILABLE = False

    return IS_TPU_AVAILABLE

def check_gpu_availability():
    """Check GPU availability"""
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        gpu_name = torch.cuda.get_device_name(0)
        print(f"✅ GPU is available: {gpu_name}")
        print(f"Number of GPUs: {gpu_count}")
        return True
    else:
        print("❌ GPU is not available")
        return False

def setup_device():
    """Setup the appropriate device for training"""

    print("=" * 50)
    print("DEVICE DETECTION")
    print("=" * 50)

    # Check TPU first
    tpu_available = check_tpu_availability()

    if tpu_available:
        try:
            import torch_xla.core.xla_model as xm
            device = xm.xla_device()
            print(f"🚀 Using TPU: {device}")
            return device, "tpu"
        except:
            print("Failed to initialize TPU device")

    # Check GPU if TPU not available
    gpu_available = check_gpu_availability()

    if gpu_available:
        device = torch.device("cuda")
        print(f"🚀 Using GPU: {device}")
        return device, "gpu"

    # Fallback to CPU
    device = torch.device("cpu")
    print("🚀 Using CPU")
    return device, "cpu"

# Run the detection
if __name__ == "__main__":
    device, device_type = setup_device()

    print("\n" + "=" * 50)
    print("ADDITIONAL TPU SETUP (if using TPU)")
    print("=" * 50)

    if device_type == "tpu":
        print("""
To use TPU effectively in your training loop, remember to:

1. Move your model to TPU:
   model = model.to(device)

2. Use xm.optimizer_step() instead of optimizer.step():
   import torch_xla.core.xla_model as xm
   xm.optimizer_step(optimizer)

3. Use xm.mark_step() for gradient synchronization:
   xm.mark_step()

4. For multi-core TPU training, use xm.spawn():
   import torch_xla.distributed.parallel_loader as pl
        """)

DEVICE DETECTION
COLAB_TPU_ADDR environment variable not found
XLA device detected: xla:0
Number of TPU cores: 1
✅ TPU is available and accessible via torch_xla!
🚀 Using TPU: xla:0

ADDITIONAL TPU SETUP (if using TPU)

To use TPU effectively in your training loop, remember to:

1. Move your model to TPU:
   model = model.to(device)

2. Use xm.optimizer_step() instead of optimizer.step():
   import torch_xla.core.xla_model as xm
   xm.optimizer_step(optimizer)

3. Use xm.mark_step() for gradient synchronization:
   xm.mark_step()

4. For multi-core TPU training, use xm.spawn():
   import torch_xla.distributed.parallel_loader as pl
        


In [3]:
import torch
import torch_xla.core.xla_model as xm
import time

device = xm.xla_device()
print(f"Using device: {device}")

# Test computation on TPU
print("\n🧪 Testing TPU computation...")
start_time = time.time()

# Create tensors on TPU
x = torch.randn(1000, 1000).to(device)
y = torch.randn(1000, 1000).to(device)

# Perform matrix multiplication
result = torch.matmul(x, y)

# Important: Mark step to sync computation
xm.mark_step()

end_time = time.time()
print(f"✅ Matrix multiplication completed in {end_time - start_time:.4f} seconds")
print(f"Result shape: {result.shape}")
print(f"Result device: {result.device}")

Using device: xla:0

🧪 Testing TPU computation...
✅ Matrix multiplication completed in 0.9469 seconds
Result shape: torch.Size([1000, 1000])
Result device: xla:0


In [5]:
# Install required packages
!pip install torch torchvision matplotlib numpy pillow tqdm requests

# If TPU is available, install PyTorch XLA
if IS_TPU_AVAILABLE:
    !pip install cloud-tpu-client==0.10 torch_xla[tpu]==2.0 -f https://storage.googleapis.com/libtpu-releases/index.html



In [6]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
import time

# Check if TPU is available (you'll need to define IS_TPU_AVAILABLE or check differently)
try:
    import torch_xla.core.xla_model as xm
    IS_TPU_AVAILABLE = True
except ImportError:
    IS_TPU_AVAILABLE = False

# Set up device
if IS_TPU_AVAILABLE:
    try:
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        print(f"Using TPU: {device}")
    except Exception as e:
        print(f"TPU initialization failed: {e}")
        device = torch.device("cpu")
        print("Falling back to CPU")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Additional setup for different devices
if device.type == 'cuda':
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Device set to: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")

# Test tensor creation on the device
test_tensor = torch.randn(3, 3).to(device)
print(f"Test tensor shape: {test_tensor.shape}")
print(f"Test tensor device: {test_tensor.device}")

Using TPU: xla:0
Device set to: xla:0
PyTorch version: 2.6.0+cpu
Torchvision version: 0.21.0+cpu
Test tensor shape: torch.Size([3, 3])
Test tensor device: xla:0


## 2. StripedHyena Architecture Overview <a name="architecture"></a>

The StripedHyena architecture, developed by Liquid AI, is a hybrid neural network architecture that combines rotary (grouped) attention with gated convolutions. This section provides a detailed explanation of the architecture and its components.

### 2.1 Key Components of StripedHyena

#### Rotary Positional Embeddings
Rotary positional embeddings (RoPE) encode spatial information, allowing the model to understand the relative positions of features in both content and style images. This is crucial for maintaining spatial coherence in the stylized output.

```
def apply_rotary_embeddings(x, freqs):
    # x: [batch, seq_len, dim]
    # freqs: [seq_len, dim/2]
    seq_len = x.shape[1]
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs = freqs[:seq_len]
    freqs = torch.view_as_complex(freqs.reshape(*freqs.shape[:-1], -1, 2))
    x_rotated = x_complex * freqs
    x_rotated = torch.view_as_real(x_rotated).flatten(-2)
    return x_rotated.type_as(x)
```

#### Gated Convolutions
The architecture employs gated convolutions that adaptively control information flow, allowing the model to selectively apply style features based on content characteristics. This results in more natural-looking style transfers that preserve important content details.

```
class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels*2, kernel_size, stride, padding)
        
    def forward(self, x):
        # Split channels into two parts
        h = self.conv(x)
        a, b = torch.chunk(h, 2, dim=1)
        # Apply gating mechanism
        return a * torch.sigmoid(b)
```

#### Hybrid Attention Mechanism
StripedHyena combines local processing (through convolutions) with global context awareness (through attention mechanisms), creating a hybrid approach that captures both fine details and overall style patterns. This hybrid design is particularly effective for style transfer tasks.

### 2.2 Advantages for Neural Style Transfer

1. **Memory Efficiency**: Processes high-resolution images with significantly lower memory requirements than traditional Transformer models.
2. **Linear Scaling**: Computational complexity scales linearly with image size instead of quadratically.
3. **Long-range Dependencies**: Efficiently captures relationships between distant parts of the image.
4. **Adaptive Processing**: Intelligently applies style based on content characteristics.
5. **TPU Compatibility**: Architecture is well-suited for TPU acceleration.

### 2.3 Adaptation for 2D Images

The original StripedHyena architecture was designed for 1D sequence processing. For neural style transfer, we adapt it to 2D image processing by:

1. Converting 2D images to sequences by flattening spatial dimensions
2. Applying StripedHyena processing
3. Reshaping back to 2D for convolutional processing
4. Using skip connections to preserve spatial information

This approach allows us to leverage the strengths of the StripedHyena architecture while maintaining the spatial structure necessary for image processing.

## 3. Model Implementation <a name="implementation"></a>

Now, let's implement the StripedHyena neural style transfer model. We'll start with the core components and then build the complete model.

In [None]:
# Rotary Positional Embeddings implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[1]

        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, :, None, :]
            self.sin_cached = emb.sin()[None, :, None, :]

        return self.cos_cached, self.sin_cached

# Function to apply rotary embeddings
def apply_rotary_pos_emb(q, k, cos, sin):
    # Reshape q and k for the rotation
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [None]:
# Gated Convolution implementation
class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels*2, kernel_size, stride, padding)

    def forward(self, x):
        h = self.conv(x)
        a, b = torch.chunk(h, 2, dim=1)
        return a * torch.sigmoid(b)

In [None]:
# StripedHyena Block implementation
class StripedHyenaBlock(nn.Module):
    def __init__(self, dim, heads=4, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.dim_head = dim_head

        # Rotary embeddings
        self.rotary_emb = RotaryEmbedding(dim_head)

        # Projections for Q, K, V
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_k = nn.Linear(dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        # Gated feed-forward network
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout)
        )

        # Layer norms
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Layer norm 1
        normed_x = self.norm1(x)

        # Self-attention with rotary embeddings
        q = self.to_q(normed_x)
        k = self.to_k(normed_x)
        v = self.to_v(normed_x)

        # Reshape for multi-head attention
        q = q.view(q.shape[0], q.shape[1], self.heads, self.dim_head)
        k = k.view(k.shape[0], k.shape[1], self.heads, self.dim_head)
        v = v.view(v.shape[0], v.shape[1], self.heads, self.dim_head)

        # Apply rotary embeddings
        cos, sin = self.rotary_emb(q, seq_len=q.shape[1])
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Reshape for attention computation
        q = q.transpose(1, 2)  # [batch, heads, seq_len, dim_head]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        scale = 1.0 / math.sqrt(self.dim_head)
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = F.softmax(attn, dim=-1)

        # Apply attention to values
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(out.shape[0], -1, self.heads * self.dim_head)
        out = self.to_out(out)
        out = self.dropout(out)

        # First residual connection
        x = x + out

        # Layer norm 2
        normed_x = self.norm2(x)

        # Feed-forward network
        ff_out = self.ff(normed_x)

        # Second residual connection
        return x + ff_out

In [None]:
# Complete StripedHyena Style Transfer Model
class StripedHyenaStyleTransfer(nn.Module):
    def __init__(self,
                 content_layers=['conv4_2'],
                 style_layers=['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']):
        super().__init__()

        self.content_layers = content_layers
        self.style_layers = style_layers

        # VGG16 for feature extraction (pre-trained)
        vgg = torchvision.models.vgg16(pretrained=True).features
        self.vgg = nn.ModuleList()
        self.vgg_layer_names = []

        # Create sequential modules from VGG16
        i = 0
        for layer in vgg.children():
            if isinstance(layer, nn.Conv2d):
                i += 1
                name = f'conv{i}'
            elif isinstance(layer, nn.ReLU):
                name = f'relu{i}'
                layer = nn.ReLU(inplace=False)  # Use non-inplace ReLU
            elif isinstance(layer, nn.MaxPool2d):
                name = f'pool{i}'
            elif isinstance(layer, nn.BatchNorm2d):
                name = f'bn{i}'
            else:
                raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')

            self.vgg.add_module(name, layer)
            self.vgg_layer_names.append(name)

        # Freeze VGG parameters
        for param in self.vgg.parameters():
            param.requires_grad = False

        # StripedHyena blocks for content and style processing
        self.content_hyena = nn.ModuleList([
            StripedHyenaBlock(512, heads=8, dim_head=64, dropout=0.1)
            for _ in range(3)
        ])

        self.style_hyena = nn.ModuleList([
            StripedHyenaBlock(512, heads=8, dim_head=64, dropout=0.1)
            for _ in range(3)
        ])

        # Decoder network
        self.decoder = nn.Sequential(
            GatedConv2d(512, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='nearest'),
            GatedConv2d(256, 256, kernel_size=3, padding=1),
            GatedConv2d(256, 256, kernel_size=3, padding=1),
            GatedConv2d(256, 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='nearest'),
            GatedConv2d(128, 128, kernel_size=3, padding=1),
            GatedConv2d(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='nearest'),
            GatedConv2d(64, 64, kernel_size=3, padding=1),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def encode_features(self, x):
        """Extract VGG features"""
        features = {}
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.content_layers or name in self.style_layers:
                features[name] = x
        return features

    def process_content(self, content_features):
        """Process content features through StripedHyena blocks"""
        # Use the deepest content feature
        x = content_features[self.content_layers[-1]]
        batch_size, channels, height, width = x.shape

        # Reshape to sequence for StripedHyena processing
        x = x.view(batch_size, channels, -1).permute(0, 2, 1)  # [batch, seq_len, channels]

        # Apply StripedHyena blocks
        for block in self.content_hyena:
            x = block(x)

        # Reshape back to spatial
        x = x.permute(0, 2, 1).view(batch_size, channels, height, width)

        return x

    def process_style(self, style_features):
        """Process style features through StripedHyena blocks"""
        # Use the deepest style feature
        x = style_features[self.style_layers[-1]]
        batch_size, channels, height, width = x.shape

        # Reshape to sequence for StripedHyena processing
        x = x.view(batch_size, channels, -1).permute(0, 2, 1)  # [batch, seq_len, channels]

        # Apply StripedHyena blocks
        for block in self.style_hyena:
            x = block(x)

        # Reshape back to spatial
        x = x.permute(0, 2, 1).view(batch_size, channels, height, width)

        return x

    def forward(self, content_img, style_img, alpha=1.0, return_features=False):
        """Forward pass for style transfer"""
        # Extract features
        content_features = self.encode_features(content_img)
        style_features = self.encode_features(style_img)

        # Process content and style features
        processed_content = self.process_content(content_features)
        processed_style = self.process_style(style_features)

        # Combine content and style features
        combined = processed_content * alpha + processed_style * (1 - alpha)

        # Decode to generate stylized image
        output_img = self.decoder(combined)

        if return_features:
            return output_img, content_features, style_features, {
                'content': processed_content,
                'style': style_features  # Use original style features for loss computation
            }

        return output_img

In [None]:
# Initialize the model
model = StripedHyenaStyleTransfer().to(device)
print(f"Model initialized on {device}")

## 4. Basic Style Transfer Demo <a name="basic_demo"></a>

Let's create a basic demo to demonstrate the style transfer capabilities of our StripedHyena model. We'll use some sample images and apply the style transfer.

In [None]:
# Utility functions for image processing
def load_image_from_url(url, max_size=512):
    """Load an image from a URL"""
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert('RGB')

    # Resize while maintaining aspect ratio
    if max(img.size) > max_size:
        ratio = max_size / max(img.size)
        new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
        img = img.resize(new_size, Image.LANCZOS)

    return img

def load_image_from_upload(uploaded_file, max_size=512):
    """Load an image from an uploaded file"""
    img = Image.open(uploaded_file).convert('RGB')

    # Resize while maintaining aspect ratio
    if max(img.size) > max_size:
        ratio = max_size / max(img.size)
        new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
        img = img.resize(new_size, Image.LANCZOS)

    return img

def preprocess_image(img):
    """Convert PIL image to tensor for model input"""
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)
    return img_tensor

def deprocess_image(tensor):
    """Convert tensor to PIL image for display"""
    # Move to CPU if needed
    if IS_TPU_AVAILABLE:
        tensor = xm.mesh_reduce('tensor_to_cpu', tensor, lambda x: x.cpu())
    elif tensor.is_cuda:
        tensor = tensor.cpu()

    # Denormalize
    tensor = tensor.squeeze(0).detach().clone()
    tensor = tensor.clamp(0, 1)

    # Convert to PIL image
    img = torchvision.transforms.ToPILImage()(tensor)
    return img

def display_images(content_img, style_img, output_img, figsize=(15, 5)):
    """Display content, style, and output images side by side"""
    fig, axes = plt.subplots(1, 3, figsize=figsize)

    axes[0].imshow(content_img)
    axes[0].set_title("Content Image")
    axes[0].axis("off")

    axes[1].imshow(style_img)
    axes[1].set_title("Style Image")
    axes[1].axis("off")

    axes[2].imshow(output_img)
    axes[2].set_title("Stylized Output")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# Sample images for demonstration
content_url = "https://images.pexels.com/photos/2559941/pexels-photo-2559941.jpeg"
style_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"

try:
    content_img = load_image_from_url(content_url)
    style_img = load_image_from_url(style_url)

    # Display original images
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(content_img)
    axes[0].set_title("Content Image")
    axes[0].axis("off")

    axes[1].imshow(style_img)
    axes[1].set_title("Style Image")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error loading sample images: {e}")
    print("Please upload your own images in the next cell.")

In [None]:
# Basic style transfer
def perform_style_transfer(content_img, style_img, alpha=1.0):
    """Perform style transfer using the StripedHyena model"""
    # Preprocess images
    content_tensor = preprocess_image(content_img)
    style_tensor = preprocess_image(style_img)

    # Perform style transfer
    with torch.no_grad():
        start_time = time.time()
        output_tensor = model(content_tensor, style_tensor, alpha=alpha)

        # Handle TPU synchronization if needed
        if IS_TPU_AVAILABLE:
            xm.mark_step()

        end_time = time.time()

    # Deprocess output
    output_img = deprocess_image(output_tensor)

    print(f"Style transfer completed in {end_time - start_time:.2f} seconds")
    return output_img

# Run style transfer on sample images
try:
    output_img = perform_style_transfer(content_img, style_img, alpha=0.8)
    display_images(content_img, style_img, output_img)
except NameError:
    print("Please upload content and style images first.")

## 5. Advanced Style Transfer with Parameter Tuning <a name="advanced_demo"></a>

Now, let's create a more advanced demo that allows for parameter tuning and interactive experimentation.

In [None]:
# Upload your own images
from google.colab import files
import ipywidgets as widgets
from IPython.display import display, clear_output

def upload_images():
    print("Please upload a content image:")
    content_file = files.upload()
    content_filename = list(content_file.keys())[0]
    content_img = load_image_from_upload(content_filename)

    print("\nPlease upload a style image:")
    style_file = files.upload()
    style_filename = list(style_file.keys())[0]
    style_img = load_image_from_upload(style_filename)

    return content_img, style_img

# Uncomment to upload your own images
# custom_content_img, custom_style_img = upload_images()

In [None]:
# Interactive demo with parameter tuning
def interactive_style_transfer(content_img, style_img):
    # Create widgets
    alpha_slider = widgets.FloatSlider(
        value=0.8,
        min=0.0,
        max=1.0,
        step=0.05,
        description='Content Weight:',
        continuous_update=False
    )

    run_button = widgets.Button(description="Apply Style Transfer")
    output = widgets.Output()

    # Display original images
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(content_img)
    axes[0].set_title("Content Image")
    axes[0].axis("off")

    axes[1].imshow(style_img)
    axes[1].set_title("Style Image")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()

    # Define button click handler
    def on_button_clicked(b):
        with output:
            clear_output()
            print(f"Running style transfer with content weight = {alpha_slider.value}...")
            output_img = perform_style_transfer(content_img, style_img, alpha=alpha_slider.value)
            display_images(content_img, style_img, output_img)

    # Connect button to handler
    run_button.on_click(on_button_clicked)

    # Display widgets
    display(alpha_slider, run_button, output)

# Run interactive demo with sample images
try:
    interactive_style_transfer(content_img, style_img)
except NameError:
    print("Please upload content and style images first.")

## 6. Training on Custom Datasets <a name="custom_training"></a>

In this section, we'll implement the training process for the StripedHyena neural style transfer model using custom datasets. This allows you to train the model on your own collection of content and style images.

In [None]:
# Dataset and DataLoader implementation
import os
import random
from torch.utils.data import Dataset, DataLoader

class StyleTransferDataset(Dataset):
    """Dataset for neural style transfer training"""

    def __init__(self, content_dir, style_dir, image_size=256, transform=None):
        """
        Initialize the dataset.

        Args:
            content_dir: Directory containing content images
            style_dir: Directory containing style images
            image_size: Size to resize images to
            transform: Optional transform to apply to images
        """
        self.content_paths = []
        self.style_paths = []

        # Check if directories exist
        if os.path.exists(content_dir):
            self.content_paths = [os.path.join(content_dir, f) for f in os.listdir(content_dir)
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        else:
            print(f"Warning: Content directory {content_dir} not found.")

        if os.path.exists(style_dir):
            self.style_paths = [os.path.join(style_dir, f) for f in os.listdir(style_dir)
                               if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        else:
            print(f"Warning: Style directory {style_dir} not found.")

        if not self.content_paths:
            raise ValueError(f"No content images found in {content_dir}")

        if not self.style_paths:
            raise ValueError(f"No style images found in {style_dir}")

        self.image_size = image_size

        if transform is None:
            self.transform = torchvision.transforms.Compose([
                torchvision.transforms.Resize(image_size),
                torchvision.transforms.CenterCrop(image_size),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform

        print(f"Found {len(self.content_paths)} content images and {len(self.style_paths)} style images")

    def __len__(self):
        return len(self.content_paths)

    def __getitem__(self, idx):
        # Load content image
        content_path = self.content_paths[idx]
        content_img = Image.open(content_path).convert('RGB')
        content_tensor = self.transform(content_img)

        # Randomly select a style image
        style_path = random.choice(self.style_paths)
        style_img = Image.open(style_path).convert('RGB')
        style_tensor = self.transform(style_img)

        return {
            'content': content_tensor,
            'style': style_tensor,
            'content_path': content_path,
            'style_path': style_path
        }

def get_dataloader(content_dir, style_dir, batch_size=4, image_size=256, num_workers=2):
    """Create and return the data loader"""
    dataset = StyleTransferDataset(content_dir, style_dir, image_size)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )

    return dataloader

In [None]:
# Loss functions for neural style transfer
class ContentLoss(nn.Module):
    """Content loss for neural style transfer"""

    def __init__(self):
        super(ContentLoss, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, output_features, target_features):
        return self.criterion(output_features, target_features)

class StyleLoss(nn.Module):
    """Style loss for neural style transfer"""

    def __init__(self):
        super(StyleLoss, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, output_features, target_features):
        # Calculate Gram matrices
        output_gram = self.gram_matrix(output_features)
        target_gram = self.gram_matrix(target_features)
        return self.criterion(output_gram, target_gram)

    def gram_matrix(self, features):
        """Calculate Gram matrix for style loss"""
        batch_size, ch, h, w = features.size()
        features = features.view(batch_size, ch, h * w)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (ch * h * w)
        return gram

class TVLoss(nn.Module):
    """Total variation loss for smoothing"""

    def __init__(self):
        super(TVLoss, self).__init__()

    def forward(self, x):
        batch_size, c, h, w = x.size()
        tv_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum()
        tv_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * c * h * w)

In [None]:
# Training function
def train_model(content_dir, style_dir,
                num_epochs=100,
                batch_size=4,
                learning_rate=1e-4,
                content_weight=1.0,
                style_weight=10.0,
                tv_weight=0.001,
                checkpoint_interval=10,
                image_size=256):
    """Train the StripedHyena neural style transfer model"""

    # Create output directories
    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("results", exist_ok=True)

    # Initialize model
    model = StripedHyenaStyleTransfer().to(device)

    # Initialize loss functions
    content_criterion = ContentLoss().to(device)
    style_criterion = StyleLoss().to(device)
    tv_criterion = TVLoss().to(device)

    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Get data loader
    dataloader = get_dataloader(content_dir, style_dir, batch_size, image_size)

    # Training loop
    losses = []

    print(f"Starting training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        start_time = time.time()

        model.train()
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            # Move data to device
            content_images = batch['content'].to(device)
            style_images = batch['style'].to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            output_images, content_features, style_features, output_features = model(
                content_images, style_images, return_features=True
            )

            # Calculate losses
            content_loss = content_criterion(output_features['content'], content_features[model.content_layers[0]])

            style_loss = 0
            for layer in model.style_layers:
                style_loss += style_criterion(output_features['style'][layer], style_features[layer])
            style_loss /= len(model.style_layers)

            tv_loss = tv_criterion(output_images)

            # Total loss
            loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Handle TPU synchronization if needed
            if IS_TPU_AVAILABLE:
                xm.mark_step()

            # Update progress
            epoch_loss += loss.item()
            progress_bar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "content": f"{content_loss.item():.4f}",
                "style": f"{style_loss.item():.4f}"
            })

        # Calculate average epoch loss
        avg_epoch_loss = epoch_loss / len(dataloader)
        losses.append(avg_epoch_loss)

        # Print epoch summary
        elapsed_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_epoch_loss:.4f} - Time: {elapsed_time:.2f}s")

        # Save checkpoint
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = f"checkpoints/model_epoch_{epoch+1}.pth"
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_epoch_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")

    # Save final model
    torch.save(model.state_dict(), "checkpoints/model_final.pth")
    print("Training complete. Final model saved.")

    # Plot loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid(True)
    plt.savefig('results/training_loss.png')
    plt.show()

    return model

In [None]:
# Prepare dataset directories and download sample images
def prepare_sample_dataset(num_images=5):
    """Prepare a sample dataset for training"""
    # Create directories
    os.makedirs("dataset/content", exist_ok=True)
    os.makedirs("dataset/style", exist_ok=True)

    # Sample content image URLs
    content_urls = [
        "https://images.pexels.com/photos/2559941/pexels-photo-2559941.jpeg",
        "https://images.pexels.com/photos/1563256/pexels-photo-1563256.jpeg",
        "https://images.pexels.com/photos/1366630/pexels-photo-1366630.jpeg",
        "https://images.pexels.com/photos/1366909/pexels-photo-1366909.jpeg",
        "https://images.pexels.com/photos/1366919/pexels-photo-1366919.jpeg"
    ]

    # Sample style image URLs
    style_urls = [
        "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg",
        "https://upload.wikimedia.org/wikipedia/commons/thumb/c/cc/Grant_Wood_-_American_Gothic_-_Google_Art_Project.jpg/1280px-Grant_Wood_-_American_Gothic_-_Google_Art_Project.jpg",
        "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8c/Picasso_Three_Musicians_MoMA.jpg/1280px-Picasso_Three_Musicians_MoMA.jpg",
        "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg/1280px-Vassily_Kandinsky%2C_1913_-_Composition_7.jpg",
        "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Tsunami_by_hokusai_19th_century.jpg/1280px-Tsunami_by_hokusai_19th_century.jpg"
    ]

    # Download content images
    print("Downloading content images...")
    for i, url in enumerate(content_urls[:num_images]):
        try:
            response = requests.get(url, timeout=10)
            if response.status_code == 200:
                with open(f"dataset/content/content_{i+1}.jpg", "wb") as f:
                    f.write(response.content)
                print(f"Downloaded content image {i+1}/{num_images}")
            else:
                print(f"Failed to download content image {i+1}: HTTP {response.status_code}")
        except Exception as e:
            print(f"Error downloading content image {i+1}: {e}")

    # Download style images
    print("\nDownloading style images...")
    for i, url in enumerate(style_urls[:num_images]):
        try:
            response = requests.get(url, timeout=10)
            if response.status_code == 200:
                with open(f"dataset/style/style_{i+1}.jpg", "wb") as f:
                    f.write(response.content)
                print(f"Downloaded style image {i+1}/{num_images}")
            else:
                print(f"Failed to download style image {i+1}: HTTP {response.status_code}")
        except Exception as e:
            print(f"Error downloading style image {i+1}: {e}")

    print("\nSample dataset prepared.")
    return "dataset/content", "dataset/style"

# Uncomment to prepare a sample dataset
# content_dir, style_dir = prepare_sample_dataset(num_images=3)

In [None]:
# Upload your own dataset
def upload_dataset():
    """Upload custom dataset for training"""
    # Create directories
    os.makedirs("dataset/content", exist_ok=True)
    os.makedirs("dataset/style", exist_ok=True)

    print("Please upload content images (you can select multiple files):")
    content_files = files.upload()

    # Save content images
    for filename, content in content_files.items():
        with open(f"dataset/content/{filename}", "wb") as f:
            f.write(content)

    print(f"\nUploaded {len(content_files)} content images.")

    print("\nPlease upload style images (you can select multiple files):")
    style_files = files.upload()

    # Save style images
    for filename, content in style_files.items():
        with open(f"dataset/style/{filename}", "wb") as f:
            f.write(content)

    print(f"\nUploaded {len(style_files)} style images.")

    return "dataset/content", "dataset/style"

# Uncomment to upload your own dataset
# content_dir, style_dir = upload_dataset()

In [None]:
# Run training with interactive parameters
def interactive_training():
    """Interactive training with parameter selection"""
    # Create widgets for dataset selection
    dataset_options = widgets.RadioButtons(
        options=['Use sample dataset', 'Upload my own dataset'],
        description='Dataset:',
        disabled=False
    )

    # Create widgets for training parameters
    num_epochs_slider = widgets.IntSlider(
        value=20,
        min=5,
        max=100,
        step=5,
        description='Epochs:',
        disabled=False
    )

    batch_size_slider = widgets.IntSlider(
        value=2,
        min=1,
        max=8,
        step=1,
        description='Batch Size:',
        disabled=False
    )

    content_weight_slider = widgets.FloatSlider(
        value=1.0,
        min=0.1,
        max=10.0,
        step=0.1,
        description='Content Weight:',
        disabled=False
    )

    style_weight_slider = widgets.FloatSlider(
        value=10.0,
        min=1.0,
        max=50.0,
        step=1.0,
        description='Style Weight:',
        disabled=False
    )

    start_button = widgets.Button(
        description='Start Training',
        disabled=False,
        button_style='success',
        tooltip='Click to start training'
    )

    output = widgets.Output()

    # Display widgets
    display(dataset_options)
    display(widgets.HBox([num_epochs_slider, batch_size_slider]))
    display(widgets.HBox([content_weight_slider, style_weight_slider]))
    display(start_button)
    display(output)

    # Define button click handler
    def on_button_clicked(b):
        with output:
            clear_output()

            # Get dataset
            if dataset_options.value == 'Use sample dataset':
                content_dir, style_dir = prepare_sample_dataset(num_images=3)
            else:
                content_dir, style_dir = upload_dataset()

            # Get training parameters
            num_epochs = num_epochs_slider.value
            batch_size = batch_size_slider.value
            content_weight = content_weight_slider.value
            style_weight = style_weight_slider.value

            print(f"Starting training with the following parameters:")
            print(f"- Number of epochs: {num_epochs}")
            print(f"- Batch size: {batch_size}")
            print(f"- Content weight: {content_weight}")
            print(f"- Style weight: {style_weight}")
            print(f"- Content directory: {content_dir}")
            print(f"- Style directory: {style_dir}")
            print("\n")

            # Run training
            trained_model = train_model(
                content_dir=content_dir,
                style_dir=style_dir,
                num_epochs=num_epochs,
                batch_size=batch_size,
                content_weight=content_weight,
                style_weight=style_weight,
                checkpoint_interval=5,
                image_size=256
            )

            print("\nTraining complete! You can now use the trained model for style transfer.")

    # Connect button to handler
    start_button.on_click(on_button_clicked)

# Uncomment to run interactive training
# interactive_training()

## 7. Performance Analysis: TPU vs. CPU vs. GPU <a name="performance"></a>

In this section, we'll compare the performance of the StripedHyena neural style transfer model on different hardware accelerators: TPU, GPU, and CPU.

In [None]:
# Performance comparison function
def compare_performance(content_img, style_img, image_sizes=[256, 512, 1024]):
    """Compare performance of style transfer on different devices and image sizes"""
    results = []

    # Check available devices
    available_devices = []
    device_names = []

    if IS_TPU_AVAILABLE:
        available_devices.append(xm.xla_device())
        device_names.append("TPU")

    if torch.cuda.is_available():
        available_devices.append(torch.device("cuda"))
        device_names.append(f"GPU ({torch.cuda.get_device_name(0)})")

    available_devices.append(torch.device("cpu"))
    device_names.append("CPU")

    print(f"Testing performance on {len(available_devices)} devices: {', '.join(device_names)}")

    # Test each device and image size
    for size in image_sizes:
        print(f"\nTesting with image size: {size}x{size}")

        # Resize images
        content_resized = content_img.resize((size, size), Image.LANCZOS)
        style_resized = style_img.resize((size, size), Image.LANCZOS)

        for device, device_name in zip(available_devices, device_names):
            print(f"  Testing on {device_name}...")

            # Initialize model on device
            model = StripedHyenaStyleTransfer().to(device)
            model.eval()

            # Preprocess images
            transform = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

            content_tensor = transform(content_resized).unsqueeze(0).to(device)
            style_tensor = transform(style_resized).unsqueeze(0).to(device)

            # Warm-up run
            with torch.no_grad():
                _ = model(content_tensor, style_tensor)
                if device_name == "TPU":
                    xm.mark_step()

            # Timed run
            torch.cuda.synchronize() if device_name.startswith("GPU") else None
            start_time = time.time()

            with torch.no_grad():
                for _ in range(3):  # Run multiple times for more accurate timing
                    _ = model(content_tensor, style_tensor)
                    if device_name == "TPU":
                        xm.mark_step()

            torch.cuda.synchronize() if device_name.startswith("GPU") else None
            end_time = time.time()

            # Calculate average time
            avg_time = (end_time - start_time) / 3
            print(f"    Average time: {avg_time:.4f} seconds")

            results.append({
                "device": device_name,
                "image_size": size,
                "time": avg_time
            })

    # Plot results
    plt.figure(figsize=(12, 6))

    for device_name in device_names:
        device_results = [r for r in results if r["device"] == device_name]
        sizes = [r["image_size"] for r in device_results]
        times = [r["time"] for r in device_results]
        plt.plot(sizes, times, marker='o', label=device_name)

    plt.xlabel('Image Size (pixels)')
    plt.ylabel('Processing Time (seconds)')
    plt.title('StripedHyena Neural Style Transfer Performance Comparison')
    plt.grid(True)
    plt.legend()
    plt.savefig('results/performance_comparison.png')
    plt.show()

    return results

# Uncomment to run performance comparison
# try:
#     performance_results = compare_performance(content_img, style_img, image_sizes=[256, 512])
# except NameError:
#     print("Please load content and style images first.")

## 8. Exporting Your Model <a name="export"></a>

In this section, we'll show how to export your trained model for deployment or sharing.

In [None]:
# Export model function
def export_model(model_path="checkpoints/model_final.pth"):
    """Export the trained model for deployment"""
    # Create export directory
    os.makedirs("export", exist_ok=True)

    # Check if model exists
    if not os.path.exists(model_path):
        print(f"Model file {model_path} not found. Please train a model first.")
        return

    # Load model
    model = StripedHyenaStyleTransfer().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Export to TorchScript
    print("Exporting model to TorchScript...")
    example_content = torch.randn(1, 3, 256, 256).to(device)
    example_style = torch.randn(1, 3, 256, 256).to(device)

    with torch.no_grad():
        traced_model = torch.jit.trace(model, (example_content, example_style))
        traced_model.save("export/model_traced.pt")

    print("Model exported to export/model_traced.pt")

    # Create a simple inference script
    inference_script = """
import torch
import torchvision.transforms as transforms
from PIL import Image

def load_image(image_path, max_size=512):
    """Load and preprocess an image"""
    img = Image.open(image_path).convert('RGB')

    # Resize while maintaining aspect ratio
    if max(img.size) > max_size:
        ratio = max_size / max(img.size)
        new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
        img = img.resize(new_size, Image.LANCZOS)

    return img

def preprocess_image(img):
    """Convert PIL image to tensor for model input"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0)
    return img_tensor

def deprocess_image(tensor):
    """Convert tensor to PIL image for display"""
    tensor = tensor.squeeze(0).detach().clone()
    tensor = tensor.clamp(0, 1)

    # Convert to PIL image
    img = transforms.ToPILImage()(tensor)
    return img

def apply_style_transfer(model_path, content_path, style_path, output_path, alpha=0.8):
    """Apply style transfer using the exported model"""
    # Load model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.jit.load(model_path).to(device)
    model.eval()

    # Load images
    content_img = load_image(content_path)
    style_img = load_image(style_path)

    # Preprocess images
    content_tensor = preprocess_image(content_img).to(device)
    style_tensor = preprocess_image(style_img).to(device)

    # Apply style transfer
    with torch.no_grad():
        output_tensor = model(content_tensor, style_tensor)

    # Save result
    output_img = deprocess_image(output_tensor)
    output_img.save(output_path)
    print(f"Stylized image saved to {output_path}")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Apply neural style transfer using StripedHyena model")
    parser.add_argument("--model", type=str, default="model_traced.pt", help="Path to the exported model")
    parser.add_argument("--content", type=str, required=True, help="Path to the content image")
    parser.add_argument("--style", type=str, required=True, help="Path to the style image")
    parser.add_argument("--output", type=str, default="stylized_output.jpg", help="Path to save the output image")
    parser.add_argument("--alpha", type=float, default=0.8, help="Content weight (0.0 to 1.0)")

    args = parser.parse_args()
    apply_style_transfer(args.model, args.content, args.style, args.output, args.alpha)
"""

    with open("export/inference.py", "w") as f:
        f.write(inference_script)

    print("Inference script saved to export/inference.py")

    # Create a README file
    readme = """
# StripedHyena Neural Style Transfer

This package contains a trained StripedHyena neural style transfer model.

## Contents

- `model_traced.pt`: The exported TorchScript model
- `inference.py`: A script for applying style transfer

## Requirements

- Python 3.8+
- PyTorch 2.0+
- torchvision
- Pillow

## Usage

```bash
python inference.py --content path/to/content.jpg --style path/to/style.jpg --output stylized_output.jpg
```

## Parameters

- `--model`: Path to the exported model (default: model_traced.pt)
- `--content`: Path to the content image
- `--style`: Path to the style image
- `--output`: Path to save the output image (default: stylized_output.jpg)
- `--alpha`: Content weight (0.0 to 1.0, default: 0.8)
"""

    with open("export/README.md", "w") as f:
        f.write(readme)

    print("README file saved to export/README.md")

    # Create a zip file
    !zip -r export/striped_hyena_style_transfer.zip export/model_traced.pt export/inference.py export/README.md

    print("\nModel exported successfully! You can download the zip file from the Files tab.")
    return "export/striped_hyena_style_transfer.zip"

# Uncomment to export a trained model
# export_path = export_model()

In [None]:
# Download the exported model
def download_export(export_path="export/striped_hyena_style_transfer.zip"):
    """Download the exported model"""
    from google.colab import files

    if os.path.exists(export_path):
        files.download(export_path)
        print(f"Downloaded {export_path}")
    else:
        print(f"Export file {export_path} not found. Please export the model first.")

# Uncomment to download the exported model
# download_export()

## Conclusion

In this notebook, we've implemented and explored the StripedHyena neural style transfer model, optimized for Google Colab's TPU environment. We've covered:

1. The architecture and components of the StripedHyena model
2. Implementation of the model for neural style transfer
3. Basic and advanced style transfer demos
4. Training on custom datasets
5. Performance comparison across different hardware accelerators
6. Exporting the model for deployment

The StripedHyena architecture, with its hybrid approach combining rotary attention and gated convolutions, offers an efficient and effective solution for neural style transfer tasks. Its memory efficiency and linear scaling properties make it particularly well-suited for processing high-resolution images.

Feel free to experiment with different content and style images, adjust the parameters, and train the model on your own datasets to achieve unique and personalized style transfer results.