<a href="https://colab.research.google.com/github/vfrantc/quaternion_neurons/blob/main/train_quaternion_resnet18_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/TParcollet/Quaternion-Neural-Networks.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/TParcollet/Quaternion-Neural-Networks.git
  Cloning https://github.com/TParcollet/Quaternion-Neural-Networks.git to /tmp/pip-req-build-2q53jn2d
  Running command git clone --filter=blob:none --quiet https://github.com/TParcollet/Quaternion-Neural-Networks.git /tmp/pip-req-build-2q53jn2d
  Resolved https://github.com/TParcollet/Quaternion-Neural-Networks.git to commit f8de5d5e5a3f9c694a0d62cffc64ec4ccdffd1bc
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: Pytorch-QNN
  Building wheel for Pytorch-QNN (setup.py) ... [?25l[?25hdone
  Created wheel for Pytorch-QNN: filename=Pytorch_QNN-1-py3-none-any.whl size=21516 sha256=6169ef6380b69831283c053b633c4f92bf751e552afabd7ec0b6ba2bb88bc5f4
  Stored in directory: /tmp/pip-ephem-wheel-cache-ui81cocp/wheels/55/78/10/235c627601beea89722aa1507e19d17aae118511b3de0799b6
Succ

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

### Blocks of the model

In [None]:
from core_qnn.quaternion_layers import QuaternionConv
from core_qnn.quaternion_layers import QuaternionLinear

In [None]:
class QuaternionBatchNorm2d(nn.Module):
    """Applies a 2D Quaternion Batch Normalization to the incoming data. """

    def __init__(self, num_features, gamma_init=1., beta_param=True, training=True):
        super(QuaternionBatchNorm2d, self).__init__()
        self.num_features = num_features // 4
        self.gamma_init = gamma_init
        self.beta_param = beta_param
        self.gamma = nn.Parameter(torch.full([1, self.num_features, 1, 1], self.gamma_init))
        self.beta = nn.Parameter(torch.zeros(1, self.num_features * 4, 1, 1), requires_grad=self.beta_param)
        self.training = training
        self.eps = torch.tensor(1e-5)

    def reset_parameters(self):
        self.gamma = nn.Parameter(torch.full([1, self.num_features, 1, 1], self.gamma_init))
        self.beta = nn.Parameter(torch.zeros(1, self.num_features * 4, 1, 1), requires_grad=self.beta_param)

    def forward(self, input):
        quat_components = torch.chunk(input, 4, dim=1)
        r, i, j, k = quat_components[0], quat_components[1], quat_components[2], quat_components[3]
        delta_r, delta_i, delta_j, delta_k = r - torch.mean(r), i - torch.mean(i), j - torch.mean(j), k - torch.mean(k)
        quat_variance = torch.mean((delta_r**2 + delta_i**2 + delta_j**2 + delta_k**2))
        denominator = torch.sqrt(quat_variance + self.eps)

        # Normalize
        r_normalized = delta_r / denominator
        i_normalized = delta_i / denominator
        j_normalized = delta_j / denominator
        k_normalized = delta_k / denominator

        beta_components = torch.chunk(self.beta, 4, dim=1)

        # Multiply gamma (stretch scale) and add beta (shift scale)
        new_r = (self.gamma * r_normalized) + beta_components[0]
        new_i = (self.gamma * i_normalized) + beta_components[1]
        new_j = (self.gamma * j_normalized) + beta_components[2]
        new_k = (self.gamma * k_normalized) + beta_components[3]

        new_input = torch.cat((new_r, new_i, new_j, new_k), dim=1)

        return new_input

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'num_features=' + str(self.num_features) \
               + ', gamma=' + str(self.gamma) \
               + ', beta=' + str(self.beta) \
               + ', eps=' + str(self.eps) + ')'

In [None]:
class QBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(QBasicBlock, self).__init__()
        self.conv1 = QuaternionConv(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = QuaternionBatchNorm2d(planes)
        self.conv2 = QuaternionConv(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = QuaternionBatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                QuaternionConv(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                QuaternionBatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
class QBottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(QBottleneck, self).__init__()
        self.conv1 = QuaternionConv(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = QuaternionBatchNorm2d(planes)
        self.conv2 = QuaternionConv(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = QuaternionBatchNorm2d(planes)
        self.conv3 = QuaternionConv(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = QuaternionBatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.QuaternionConv(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                QuaternionBatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

### The model (Real resnet)

In [None]:
class QResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=12):
        super(QResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = QuaternionConv(4, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = QuaternionLinear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        zeros = torch.zeros(x.shape[0], 
                            1, 
                            x.shape[2], 
                            x.shape[3], dtype=x.dtype, device=x.device)
        x = torch.cat((zeros, x), dim=1)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        out = out[:, 2:]
        return out

### Training procedure

In [None]:
# Training
def train(net, criterion, trainloader, optimizer, device, epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


### Testing procedure

In [None]:
def test(net, testloader, device, criterion, epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc


### Load data

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0
start_epoch = 0

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


### Create the model

In [None]:
# Model
print('==> Building model..')
net = QResNet(QBasicBlock, [2, 2, 2, 2])
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

==> Building model..


### Training 

In [None]:
for epoch in range(start_epoch, start_epoch+200):
    train(net, criterion, trainloader, optimizer, device, epoch)
    test(net, testloader, device, criterion, epoch)
    scheduler.step()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
33 100 Loss: 0.289 | Acc: 93.088% (3165/3400)
34 100 Loss: 0.287 | Acc: 93.086% (3258/3500)
35 100 Loss: 0.292 | Acc: 93.056% (3350/3600)
36 100 Loss: 0.298 | Acc: 92.919% (3438/3700)
37 100 Loss: 0.305 | Acc: 92.816% (3527/3800)
38 100 Loss: 0.301 | Acc: 92.897% (3623/3900)
39 100 Loss: 0.299 | Acc: 92.925% (3717/4000)
40 100 Loss: 0.300 | Acc: 92.902% (3809/4100)
41 100 Loss: 0.303 | Acc: 92.857% (3900/4200)
42 100 Loss: 0.300 | Acc: 92.907% (3995/4300)
43 100 Loss: 0.304 | Acc: 92.932% (4089/4400)
44 100 Loss: 0.306 | Acc: 92.889% (4180/4500)
45 100 Loss: 0.309 | Acc: 92.804% (4269/4600)
46 100 Loss: 0.308 | Acc: 92.787% (4361/4700)
47 100 Loss: 0.312 | Acc: 92.750% (4452/4800)
48 100 Loss: 0.307 | Acc: 92.878% (4551/4900)
49 100 Loss: 0.310 | Acc: 92.840% (4642/5000)
50 100 Loss: 0.305 | Acc: 92.922% (4739/5100)
51 100 Loss: 0.305 | Acc: 92.942% (4833/5200)
52 100 Loss: 0.302 | Acc: 92.962% (4927/5300)
53 100 Loss: 0.

# Test model and compute accuracy

In [None]:
# Convert testset to appropriate format
test_labels = []
predictions = []

with torch.no_grad():
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)

        test_labels.extend(labels.cpu().numpy().tolist())
        predictions.extend(predicted.cpu().numpy().tolist())

# Compute the classification report using sklearn library
from sklearn.metrics import classification_report

cr = classification_report(test_labels, predictions, target_names=classes)

# Display the classification report
print("Classification Report:")
print(cr)

Classification Report:
              precision    recall  f1-score   support

       plane       0.92      0.95      0.93      1000
         car       0.97      0.98      0.98      1000
        bird       0.93      0.90      0.91      1000
         cat       0.87      0.84      0.86      1000
        deer       0.92      0.94      0.93      1000
         dog       0.89      0.89      0.89      1000
        frog       0.95      0.96      0.95      1000
       horse       0.96      0.96      0.96      1000
        ship       0.96      0.96      0.96      1000
       truck       0.95      0.95      0.95      1000

    accuracy                           0.93     10000
   macro avg       0.93      0.93      0.93     10000
weighted avg       0.93      0.93      0.93     10000

