In [1]:
import torch
import torchvision
import numpy as np
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,ConcatDataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
#credit to 'https://github.com/gidariss/FeatureLearningRotNet/blob/master/architectures/NetworkInNetwork.py'
import math

class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size):
        super(BasicBlock, self).__init__()
        padding = (kernel_size - 1) // 2
        self.layers = nn.Sequential()
        self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
                                                 kernel_size=kernel_size, stride=1, padding=padding, bias=False))
        self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes))
        self.layers.add_module('ReLU', nn.ReLU(inplace=True))

    def forward(self, x):
        return self.layers(x)

class GlobalAveragePooling(nn.Module):
    def __init__(self):
        super(GlobalAveragePooling, self).__init__()

    def forward(self, feat):
        num_channels = feat.size(1)
        return F.avg_pool2d(feat, (feat.size(2), feat.size(3))).view(-1, num_channels)

class NetworkInNetwork(nn.Module):
    def __init__(self, opt):
        super(NetworkInNetwork, self).__init__()

        num_classes = opt['num_classes']
        num_inchannels = opt['num_inchannels'] if ('num_inchannels' in opt) else 3
        num_stages = opt['num_stages'] if ('num_stages' in opt) else 3
        use_avg_on_conv3 = opt['use_avg_on_conv3'] if ('use_avg_on_conv3' in opt) else True

        assert(num_stages >= 3)
        nChannels = 192
        nChannels2 = 160
        nChannels3 = 96

        blocks = [nn.Sequential() for _ in range(num_stages)]
        # 1st block
        blocks[0].add_module('Block1_ConvB1', BasicBlock(num_inchannels, nChannels, 5))
        blocks[0].add_module('Block1_ConvB2', BasicBlock(nChannels, nChannels2, 1))
        blocks[0].add_module('Block1_ConvB3', BasicBlock(nChannels2, nChannels3, 1))
        blocks[0].add_module('Block1_MaxPool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        # 2nd block
        blocks[1].add_module('Block2_ConvB1', BasicBlock(nChannels3, nChannels, 5))
        blocks[1].add_module('Block2_ConvB2', BasicBlock(nChannels, nChannels, 1))
        blocks[1].add_module('Block2_ConvB3', BasicBlock(nChannels, nChannels, 1))
        blocks[1].add_module('Block2_AvgPool', nn.AvgPool2d(kernel_size=3, stride=2, padding=1))

        # 3rd block
        blocks[2].add_module('Block3_ConvB1', BasicBlock(nChannels, nChannels, 3))
        blocks[2].add_module('Block3_ConvB2', BasicBlock(nChannels, nChannels, 1))
        blocks[2].add_module('Block3_ConvB3', BasicBlock(nChannels, nChannels, 1))

        if num_stages > 3 and use_avg_on_conv3:
            blocks[2].add_module('Block3_AvgPool', nn.AvgPool2d(kernel_size=3, stride=2, padding=1))

        for s in range(3, num_stages):
            blocks[s].add_module('Block'+str(s+1)+'_ConvB1', BasicBlock(nChannels, nChannels, 3))
            blocks[s].add_module('Block'+str(s+1)+'_ConvB2', BasicBlock(nChannels, nChannels, 1))
            blocks[s].add_module('Block'+str(s+1)+'_ConvB3', BasicBlock(nChannels, nChannels, 1))

        # global average pooling and classifier
        blocks.append(nn.Sequential())
        blocks[-1].add_module('GlobalAveragePooling', GlobalAveragePooling())
        blocks[-1].add_module('Classifier', nn.Linear(nChannels, num_classes))

        self._feature_blocks = nn.ModuleList(blocks)
        self.all_feat_names = ['conv'+str(s+1) for s in range(num_stages)] + ['classifier',]
        assert(len(self.all_feat_names) == len(self._feature_blocks))

    def _parse_out_keys_arg(self, out_feat_keys):
        # By default return the features of the last layer / module.
        out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys

        if len(out_feat_keys) == 0:
            raise ValueError('Empty list of output feature keys.')

        for f, key in enumerate(out_feat_keys):
            if key not in self.all_feat_names:
                raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names))
            elif key in out_feat_keys[:f]:
                raise ValueError('Duplicate output feature key: {0}.'.format(key))

        # Find the highest output feature in `out_feat_keys
        max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys])

        return out_feat_keys, max_out_feat

    def forward(self, x, out_feat_keys=None):
        """Forward an image `x` through the network and return the asked output features.

        Args:
          x: input image.
          out_feat_keys: a list/tuple with the feature names of the features
                that the function should return. By default the last feature of
                the network is returned.

        Return:
            out_feats: If multiple output features were asked then `out_feats`
                is a list with the asked output features placed in the same
                order as in `out_feat_keys`. If a single output feature was
                asked then `out_feats` is that output feature (and not a list).
        """
        out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys)
        out_feats = [None] * len(out_feat_keys)

        feat = x
        for f in range(max_out_feat + 1):
            feat = self._feature_blocks[f](feat)
            key = self.all_feat_names[f]
            if key in out_feat_keys:
                out_feats[out_feat_keys.index(key)] = feat

        out_feats = out_feats[0] if len(out_feats) == 1 else out_feats
        return out_feats

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.weight.requires_grad:
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight.requires_grad:
                    m.weight.data.fill_(1)
                if m.bias.requires_grad:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                if m.bias.requires_grad:
                    m.bias.data.zero_()

def create_model(opt):
    return NetworkInNetwork(opt)

if __name__ == '__main__':
    size = 32
    opt = {'num_classes': 4, 'num_stages': 4}

    net = create_model(opt)
    x = torch.autograd.Variable(torch.FloatTensor(1, 3, size, size).uniform_(-1, 1))

    out = net(x, out_feat_keys=net.all_feat_names)
    for f in range(len(out)):
        print('Output feature {0} - size {1}'.format(
            net.all_feat_names[f], out[f].size()))

    out = net(x)
    print('Final output: {0}'.format(out.size()))

Output feature conv1 - size torch.Size([1, 96, 16, 16])
Output feature conv2 - size torch.Size([1, 192, 8, 8])
Output feature conv3 - size torch.Size([1, 192, 4, 4])
Output feature conv4 - size torch.Size([1, 192, 4, 4])
Output feature classifier - size torch.Size([1, 4])
Final output: torch.Size([1, 4])


In [3]:
def rotate_img(img, rot):
    if rot == 0: # 0 degrees rotation
        return img
    elif rot == 90: # 90 degrees rotation
        return np.flipud(np.transpose(img, (1,0,2)))
    elif rot == 180: # 90 degrees rotation
        return np.fliplr(np.flipud(img))
    elif rot == 270: # 270 degrees rotation / or -90
        return np.transpose(np.flipud(img), (1,0,2))
    else:
        raise ValueError('rotation should be 0, 90, 180, or 270 degrees')

In [4]:
class CIFAR10Rotate(CIFAR10):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR10Rotate, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)

    def __getitem__(self, index):
        img, target = self.data[index // 4], self.targets[index // 4]

        # Use PIL for rotation
        pil_img = Image.fromarray(img)
        rotated_img = pil_img.rotate(90 * (index % 4 ))

        # Convert back to NumPy array
        #img = np.array(rotated_img).permute(2, 0, 1)
        img = torch.tensor(np.transpose(rotated_img, (2, 0, 1)))

        label = index % 4
        return img, label

    def __len__(self):
        # Return the total number of images in the new dataset
        return 4 * len(self.data)  # Each image is rotated four times


In [5]:
#define the loss function
class RotationPredictionLoss(nn.Module):
    """
    This class implements the rotation prediction loss as described in the paper.
    It assumes that the model's output are logits for each of the K classes (geometric transformations).
    """

    def __init__(self):
        super(RotationPredictionLoss, self).__init__()

    def forward(self, logits, targets):
        """
        Calculate the loss given the logits and the target classes.

        :param logits: Predicted logits from the model for each of the K classes. Shape (batch_size, K)
        :param targets: Actual labels of the geometric transformations applied. Shape (batch_size,)
        :return: The mean rotation prediction loss for the batch.
        """
        # Calculate the log probabilities
        log_probs = F.log_softmax(logits, dim=1)

        # Gather the log probabilities by the target class labels
        target_log_probs = log_probs.gather(dim=1, index = targets.unsqueeze(1)).squeeze(1)

        # Calculate the negative mean of these log probabilities
        #loss = -target_log_probs.mean()
        loss = -target_log_probs.sum()/4/128

        return loss

# Example usage:
# Assuming `outputs` are the raw logits from the model's forward pass, and `labels` are the correct class labels
# outputs = model(input_data)
# labels = torch.tensor([correct_class_labels])  # This should be provided as per the dataset
# loss_fn = RotationPredictionLoss()
# loss = loss_fn(outputs, labels)
# loss.backward()  # Backpropagate the loss

# Note that in an actual implementation, you would also include optimizer steps and update the model parameters.

In [6]:
#hyperparameter setting
import torch.optim as optim
INITIAL_LR = 0.1
MOMENTUM = 0.9
REG = 5e-4
criterian = RotationPredictionLoss()
opt = {'num_classes': 4, 'num_stages': 4,'use_avg_on_conv3': False}
NIN_net_4block = create_model(opt)
optimizer = optim.SGD(NIN_net_4block.parameters(), lr = INITIAL_LR,
                      momentum = MOMENTUM,
                      weight_decay=REG)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [7]:
cifar_dataset_train = CIFAR10Rotate(root='./data', train=True, download=True)
cifar_dataset_train_loader = DataLoader(cifar_dataset_train, batch_size=128*4, shuffle=False)
cifar_dataset_val = CIFAR10Rotate(root='./data', train=False, download=True)
cifar_dataset_val_loader = DataLoader(cifar_dataset_val, batch_size=128*4, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 15132348.80it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
import matplotlib.pyplot as plt
# some hyperparameters
# total number of training epochs
EPOCHS = 100

# the folder where the trained model is saved
CHECKPOINT_FOLDER = "./saved_model"

# start the training/validation process
# the process should take about 5 minutes on a GTX 1070-Ti
# if the code is written efficiently.
best_val_acc = 0
current_learning_rate = INITIAL_LR

DECAY = 0.2
valid_acc = []
losslist = []
print("==> Training starts!")
print("="*50)
for i in range(0, EPOCHS):
    if i  == 35 or i == 70 or i == 85 or i==100 :
        current_learning_rate = current_learning_rate * DECAY
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_learning_rate
        print("Current learning rate has decayed to %f" %current_learning_rate)

    NIN_net_4block.train()

    #######################

    print("Epoch %d:" %i)
    # this help you compute the training accuracy
    total_examples = 0
    correct_examples = 0
    train_loss = 0 # track training loss if you want

    # Train the model for 1 epoch.
    for batch_idx, (inputs, targets) in enumerate(cifar_dataset_train_loader):
        ####################################

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        inputs = inputs.float().to(device)
        targets = targets.to(device)
        NIN_net_4block = NIN_net_4block.to(device)
        # compute the output and loss
        outputs = NIN_net_4block(inputs)
        loss = RotationPredictionLoss().forward(outputs, targets)

        train_loss += loss
        # zero the gradient

        optimizer.zero_grad()
        # backpropagation

        loss.backward()
        # apply gradient and update the weights
        optimizer.step()
        # count the number of correctly predicted samples in the current batch
        _, predicted = outputs.max(1)  #make prediction based on the highest value
        total_examples += targets.size(0) # in this case,128 for each batch
        correct_examples += predicted.eq(targets).sum().item()
        ####################################

    avg_loss = train_loss / len(cifar_dataset_train_loader)
    avg_acc = correct_examples / total_examples
    print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))


    NIN_net_4block.eval()

    #######################

    # this help you compute the validation accuracy
    total_examples = 0
    correct_examples = 0

    val_loss = 0 # again, track the validation loss if you want

    # disable gradient during validation, which can save GPU memory
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(cifar_dataset_val_loader):
            ####################################
            # your code here
            # copy inputs to device
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            inputs = inputs.float().to(device)
            targets = targets.to(device)
            NIN_net_4block = NIN_net_4block.to(device)


            # compute the output and loss

            outputs = NIN_net_4block(inputs)
            loss = RotationPredictionLoss().forward(outputs, targets)
            val_loss += loss

            # count the number of correctly predicted samples in the current batch
            _, predicted = outputs.max(1)
            total_examples += targets.size(0)
            correct_examples += predicted.eq(targets).sum().item()
            ####################################

    avg_loss = val_loss / len(cifar_dataset_val_loader)
    avg_acc = correct_examples / total_examples
    print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))
    valid_acc.append(avg_acc)
    losslist.append(avg_loss.item())
    # save the model checkpoint
    if avg_acc > best_val_acc:
        best_val_acc = avg_acc
        #if not os.path.exists(CHECKPOINT_FOLDER):
        #    os.makedirs(CHECKPOINT_FOLDER)
        #print("Saving ...")

        torch.save(NIN_net_4block.state_dict(), 'NIN_net_4block.pth')

    print('')

print("="*50)
print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")

==> Training starts!
Epoch 0:
Training loss: 0.8377, Training accuracy: 0.6543
Validation loss: 0.7875, Validation accuracy: 0.6776

Epoch 1:
Training loss: 0.6197, Training accuracy: 0.7563
Validation loss: 0.7003, Validation accuracy: 0.7172

Epoch 2:
Training loss: 0.5315, Training accuracy: 0.7938
Validation loss: 0.6382, Validation accuracy: 0.7469

Epoch 3:
Training loss: 0.4747, Training accuracy: 0.8186
Validation loss: 0.6825, Validation accuracy: 0.7288

Epoch 4:
Training loss: 0.4352, Training accuracy: 0.8340
Validation loss: 0.6647, Validation accuracy: 0.7414

Epoch 5:
Training loss: 0.4061, Training accuracy: 0.8462
Validation loss: 0.8968, Validation accuracy: 0.6772

Epoch 6:
Training loss: 0.3884, Training accuracy: 0.8534
Validation loss: 0.7538, Validation accuracy: 0.7167

Epoch 7:
Training loss: 0.3721, Training accuracy: 0.8596
Validation loss: 0.6770, Validation accuracy: 0.7506

Epoch 8:
Training loss: 0.3599, Training accuracy: 0.8646
Validation loss: 0.6688, 