In [None]:
import torch
import torch.nn as nn
import quaternion

class QuaternionLinear(nn.Module):
    """
    A linear layer that works with quaternion inputs and outputs.
    """
    def __init__(self, in_features, out_features):
        super(QuaternionLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_r = nn.Parameter(torch.randn(out_features, in_features))
        self.weight_i = nn.Parameter(torch.randn(out_features, in_features))
        self.weight_j = nn.Parameter(torch.randn(out_features, in_features))
        self.weight_k = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def forward(self, x):
        # Convert the input quaternion to a quaternion tensor
        x_q = quaternion.from_float_array(x.cpu().detach().numpy())
        x_q = torch.tensor(x_q.components(), device=x.device)

        # Apply the quaternion linear transformation
        y_r = x_q[:,0] @ self.weight_r.t()
        y_i = x_q[:,1] @ self.weight_i.t()
        y_j = x_q[:,2] @ self.weight_j.t()
        y_k = x_q[:,3] @ self.weight_k.t()
        y = y_r + y_i + y_j + y_k + self.bias

        # Convert the output quaternion tensor to a quaternion
        y_q = quaternion.as_quat_array(torch.stack([y_r, y_i, y_j, y_k], dim=-1).cpu().detach().numpy())
        y_q = torch.tensor(y_q.components(), device=x.device)

        return y_q


class LeNet300_100(nn.Module):
    """
    A modified version of LeNet-300-100 that works with quaternion inputs and outputs.
    """
    def __init__(self, num_classes):
        super(LeNet300_100, self).__init__()
        self.fc1 = QuaternionLinear(784, 300)
        self.fc2 = QuaternionLinear(300, 100)
        self.fc3 = QuaternionLinear(100, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input tensor
        x = nn.functional.sigmoid(self.fc1(x))
        x = nn.functional.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x


In [None]:
import torchvision
import torchvision.transforms as transforms

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

# Create the model and optimizer
model = LeNet300_100(num_classes=10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Train the model
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader


Another code!

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import quaternion

class LeNet300_100(nn.Module):
    def __init__(self):
        super(LeNet300_100, self).__init__()
        
        # Fully connected layers with quaternion activation functions
        self.fc1 = nn.Linear(784, 300)
        self.qact1 = nn.QuaternionLinear(300, 300)
        self.fc2 = nn.Linear(300, 100)
        self.qact2 = nn.QuaternionLinear(100, 100)
        self.fc3 = nn.Linear(100, 10)
        
        # Initialize the weight and bias parameters of the fully connected layers
        nn.init.normal_(self.fc1.weight, mean=0.0, std=0.1)
        nn.init.normal_(self.fc2.weight, mean=0.0, std=0.1)
        nn.init.normal_(self.fc3.weight, mean=0.0, std=0.1)
        nn.init.constant_(self.fc1.bias, 0.1)
        nn.init.constant_(self.fc2.bias, 0.1)
        nn.init.constant_(self.fc3.bias, 0.1)

    def forward(self, x):
        # Flatten the input image
        x = x.view(-1, 784)
        
        # Apply fully connected layers with quaternion activation functions
        x = self.fc1(x)
        x = self.qact1(x)
        x = self.fc2(x)
        x = self.qact2(x)
        x = self.fc3(x)
        
        # Return the output logits
        return x
    
# Create an instance of the LeNet300_100 model
model = LeNet300_100()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Train the model on the MNIST dataset
for epoch in range(10):
    for i, (images, labels) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)

        # Calculate the loss
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Print the loss every 100 batches
        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch+1}/10], Batch [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")


In this example, we use the nn.QuaternionLinear module provided by PyTorch to implement fully connected layers with quaternion activation functions. We also initialize the weight and bias parameters of the fully connected layers using the nn.init module. Finally, we define the loss function and optimizer and train the model on the MNIST dataset using a nested loop over the epochs and batches. Note that you will need to preprocess the MNIST dataset to convert the images and labels to PyTorch tensors and load them using a DataLoader.

Another one!

In [None]:
import torch
import torch.nn as nn
import quaternion

class LeNetQuat(nn.Module):
    def __init__(self):
        super(LeNetQuat, self).__init__()
        # Define the layers for the real-valued input
        self.fc1_real = nn.Linear(784, 300)
        self.fc2_real = nn.Linear(300, 100)
        self.fc3_real = nn.Linear(100, 10)
        self.relu_real = nn.ReLU()
        
        # Define the layers for the quaternion-valued input
        self.fc1_quat = nn.Linear(4, 300)
        self.fc2_quat = nn.Linear(300, 100)
        self.fc3_quat = nn.Linear(100, 10)
        self.relu_quat = nn.ReLU()
        
        # Define the output layer for the quaternion-valued input
        self.out_quat = nn.Linear(10, 4)
        
    def forward(self, x):
        # Separate the real and imaginary parts of the input
        x_real = x[:, :784]
        x_quat = x[:, 784:]
        
        # Compute the output for the real-valued input
        x_real = self.fc1_real(x_real)
        x_real = self.relu_real(x_real)
        x_real = self.fc2_real(x_real)
        x_real = self.relu_real(x_real)
        x_real = self.fc3_real(x_real)
        
        # Compute the output for the quaternion-valued input
        x_quat = quaternion.from_float_array(x_quat.detach().numpy()).normalized()
        x_quat = quaternion.as_float_array(x_quat * quaternion.from_float_array([1, 0, 0, 0]))
        x_quat = torch.tensor(x_quat)
        x_quat = self.fc1_quat(x_quat)
        x_quat = self.relu_quat(x_quat)
        x_quat = self.fc2_quat(x_quat)
        x_quat = self.relu_quat(x_quat)
        x_quat = self.fc3_quat(x_quat)
        x_quat = self.out_quat(x_quat)
        x_quat = x_quat * x_quat
        x_quat = x_quat.sum(dim=1)
        x_quat = torch.sqrt(x_quat)
        
        # Combine the output for the real and quaternion-valued inputs
        x = torch.cat([x_real, x_quat], dim=1)
        
        return x


In this example, we define a modified version of the LeNet-300-100 architecture that can take both real and quaternion-valued inputs. The input is split into two parts: the first 784 elements correspond to the real-valued input (i.e., an image of a handwritten digit), and the last 4 elements correspond to the quaternion-valued input.

The real-valued input is processed using the standard LeNet-300-100 architecture, while the quaternion-valued input is first converted to a quaternion using the quaternion library, then normalized and converted back to a float tensor. The quaternion tensor is then processed using a modified LeNet-300-100 architecture that is specifically designed to work with quaternion-valued inputs. The output from the quaternion branch is then combined with the output from the real-valued branch using concatenation, and the final output is returned.

Note that this is just one possible way to modify the LeNet-300-100 architecture to work with quaternions, and there are likely