In [1]:
import torch

from perceiver.model.core import (
    PerceiverDecoder,
    PerceiverEncoder,
    PerceiverIO
)

from perceiver.model.core.classifier import ClassificationOutputAdapter
from perceiver.model.core.adapter import TrainableQueryProvider

from perceiver.model.vision.image_classifier import ImageInputAdapter

from torch import nn, optim

from dataloader import xrd_dataloader, binary_dataloader
%load_ext autoreload

In [2]:
from perceiver.model.core import (
    PerceiverDecoder,
    PerceiverEncoder,
    PerceiverIO
)

from perceiver.model.core.classifier import ClassificationOutputAdapter
from perceiver.model.core.adapter import TrainableQueryProvider

from perceiver.model.vision.image_classifier import ImageInputAdapter

D_INPUT = 100

# Fourier-encodes pixel positions and flatten along spatial dimensions
input_adapter = ImageInputAdapter(
  image_shape=(D_INPUT, 1),  # M = 224 * 224
  num_frequency_bands=32,
)

# Projects generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(
  num_classes=D_INPUT,
  num_output_query_channels=512,  # F
)

# Generic Perceiver encoder
encoder = PerceiverEncoder(
  input_adapter=input_adapter,
  num_latents=512,  # N
  num_latent_channels=512,  # D changed from 1028
  num_cross_attention_qk_channels=input_adapter.num_input_channels,  # C
  num_cross_attention_heads=1,
  num_self_attention_heads=8,
  num_self_attention_layers_per_block=6,
  num_self_attention_blocks=8,
  dropout=0.0,
)

query_provider = TrainableQueryProvider(1, 512) # very arbitrary!

# Generic Perceiver decoder
decoder = PerceiverDecoder(
  output_adapter=output_adapter,
  output_query_provider=query_provider,
  num_latent_channels=512,  # D
  num_cross_attention_heads=1,
  dropout=0.0,
)

# Perceiver IO image classifier
mse_loss = nn.MSELoss()
model = PerceiverIO(encoder, decoder)

print('number of parameters: ', sum(p.numel() for p in model.parameters()))


number of parameters:  11965366


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
def train_model(num_epochs=100):
    outputs = []
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    for epoch in range(num_epochs):
        for idx, data in enumerate(binary_dataloader):
            data = data.reshape(2, D_INPUT, 1)
            data = data.float()
            # data = data.to(device)
            # ===================forward=====================
            output = model(data)
            loss = mse_loss(output, data.squeeze())
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n_correct = torch.sum(torch.round(output) == data[:, :, 0]) / 2

            if idx % 5 == 0:
                print(f"Finished batch {idx} in epoch {epoch + 1}. Loss: {loss.item():.4f}")
                print(f"The model classified {n_correct:.4f} percent of points correctly.")

        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
        outputs.append((epoch, data, output))



# Train the model

model.train(True)
train_model(num_epochs=1)
model.train(False)

Finished batch 0 in epoch 1. Loss: 0.3392
The model classified 0.4900 percent of points correctly.
Finished batch 5 in epoch 1. Loss: 0.2956
The model classified 0.5350 percent of points correctly.
Finished batch 10 in epoch 1. Loss: 0.2913
The model classified 0.5050 percent of points correctly.
Finished batch 15 in epoch 1. Loss: 0.3285
The model classified 0.5200 percent of points correctly.
Finished batch 20 in epoch 1. Loss: 0.2958
The model classified 0.5700 percent of points correctly.
Finished batch 25 in epoch 1. Loss: 0.3263
The model classified 0.4900 percent of points correctly.
Finished batch 30 in epoch 1. Loss: 0.2587
The model classified 0.5550 percent of points correctly.
Finished batch 35 in epoch 1. Loss: 0.3133
The model classified 0.4450 percent of points correctly.
Finished batch 40 in epoch 1. Loss: 0.2829
The model classified 0.5100 percent of points correctly.
Finished batch 45 in epoch 1. Loss: 0.3040
The model classified 0.5300 percent of points correctly.
Fi

KeyboardInterrupt: 

In [None]:

for item in xrd_dataloader:
    print(item.size())
    batch1 = item
    batch = item.reshape(2, 10000, 1)
    break

out.detach().numpy()[0, :].shape
spectra = out.detach().numpy()[0, :]