#### This jupyter contains out initial implementation of the GCNN artchitecture. This solution is slow for three reasons:
1. Conv2D is called way more than needed, the solution does not use vector stacking
2. Does not use a filther bank, which computes transformations on the go
3. Does not use batching

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
from torch.utils.data import DataLoader, Subset
from torch.nn import AdaptiveAvgPool3d
import os
from tqdm import tqdm

# Reproducibility
torch.manual_seed(2)
np.random.seed(2)
random.seed(2)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Loader

def get_mnist_loaders(data_dir="./data", batch_size=64, n_train=400, n_test=80, n_val=50, digits=(0, 1, 2, 3, 4)):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    rotation_aug = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomRotation((0, 360)),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    trainset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
    testset = datasets.MNIST(data_dir, train=False, download=True, transform=rotation_aug)

    # Filter datasets to only include specified digits
    train_indices = [i for i, target in enumerate(trainset.targets) if target in digits]
    test_indices = [i for i, target in enumerate(testset.targets) if target in digits]

    train_subset = Subset(trainset, train_indices[:min(n_train, len(train_indices))])
    test_subset = Subset(testset, test_indices[:min(n_test, len(test_indices))])
    val_subset = Subset(testset, test_indices[min(n_test, len(test_indices)):min(n_test + n_val, len(test_indices))])

    digit_map = {digit: i for i, digit in enumerate(digits)}

    # Fix the labels in the subsets directly
    def map_targets(subset):
        subset.dataset.targets[subset.indices] = torch.tensor([
            digit_map[subset.dataset.targets[i].item()] for i in subset.indices
        ])

    map_targets(train_subset)
    map_targets(test_subset)
    map_targets(val_subset)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = get_mnist_loaders(digits=(0, 1, 7, 8))



# Group Definition
class CyclicGroup:
    def __init__(self, order: int):
        self.n_int = order  # Group order for C_n (discrete cyclic group)

    def elements(self, device=None):
        # Discrete rotation group elements as angles (in radians)
        # These represent H in G = R^2 ⋉ H
        angles = [2 * np.pi * k / self.n_int for k in range(self.n_int)]
        return torch.tensor(angles, dtype=torch.float32, device=device)

    def product(self, h, h_prime):
        # Group product: h ⋅ h'
        return torch.remainder(h + h_prime, 2 * np.pi)

    def inverse(self, h):
        # Group inverse: h⁻¹
        return torch.remainder(-h, 2 * np.pi)

    def left_regular_representation(self, h, x):
        # Applies the left regular representation matrix to x ∈ ℝ²
        # Corresponds to L^H_h in the theory
        return torch.matmul(self.matrix_representation(h), x)

    def matrix_representation(self, h):
        # Standard SO(2) matrix representation
        return torch.tensor([
            [torch.cos(h), -torch.sin(h)],
            [torch.sin(h), torch.cos(h)]
        ], device=h.device)


# Rotation of 2D kernels (L^H_h(k))
def rotate_kernel(kernel_2d, group, angle_rad):
    # Implements L^H_h(k): transform kernel by group element h
    # i.e., rotate the kernel in the spatial domain (inverse action)

    # Create normalized grid
    y, x = torch.meshgrid(torch.linspace(-1, 1, kernel_2d.shape[0], device=kernel_2d.device),
                          torch.linspace(-1, 1, kernel_2d.shape[1], device=kernel_2d.device),
                          indexing='ij')
    grid = torch.stack([x, y], dim=-1).view(-1, 2)

    # Apply inverse transformation (h⁻¹ ▷ y)
    angle_inv = group.inverse(angle_rad)
    rot_grid_flat = torch.stack([group.left_regular_representation(angle_inv, coord) for coord in grid], dim=0)
    rot_grid = rot_grid_flat.view(kernel_2d.shape[0], kernel_2d.shape[1], 2).unsqueeze(0)

    # Grid sample (resample the kernel at transformed coordinates)
    kernel = kernel_2d.unsqueeze(0).unsqueeze(0)
    rotated = F.grid_sample(kernel, rot_grid, align_corners=True, mode='bilinear', padding_mode='zeros')

    return rotated.squeeze()  # Returns k(h⁻¹ ▷ y)


# Lifting Layer: ℝ² → G
class LiftingConvolution(nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.group = group
        self.base_kernel = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        nn.init.kaiming_uniform_(self.base_kernel, a=np.sqrt(5))

    def forward(self, x):
        # Input: x ∈ ℝ² (image)
        # Output: f_out ∈ G → ℝ^C (stack of rotated conv results)

        group_elements = self.group.elements(device=x.device)
        outputs = []

        for angle in group_elements:
            rotated_kernel = torch.zeros_like(self.base_kernel)
            for oc in range(self.base_kernel.shape[0]):
                for ic in range(self.base_kernel.shape[1]):
                    # Apply L^H_h to kernel
                    rotated_kernel[oc, ic] = rotate_kernel(self.base_kernel[oc, ic], self.group, angle)

            # Standard 2D convolution (L^ℝ²_x)
            conv_out = F.conv2d(x, rotated_kernel, padding=self.padding)
            outputs.append(conv_out.unsqueeze(2))  # Append G-dimension

        return torch.cat(outputs, dim=2)  # Shape: (B, C_out, G, H, W)

# Group Convolution Layer: G → G
class GroupConvolution(nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.group = group
        self.G = group.n_int

        # Kernels indexed by group elements: k[g_out][g_in]
        self.base_kernel = nn.Parameter(torch.randn(
            self.G, self.G, out_channels, in_channels, kernel_size, kernel_size
        ))
        nn.init.kaiming_uniform_(self.base_kernel, a=np.sqrt(5))

    def forward(self, x):
        # Input: f_in ∈ G → ℝ^C_in (B, C_in, G, H, W)
        # Output: f_out ∈ G → ℝ^C_out (B, C_out, G, H, W)

        group_elements = self.group.elements(device=x.device)
        output = torch.zeros(x.shape[0], self.base_kernel.shape[2], self.G, x.shape[3], x.shape[4], device=x.device)

        for g_out_idx, g_out in enumerate(group_elements):
            for g_in_idx, g_in in enumerate(group_elements):
                # Relative transformation h_rel = g_out⁻¹ * g_in
                h_rel = self.group.product(self.group.inverse(g_out), g_in)

                for oc in range(self.base_kernel.shape[2]):
                    for ic in range(self.base_kernel.shape[3]):
                        base_k = self.base_kernel[g_out_idx, g_in_idx, oc, ic]
                        # Apply L^H_{h_rel}(k)
                        rot_k = rotate_kernel(base_k, self.group, h_rel).unsqueeze(0).unsqueeze(0)
                        x_in = x[:, ic, g_in_idx].unsqueeze(1)
                        output[:, oc, g_out_idx] += F.conv2d(x_in, rot_k, padding=self.padding).squeeze(1)

        return output  # G → ℝ^C_out



class CNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):
        super().__init__()
        padding = kernel_size // 2  # e.g., for kernel_size=5, padding=2
        self.first_conv = nn.Conv2d(in_channels, hidden_channels, kernel_size, padding=padding)
        self.convs = nn.ModuleList([
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size, padding=padding) for _ in range(num_hidden)
        ])
        self.final_linear = nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = self.first_conv(x)
        x = F.layer_norm(x, x.shape[-3:])
        x = F.relu(x)
        for conv in self.convs:
            x = conv(x)
            x = F.layer_norm(x, x.shape[-3:])
            x = F.relu(x)
        x = F.adaptive_avg_pool2d(x, 1).squeeze()
        return self.final_linear(x)


# Group Equivariant CNN Architecture
class GroupEquivariantCNN(nn.Module):
    def __init__(self, group, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):
        super().__init__()
        self.lifting_conv = LiftingConvolution(group, in_channels, hidden_channels, kernel_size, padding=kernel_size//2)
        self.gconvs = nn.ModuleList([
            GroupConvolution(group, hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)
            for _ in range(num_hidden)
        ])
        # Projection from G → ℝ by global average over G, H, W
        self.projection_layer = AdaptiveAvgPool3d(1)
        self.final_linear = nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        # Step 1: Lift input image from ℝ² to G (lifting convolution)
        x = self.lifting_conv(x)
        x = F.layer_norm(x, x.shape[-4:])
        x = F.relu(x)

        # Step 2: Apply group convolutions G → G
        for gconv in self.gconvs:
            x = gconv(x)
            x = F.layer_norm(x, x.shape[-4:])
            x = F.relu(x)

        # Step 3: Pool over all dimensions (incl. group axis) and classify
        x = self.projection_layer(x).squeeze()
        return self.final_linear(x)


# Training utilities

def train_model(model_name, model_hparams, optimizer_name, optimizer_hparams):
    model_class = GroupEquivariantCNN if model_name == "GCNN" else CNN
    model = model_class(**model_hparams).to(device)

    optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), **optimizer_hparams)
    criterion = nn.CrossEntropyLoss()

    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(1, 16):
        model.train()
        correct, total, loss_sum = 0, 0, 0
        for x, y in tqdm(train_loader, desc=f"[{model_name}] Epoch {epoch} - Training"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item() * x.size(0)
            correct += (out.argmax(1) == y).sum().item()
            total += x.size(0)
        train_loss = loss_sum / total
        train_acc = correct / total

        model.eval()
        def evaluate(loader, desc):
            correct, total, loss_sum = 0, 0, 0
            for x, y in tqdm(loader, desc=desc, leave=False):
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = criterion(out, y)
                loss_sum += loss.item() * x.size(0)
                correct += (out.argmax(1) == y).sum().item()
                total += x.size(0)
            return loss_sum / total, correct / total

        test_loss, test_acc = evaluate(test_loader, f"[{model_name}] Epoch {epoch} - Test")
        val_loss, val_acc = evaluate(val_loader, f"[{model_name}] Epoch {epoch} - Validation")

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"\nEpoch {epoch}: {model_name} | "
              f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f} | Val Acc: {val_acc:.4f}")

    return model, history


cnn_model, cnn_results = train_model(
    model_name="CNN",
    model_hparams={"in_channels": 1, "out_channels": 5, "kernel_size": 5, "num_hidden": 3, "hidden_channels": 16},
    optimizer_name="Adam",
    optimizer_hparams={"lr": 0.001, "weight_decay": 1e-5},
)


# Model Training Calls
gcnn_model, gcnn_results = train_model(
    model_name="GCNN",
    model_hparams={"in_channels": 1, "out_channels": 5, "kernel_size": 5, "num_hidden": 3  ,
                   "hidden_channels": 16, "group": CyclicGroup(order=4)},
    optimizer_name="Adam",
    optimizer_hparams={"lr": 0.001, "weight_decay": 1e-5},
)




[CNN] Epoch 1 - Training: 100%|██████████| 7/7 [00:00<00:00, 17.18it/s]
                                                                 


Epoch 1: CNN | Train Acc: 0.2750 | Test Acc: 0.2125 | Val Acc: 0.2400


[CNN] Epoch 2 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.64it/s]
                                                                 


Epoch 2: CNN | Train Acc: 0.2325 | Test Acc: 0.2250 | Val Acc: 0.2400


[CNN] Epoch 3 - Training: 100%|██████████| 7/7 [00:00<00:00, 16.90it/s]
                                                                 


Epoch 3: CNN | Train Acc: 0.3450 | Test Acc: 0.3875 | Val Acc: 0.5400


[CNN] Epoch 4 - Training: 100%|██████████| 7/7 [00:00<00:00, 16.71it/s]
                                                                 


Epoch 4: CNN | Train Acc: 0.5500 | Test Acc: 0.3250 | Val Acc: 0.3400


[CNN] Epoch 5 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.57it/s]
                                                                 


Epoch 5: CNN | Train Acc: 0.4800 | Test Acc: 0.3750 | Val Acc: 0.4000


[CNN] Epoch 6 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.47it/s]
                                                                 


Epoch 6: CNN | Train Acc: 0.4825 | Test Acc: 0.3000 | Val Acc: 0.4600


[CNN] Epoch 7 - Training: 100%|██████████| 7/7 [00:00<00:00, 17.75it/s]
                                                                 


Epoch 7: CNN | Train Acc: 0.5950 | Test Acc: 0.2250 | Val Acc: 0.3200


[CNN] Epoch 8 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.39it/s]
                                                                 


Epoch 8: CNN | Train Acc: 0.5700 | Test Acc: 0.4125 | Val Acc: 0.4200


[CNN] Epoch 9 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.29it/s]
                                                                 


Epoch 9: CNN | Train Acc: 0.6900 | Test Acc: 0.4375 | Val Acc: 0.5200


[CNN] Epoch 10 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.39it/s]
                                                                  


Epoch 10: CNN | Train Acc: 0.6750 | Test Acc: 0.4375 | Val Acc: 0.4400


[CNN] Epoch 11 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.20it/s]
                                                                  


Epoch 11: CNN | Train Acc: 0.6750 | Test Acc: 0.4875 | Val Acc: 0.5400


[CNN] Epoch 12 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.73it/s]
                                                                  


Epoch 12: CNN | Train Acc: 0.6500 | Test Acc: 0.5500 | Val Acc: 0.7000


[CNN] Epoch 13 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.65it/s]
                                                                  


Epoch 13: CNN | Train Acc: 0.6975 | Test Acc: 0.4500 | Val Acc: 0.6400


[CNN] Epoch 14 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.55it/s]
                                                                  


Epoch 14: CNN | Train Acc: 0.7300 | Test Acc: 0.5000 | Val Acc: 0.4800


[CNN] Epoch 15 - Training: 100%|██████████| 7/7 [00:00<00:00, 17.17it/s]
                                                                  


Epoch 15: CNN | Train Acc: 0.8075 | Test Acc: 0.6000 | Val Acc: 0.5600


[GCNN] Epoch 1 - Training: 100%|██████████| 7/7 [19:23<00:00, 166.23s/it]
                                                                          


Epoch 1: GCNN | Train Acc: 0.2400 | Test Acc: 0.3000 | Val Acc: 0.2600


[GCNN] Epoch 2 - Training: 100%|██████████| 7/7 [19:02<00:00, 163.26s/it]
                                                                          


Epoch 2: GCNN | Train Acc: 0.2750 | Test Acc: 0.3625 | Val Acc: 0.2800


[GCNN] Epoch 3 - Training: 100%|██████████| 7/7 [18:54<00:00, 162.07s/it]
                                                                          


Epoch 3: GCNN | Train Acc: 0.2825 | Test Acc: 0.3625 | Val Acc: 0.2800


[GCNN] Epoch 4 - Training: 100%|██████████| 7/7 [19:06<00:00, 163.72s/it]
                                                                          


Epoch 4: GCNN | Train Acc: 0.2825 | Test Acc: 0.3625 | Val Acc: 0.2800


[GCNN] Epoch 5 - Training: 100%|██████████| 7/7 [18:45<00:00, 160.85s/it]
                                                                          


Epoch 5: GCNN | Train Acc: 0.3575 | Test Acc: 0.5875 | Val Acc: 0.5000


[GCNN] Epoch 6 - Training: 100%|██████████| 7/7 [18:55<00:00, 162.16s/it]
                                                                          


Epoch 6: GCNN | Train Acc: 0.3525 | Test Acc: 0.5750 | Val Acc: 0.5000


[GCNN] Epoch 7 - Training: 100%|██████████| 7/7 [18:51<00:00, 161.62s/it]
                                                                          


Epoch 7: GCNN | Train Acc: 0.4375 | Test Acc: 0.5375 | Val Acc: 0.3800


[GCNN] Epoch 8 - Training: 100%|██████████| 7/7 [18:37<00:00, 159.63s/it]
                                                                          


Epoch 8: GCNN | Train Acc: 0.5400 | Test Acc: 0.6375 | Val Acc: 0.5400


[GCNN] Epoch 9 - Training: 100%|██████████| 7/7 [18:28<00:00, 158.35s/it]
                                                                          


Epoch 9: GCNN | Train Acc: 0.5200 | Test Acc: 0.6250 | Val Acc: 0.5600


[GCNN] Epoch 10 - Training: 100%|██████████| 7/7 [18:48<00:00, 161.28s/it]
                                                                           


Epoch 10: GCNN | Train Acc: 0.5550 | Test Acc: 0.6125 | Val Acc: 0.5200


[GCNN] Epoch 11 - Training: 100%|██████████| 7/7 [18:56<00:00, 162.29s/it]
                                                                           


Epoch 11: GCNN | Train Acc: 0.5575 | Test Acc: 0.6500 | Val Acc: 0.5800


[GCNN] Epoch 12 - Training: 100%|██████████| 7/7 [18:35<00:00, 159.40s/it]
                                                                           


Epoch 12: GCNN | Train Acc: 0.6025 | Test Acc: 0.6125 | Val Acc: 0.5600


[GCNN] Epoch 13 - Training: 100%|██████████| 7/7 [18:48<00:00, 161.27s/it]
                                                                           


Epoch 13: GCNN | Train Acc: 0.6075 | Test Acc: 0.6000 | Val Acc: 0.6400


[GCNN] Epoch 14 - Training: 100%|██████████| 7/7 [19:00<00:00, 162.91s/it]
                                                                           


Epoch 14: GCNN | Train Acc: 0.5925 | Test Acc: 0.6375 | Val Acc: 0.5800


[GCNN] Epoch 15 - Training: 100%|██████████| 7/7 [19:02<00:00, 163.28s/it]
                                                                           


Epoch 15: GCNN | Train Acc: 0.6225 | Test Acc: 0.6625 | Val Acc: 0.5400


[CNN] Epoch 1 - Training: 100%|██████████| 7/7 [00:00<00:00, 17.18it/s]
                                                                 
Epoch 1: CNN | Train Acc: 0.2750 | Test Acc: 0.2125 | Val Acc: 0.2400
[CNN] Epoch 2 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.64it/s]
                                                                 
Epoch 2: CNN | Train Acc: 0.2325 | Test Acc: 0.2250 | Val Acc: 0.2400
[CNN] Epoch 3 - Training: 100%|██████████| 7/7 [00:00<00:00, 16.90it/s]
                                                                 
Epoch 3: CNN | Train Acc: 0.3450 | Test Acc: 0.3875 | Val Acc: 0.5400
[CNN] Epoch 4 - Training: 100%|██████████| 7/7 [00:00<00:00, 16.71it/s]
                                                                 
Epoch 4: CNN | Train Acc: 0.5500 | Test Acc: 0.3250 | Val Acc: 0.3400
[CNN] Epoch 5 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.57it/s]
                                                                 
Epoch 5: CNN | Train Acc: 0.4800 | Test Acc: 0.3750 | Val Acc: 0.4000
[CNN] Epoch 6 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.47it/s]
                                                                 
Epoch 6: CNN | Train Acc: 0.4825 | Test Acc: 0.3000 | Val Acc: 0.4600
[CNN] Epoch 7 - Training: 100%|██████████| 7/7 [00:00<00:00, 17.75it/s]
                                                                 
Epoch 7: CNN | Train Acc: 0.5950 | Test Acc: 0.2250 | Val Acc: 0.3200
[CNN] Epoch 8 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.39it/s]
                                                                 
Epoch 8: CNN | Train Acc: 0.5700 | Test Acc: 0.4125 | Val Acc: 0.4200
[CNN] Epoch 9 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.29it/s]
                                                                 
Epoch 9: CNN | Train Acc: 0.6900 | Test Acc: 0.4375 | Val Acc: 0.5200
[CNN] Epoch 10 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.39it/s]
                                                                  
Epoch 10: CNN | Train Acc: 0.6750 | Test Acc: 0.4375 | Val Acc: 0.4400
[CNN] Epoch 11 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.20it/s]
                                                                  
Epoch 11: CNN | Train Acc: 0.6750 | Test Acc: 0.4875 | Val Acc: 0.5400
[CNN] Epoch 12 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.73it/s]
                                                                  
Epoch 12: CNN | Train Acc: 0.6500 | Test Acc: 0.5500 | Val Acc: 0.7000
[CNN] Epoch 13 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.65it/s]
                                                                  
Epoch 13: CNN | Train Acc: 0.6975 | Test Acc: 0.4500 | Val Acc: 0.6400
[CNN] Epoch 14 - Training: 100%|██████████| 7/7 [00:00<00:00, 18.55it/s]
                                                                  
Epoch 14: CNN | Train Acc: 0.7300 | Test Acc: 0.5000 | Val Acc: 0.4800
[CNN] Epoch 15 - Training: 100%|██████████| 7/7 [00:00<00:00, 17.17it/s]
                                                                  
**Epoch 15: CNN | Train Acc: 0.8075 | Test Acc: 0.6000 | Val Acc: 0.5600**


[GCNN] Epoch 1 - Training: 100%|██████████| 7/7 [19:23<00:00, 166.23s/it]
                                                                          
Epoch 1: GCNN | Train Acc: 0.2400 | Test Acc: 0.3000 | Val Acc: 0.2600
[GCNN] Epoch 2 - Training: 100%|██████████| 7/7 [19:02<00:00, 163.26s/it]
                                                                          
Epoch 2: GCNN | Train Acc: 0.2750 | Test Acc: 0.3625 | Val Acc: 0.2800
[GCNN] Epoch 3 - Training: 100%|██████████| 7/7 [18:54<00:00, 162.07s/it]
                                                                          
Epoch 3: GCNN | Train Acc: 0.2825 | Test Acc: 0.3625 | Val Acc: 0.2800
[GCNN] Epoch 4 - Training: 100%|██████████| 7/7 [19:06<00:00, 163.72s/it]
                                                                          
Epoch 4: GCNN | Train Acc: 0.2825 | Test Acc: 0.3625 | Val Acc: 0.2800
[GCNN] Epoch 5 - Training: 100%|██████████| 7/7 [18:45<00:00, 160.85s/it]
                                                                          
Epoch 5: GCNN | Train Acc: 0.3575 | Test Acc: 0.5875 | Val Acc: 0.5000
[GCNN] Epoch 6 - Training: 100%|██████████| 7/7 [18:55<00:00, 162.16s/it]
                                                                          
Epoch 6: GCNN | Train Acc: 0.3525 | Test Acc: 0.5750 | Val Acc: 0.5000
[GCNN] Epoch 7 - Training: 100%|██████████| 7/7 [18:51<00:00, 161.62s/it]
                                                                          
Epoch 7: GCNN | Train Acc: 0.4375 | Test Acc: 0.5375 | Val Acc: 0.3800
[GCNN] Epoch 8 - Training: 100%|██████████| 7/7 [18:37<00:00, 159.63s/it]
                                                                          
Epoch 8: GCNN | Train Acc: 0.5400 | Test Acc: 0.6375 | Val Acc: 0.5400
[GCNN] Epoch 9 - Training: 100%|██████████| 7/7 [18:28<00:00, 158.35s/it]
                                                                          
Epoch 9: GCNN | Train Acc: 0.5200 | Test Acc: 0.6250 | Val Acc: 0.5600
[GCNN] Epoch 10 - Training: 100%|██████████| 7/7 [18:48<00:00, 161.28s/it]
                                                                           
Epoch 10: GCNN | Train Acc: 0.5550 | Test Acc: 0.6125 | Val Acc: 0.5200
[GCNN] Epoch 11 - Training: 100%|██████████| 7/7 [18:56<00:00, 162.29s/it]
                                                                           
Epoch 11: GCNN | Train Acc: 0.5575 | Test Acc: 0.6500 | Val Acc: 0.5800
[GCNN] Epoch 12 - Training: 100%|██████████| 7/7 [18:35<00:00, 159.40s/it]
                                                                           
Epoch 12: GCNN | Train Acc: 0.6025 | Test Acc: 0.6125 | Val Acc: 0.5600
[GCNN] Epoch 13 - Training: 100%|██████████| 7/7 [18:48<00:00, 161.27s/it]
                                                                           
Epoch 13: GCNN | Train Acc: 0.6075 | Test Acc: 0.6000 | Val Acc: 0.6400
[GCNN] Epoch 14 - Training: 100%|██████████| 7/7 [19:00<00:00, 162.91s/it]
                                                                           
Epoch 14: GCNN | Train Acc: 0.5925 | Test Acc: 0.6375 | Val Acc: 0.5800
[GCNN] Epoch 15 - Training: 100%|██████████| 7/7 [19:02<00:00, 163.28s/it]


**Epoch 15: GCNN | Train Acc: 0.6225 | Test Acc: 0.6625 | Val Acc: 0.5400**