In [18]:
from perceiver.model.core import (
    PerceiverDecoder,
    PerceiverEncoder,
    PerceiverIO
)

from perceiver.model.core.classifier import ClassificationOutputAdapter
from perceiver.model.core.adapter import TrainableQueryProvider

from perceiver.model.vision.image_classifier import ImageInputAdapter


# Fourier-encodes pixel positions and flatten along spatial dimensions
input_adapter = ImageInputAdapter(
  image_shape=(224, 224, 3),  # M = 224 * 224
  num_frequency_bands=64,
)

# Projects generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(
  num_classes=1000,
  num_output_query_channels=1024,  # F
)

# Generic Perceiver encoder
encoder = PerceiverEncoder(
  input_adapter=input_adapter,
  num_latents=512,  # N
  num_latent_channels=512,  # D
  num_cross_attention_qk_channels=input_adapter.num_input_channels,  # C
  num_cross_attention_heads=1,
  num_self_attention_heads=8,
  num_self_attention_layers_per_block=6,
  num_self_attention_blocks=8,
  dropout=0.0,
)

query_provider = TrainableQueryProvider(64, 64)

# Generic Perceiver decoder
decoder = PerceiverDecoder(
  output_adapter=output_adapter,
  output_query_provider=query_provider,
  num_latent_channels=1024,  # D
  num_cross_attention_heads=1,
  dropout=0.0,
)

# Perceiver IO image classifier
model = PerceiverIO(encoder, decoder)

print('number of parameters: ', sum(p.numel() for p in model.parameters()))


number of parameters:  11841971
