In [1]:
import os
import time
import e2cnn
import scipy
import torch
import numpy as np
import torchvision
from PIL import Image
from e2cnn import gspaces
from mlxtend.data import loadlocal_mnist

In [2]:
# MNIST dataset class
class MnistDataset(torch.utils.data.Dataset):

    def __init__(self, mode, transform=None, rotation=None):
        assert mode in ['train', 'test']
        data_dir = '/home/zhuokai/Desktop/nvme1n1p1/Data/MNIST/original'

        if mode == 'train':
            images_path = os.path.join(data_dir, 'train-images-idx3-ubyte')
            labels_path = os.path.join(data_dir, 'train-labels-idx1-ubyte')
        elif mode == 'test':
            self.rotation = rotation
            images_path = os.path.join(data_dir, 't10k-images-idx3-ubyte')
            labels_path = os.path.join(data_dir, 't10k-labels-idx1-ubyte')

        self.mode = mode
        self.transform = transform

        self.images, self.labels = loadlocal_mnist(images_path=images_path,
                                                   labels_path=labels_path)

        self.images = self.images.reshape(-1, 28, 28).astype(np.float32)
        self.labels = self.labels.astype(np.int64)
        self.num_samples = len(self.labels)

        # normalization and conversion
        self.to_tensor = torchvision.transforms.ToTensor()
        self.normalize = torchvision.transforms.Normalize((0.1307,), (0.3081,))

    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
            image = self.to_tensor(np.array(image))
            image = self.normalize(image)


        return image, label

    def __len__(self):
        return len(self.labels)

In [3]:
# rotation equivariant network using e2cnn
class Rot_Eqv_Net_MNIST(torch.nn.Module):
    def __init__(self, image_size, num_rotation, n_classes=10):

        super(Rot_Eqv_Net_MNIST, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=num_rotation)

        # the input image is a scalar field, corresponding to the trivial representation
        in_type = e2cnn.nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type

        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = e2cnn.nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])
        if image_size != None:
            self.block1 = e2cnn.nn.SequentialModule(
                e2cnn.nn.MaskModule(in_type, image_size[0], margin=image_size[0]-28),
                e2cnn.nn.R2Conv(in_type, out_type, kernel_size=7, padding=2, bias=False),
                e2cnn.nn.InnerBatchNorm(out_type),
                e2cnn.nn.ReLU(out_type, inplace=True)
            )
        else:
            self.block1 = e2cnn.nn.SequentialModule(
                e2cnn.nn.MaskModule(in_type, 29, margin=1),
                e2cnn.nn.R2Conv(in_type, out_type, kernel_size=7, padding=2, bias=False),
                e2cnn.nn.InnerBatchNorm(out_type),
                e2cnn.nn.ReLU(out_type, inplace=True)
            )

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = e2cnn.nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block2 = e2cnn.nn.SequentialModule(
            e2cnn.nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            e2cnn.nn.InnerBatchNorm(out_type),
            e2cnn.nn.ReLU(out_type, inplace=True)
        )
        self.pool1 = e2cnn.nn.SequentialModule(
            e2cnn.nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )

        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = e2cnn.nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])
        self.block3 = e2cnn.nn.SequentialModule(
            e2cnn.nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            e2cnn.nn.InnerBatchNorm(out_type),
            e2cnn.nn.ReLU(out_type, inplace=True)
        )

        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = e2cnn.nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block4 = e2cnn.nn.SequentialModule(
            e2cnn.nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            e2cnn.nn.InnerBatchNorm(out_type),
            e2cnn.nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = e2cnn.nn.SequentialModule(
            e2cnn.nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )

        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = e2cnn.nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])
        self.block5 = e2cnn.nn.SequentialModule(
            e2cnn.nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            e2cnn.nn.InnerBatchNorm(out_type),
            e2cnn.nn.ReLU(out_type, inplace=True)
        )

        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = e2cnn.nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])
        self.block6 = e2cnn.nn.SequentialModule(
            e2cnn.nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            e2cnn.nn.InnerBatchNorm(out_type),
            e2cnn.nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = e2cnn.nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)

        self.gpool = e2cnn.nn.GroupPooling(out_type)

        # number of output channels
        # c = self.gpool.out_type.size
        c = 6400

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )

    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        x = e2cnn.nn.GeometricTensor(input, self.input_type)

        # apply each equivariant block

        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)

        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)

        x = self.block5(x)
        x = self.block6(x)

        # pool over the spatial dimensions
        x = self.pool3(x)

        # pool over the group
        x = self.gpool(x)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor

        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))

        return x

In [4]:
# Print iterations progress
def print_progress_bar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█', printEnd = "\r"):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = printEnd)
    # Print New Line on Complete
    if iteration == total:
        print()

In [5]:
# train each epoch
def train(model, device, train_loader, optimizer):
    all_batch_losses = []
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        batch_time_start = time.time()
        # move data to GPU
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # print(torch.nn.functional.softmax(output).data[0])
        # negative log likelihood loss.
        loss = loss_function(output, target)
        loss.backward()
        all_batch_losses.append(loss.item())
        optimizer.step()
        batch_time_end = time.time()
        batch_time_cost = batch_time_end - batch_time_start

        print_progress_bar(iteration=batch_idx+1,
                            total=len(train_loader),
                            prefix=f'Train batch {batch_idx+1}/{len(train_loader)},',
                            suffix='%s: %.3f, time: %.2f' % ('CE loss', all_batch_losses[-1], batch_time_cost),
                            length=50)

    # return the averaged batch loss
    return np.mean(all_batch_losses)

In [6]:
# validation
def val(model, device, val_loader):
    model.eval()
    loss_function = torch.nn.CrossEntropyLoss()
    all_batch_losses = []
    num_correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(val_loader):
            batch_time_start = time.time()
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_function(output, target)
            all_batch_losses.append(loss.item())
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            num_correct += pred.eq(target.view_as(pred)).sum().item()
            batch_time_end = time.time()
            batch_time_cost = batch_time_end - batch_time_start

            print_progress_bar(iteration=batch_idx+1,
                                total=len(val_loader),
                                prefix=f'Val batch {batch_idx+1}/{len(val_loader)},',
                                suffix='%s: %.3f, time: %.2f' % ('CE loss', all_batch_losses[-1], batch_time_cost),
                                length=50)

    # return the averaged batch loss
    return np.mean(all_batch_losses), num_correct

In [7]:
# kwargs for train and val dataloader
train_batch_size = 64
val_batch_size = 100
train_kwargs = {'batch_size': train_batch_size}
val_kwargs = {'batch_size': val_batch_size}
if torch.cuda.is_available():
    # additional cuda kwargs
    cuda_kwargs = {'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    val_kwargs.update(cuda_kwargs)

    # device set up
    if torch.cuda.device_count() > 1:
        print('\n', torch.cuda.device_count(), 'GPUs available')
        device = torch.device('cuda')
    else:
        device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
    
image_size = (64, 64)
transform = torchvision.transforms.Compose([
                    # transform includes padding (right and bottom) to image_size and totensor
                    torchvision.transforms.Pad((0, 0, image_size[1]-28, image_size[0]-28), fill=0, padding_mode='constant'),
                ])

# split into train and validation dataset
dataset = MnistDataset(mode='train', transform=transform)
train_data, val_data = torch.utils.data.random_split(dataset, [50000, 10000])
# pytorch data loader
train_loader = torch.utils.data.DataLoader(train_data, **train_kwargs)
val_loader = torch.utils.data.DataLoader(val_data, **val_kwargs)

starting_epoch = 0
model = Rot_Eqv_Net_MNIST(image_size=image_size, num_rotation=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)

# train the model
num_epoch = 10
all_epoch_train_losses = []
all_epoch_val_losses = []
for epoch in range(starting_epoch, starting_epoch+num_epoch):
    print(f'\n Starting epoch {epoch+1}/{starting_epoch+num_epoch}')
    epoch_start_time = time.time()
    cur_epoch_train_loss = train(model, device, train_loader, optimizer)
    all_epoch_train_losses.append(cur_epoch_train_loss)
    cur_epoch_val_loss, num_correct = val(model, device, val_loader)
    all_epoch_val_losses.append(cur_epoch_val_loss)
    epoch_end_time = time.time()
    epoch_time_cost = epoch_end_time - epoch_start_time

    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            cur_epoch_val_loss, num_correct, len(val_loader.dataset),
            100. * num_correct / len(val_loader.dataset)))


 3 GPUs available


  sampled_basis = sampled_basis[mask, ...]



 Starting epoch 1/10
Train batch 782/782, |██████████████████████████████████████████████████| 100.0% CE loss: 1.390, time: 0.09
Val batch 100/100, |██████████████████████████████████████████████████| 100.0% CE loss: 1.265, time: 0.07

Validation set: Average loss: 1.0901, Accuracy: 6178/10000 (62%)


 Starting epoch 2/10
Train batch 782/782, |██████████████████████████████████████████████████| 100.0% CE loss: 0.853, time: 0.09
Val batch 100/100, |██████████████████████████████████████████████████| 100.0% CE loss: 0.988, time: 0.07

Validation set: Average loss: 0.9770, Accuracy: 6343/10000 (63%)


 Starting epoch 3/10
Train batch 782/782, |██████████████████████████████████████████████████| 100.0% CE loss: 1.179, time: 0.10
Val batch 100/100, |██████████████████████████████████████████████████| 100.0% CE loss: 1.063, time: 0.07

Validation set: Average loss: 0.9386, Accuracy: 6561/10000 (66%)


 Starting epoch 4/10
Train batch 782/782, |███████████████████████████████████████████████