In [None]:
import torch

# find a clean device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# torch.backends.cudnn.benchmark = True
torch.cuda.set_device(device)
torch.cuda.empty_cache()


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as tF
import torch.nn as nn
import time

import matplotlib.pyplot as plt

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out = F.relu(out + self.shortcut(x))
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out



class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, 3, 1)
        self.conv2 = nn.Conv2d(5, 6, 4, 1)
        self.conv3 = nn.Conv2d(6, 7, 3, 1)
        self.fc1 = nn.Linear(1*1*7, 10)

    def forward(self, x):
        x = tF.relu(self.conv1(x))
        x = tF.max_pool2d(x, 2, 2)
        x = tF.relu(self.conv2(x))
        x = tF.max_pool2d(x, 2, 2)
        x = tF.relu(self.conv3(x))
        x = tF.max_pool2d(x, 2, 2)
        x = x.view(-1, 1*1*7)
        x = self.fc1(x)
        return tF.log_softmax(x, dim=1)

class MLP(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden1, hidden2):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        # self.linear = torch.nn.Linear(input_dim, hidden)
        # self.fc = torch.nn.Linear(hidden, output_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden1),
            nn.ReLU(),
            nn.Linear(hidden1, hidden2),
            nn.ReLU(),
            nn.Linear(hidden2, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        # x = self.linear(x)
        outputs = self.mlp(x)
        return tF.log_softmax(outputs, dim=1)

class MLP2(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden1, hidden2):
        super(MLP2, self).__init__()
        self.input_dim = input_dim
        # self.linear = torch.nn.Linear(input_dim, hidden)
        # self.fc = torch.nn.Linear(hidden, output_dim)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden1),
            nn.ReLU(),
            nn.Linear(hidden1, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        # x = self.linear(x)
        outputs = self.mlp(x)
        return tF.log_softmax(outputs, dim=1)

In [None]:
transforms_normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
transform_list = [transforms.ToTensor(), transforms_normalize]
transformer = transforms.Compose(transform_list)
trainset = datasets.MNIST(root='/tmp/', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='/tmp/', train=False, download=True, transform=transforms.ToTensor())
trainloader = DataLoader(
    dataset=trainset,
    batch_size=128,
    shuffle=True)

testloader = DataLoader(
    dataset=testset,
    batch_size=128,
    shuffle=True)
# repr in [PMatDense, PMatBlockDiag, PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]
input_dim = 784


In [None]:
def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0):
    if step_iter == 0:
        for name, mod in model.named_modules():
            if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                # print(mod.weight.grad.data.size())
                # print(mod.weight.data.size())
                grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()]
    else:
        for name, mod in model.named_modules():
            if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy())

    return grad_dict


def caculate_fr_zico(grad_dict, theta_dict, losses, theta_dict_copy=None):
    """Use implementation based on zico because the module names of search spaces in the
    CV benchmark are different from the ads search space.
    """
    allgrad_array = None
    for i, modname in enumerate(grad_dict.keys()):
        grad_dict[modname] = np.array(grad_dict[modname])
    per_sample_mean_abs = np.zeros(4)
    nsr_mean_sum_abs = 0
    nsr_mean_sum_mean = 0
    nsr_mean_sum_std = 0
    fr_mean_sum_abs = 0
    mean_abs_grad_value, std_grad_value = 0, 0
    per_sample_prod_grad = np.zeros(4)

    for j, modname in enumerate(grad_dict.keys()):
        nsr_std = np.std(grad_dict[modname], axis=0)
        nonzero_idx = np.nonzero(nsr_std)[0]
        nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
        temp = np.mean(np.abs(grad_dict[modname])[:, nonzero_idx], axis=1)

        tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])
        mean_abs_grad, std_grad = np.sum(nsr_mean_abs[nonzero_idx]), np.sum(
            nsr_std[nonzero_idx]
        )
        if tmpsum == 0:
            pass
        else:
            nsr_mean_sum_abs += np.log(tmpsum)
            nsr_sum_grad = np.sum(
                np.sum(grad_dict[modname], axis=0) * theta_dict[-1][modname]
            )
            fr_mean_sum_abs += np.abs(nsr_sum_grad)
            mean_abs_grad_value += mean_abs_grad
            std_grad_value += std_grad

    return (
        nsr_mean_sum_abs,
        nsr_mean_sum_abs + np.log(fr_mean_sum_abs + 1e-5),
        mean_abs_grad_value,
        std_grad_value,
    )

In [None]:
import copy
import numpy as np

eval_batch = 4
epoch = 3
last_loss, train_losses, last_scores, cur_losses, mean_abs_grad_list, std_grad_list, zico_list = [], [], [], [], [], [], []
last_scores01, cur_losses01, mean_abs_grad_list01, std_grad_list01, zico_list01 = [], [], [], [], []
last_scores02, cur_losses02, mean_abs_grad_list02, std_grad_list02, zico_list02 = [], [], [], [], []

batch_10 = 48
batch_20 = 48 * 4
for k1 in range(25):
    for k2 in range(25):
        # convnet = ConvNet().to(device)
        convnet = MLP2(input_dim, 10, (k1 + 1) * 2, (k2 + 1) * 2).to(device)
        convnet2 = MLP2(input_dim, 10, (k1 + 1) * 2, (k2 + 1) * 2).to(device)
        optimizer = torch.optim.SGD(convnet.parameters(), lr=0.2)
        loss_fn = tF.cross_entropy
        # Here, we use enumerate(training_loader) instead of
        # iter(training_loader) so that we can track the batch
        # index and do some intra-epoch reporting
        num_param = 0
        for d in convnet.state_dict():
            num_param += convnet.state_dict()[d].flatten().size()[0]

        # ini_theta = []
        # for i in convnet.state_dict():
        #     ini_theta.append(convnet.state_dict()[i].flatten())
        # ini_theta = torch.concatenate(ini_theta)
        train_loss, train_loss01, train_loss02 = [], [], []
        grad_dict, grad_dict01, grad_dict02 = {}, {}, {}
        scores = torch.zeros(num_param).to(device)
        scores01 = torch.zeros(num_param).to(device)
        scores02 = torch.zeros(num_param).to(device)
        for e in range(epoch):
            for i, data in enumerate(trainloader):
                # Every data instance is an input + label pair
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                # Zero your gradients for every batch!
                optimizer.zero_grad()

                # Make predictions for this batch
                outputs = convnet(inputs)

                # Compute the loss and its gradients
                loss = loss_fn(outputs, labels)
                loss.backward()
                # Adjust learning weights
                optimizer.step()
                losses, theta_list = [], []
                losses01, theta_list01 = [], []
                losses02, theta_list02 = [], []
                if e == 0 and i < eval_batch: # only train with eval_batch batches
                    # copy all weights to a separate model
                    convnet2.load_state_dict(copy.deepcopy(convnet.state_dict()))
                    outputs2 = convnet2(inputs)
                    loss2 = loss_fn(outputs2, labels)
                    loss2.backward()
                    grad = []
                    for name, param in convnet2.named_parameters():
                        if param.requires_grad:
                            grad.append(param.grad.flatten())
                    scores += torch.concatenate(grad)

                    grad_dict= getgrad(convnet2, grad_dict,i)
                    losses.append(loss2.item())

                    theta_dict = {}
                    for name, mod in convnet2.named_modules():
                        if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                            theta_dict[name] = mod.weight.data.cpu().reshape(-1).numpy()
                    theta_list.append(theta_dict)

                    if i == eval_batch - 1:
                        cur_losses.append(loss.item())
                        score1, score2, mean_abs_grad, std_grad = caculate_fr_zico(grad_dict, theta_list, np.array(losses))
                        mean_abs_grad_list.append(mean_abs_grad)
                        std_grad_list.append(std_grad)
                        zico_list.append(score1)

                        theta = []
                        for d in convnet.state_dict():
                            theta.append(convnet.state_dict()[d].flatten())
                        theta = torch.concatenate(theta)
                        score = torch.log((scores * theta).sum()**2)

                if e == 0 and i >= batch_10 and i < batch_10 + eval_batch: # only train with eval_batch batches after 24 batches (10%) warm-up

                    # copy all weights to a separate model
                    convnet2.load_state_dict(copy.deepcopy(convnet.state_dict()))
                    outputs2 = convnet2(inputs)
                    loss2 = loss_fn(outputs2, labels)
                    loss2.backward()
                    grad = []
                    for name, param in convnet2.named_parameters():
                        if param.requires_grad:
                            grad.append(param.grad.flatten())
                    scores01 += torch.concatenate(grad)

                    grad_dict01= getgrad(convnet2, grad_dict01, i - batch_10)
                    losses01.append(loss2.item())

                    theta_dict = {}
                    for name, mod in convnet2.named_modules():
                        if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                            theta_dict[name] = mod.weight.data.cpu().reshape(-1).numpy()
                    theta_list01.append(theta_dict)

                    if i == batch_10 + eval_batch - 1:
                        cur_losses01.append(loss2.item())
                        score1, score2, mean_abs_grad, std_grad = caculate_fr_zico(grad_dict01, theta_list01, np.array(losses01))
                        mean_abs_grad_list01.append(mean_abs_grad)
                        std_grad_list01.append(std_grad)
                        zico_list01.append(score1)

                        theta01 = []
                        for c in convnet.state_dict():
                            theta01.append(convnet.state_dict()[c].flatten())
                        theta01 = torch.concatenate(theta01)
                        score01 = torch.log((scores01 * theta01).sum()**2)

                if e == 0 and i >= batch_20 and i < batch_20 + eval_batch: # only train with eval_batch batches after 24 batches (20%) warm-up

                    # copy all weights to a separate model
                    convnet2.load_state_dict(copy.deepcopy(convnet.state_dict()))
                    outputs2 = convnet2(inputs)
                    loss2 = loss_fn(outputs2, labels)
                    loss2.backward()
                    grad = []
                    for name, param in convnet2.named_parameters():
                        if param.requires_grad:
                            grad.append(param.grad.flatten())
                    scores02 += torch.concatenate(grad)

                    grad_dict02= getgrad(convnet2, grad_dict02, i - batch_20)
                    losses02.append(loss2.item())

                    theta_dict = {}
                    for name, mod in convnet2.named_modules():
                        if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
                            theta_dict[name] = mod.weight.data.cpu().reshape(-1).numpy()
                    theta_list02.append(theta_dict)

                    if i == batch_20 + eval_batch - 1:
                        cur_losses02.append(loss2.item())
                        score1, score2, mean_abs_grad, std_grad = caculate_fr_zico(grad_dict02, theta_list02, np.array(losses02))
                        mean_abs_grad_list02.append(mean_abs_grad)
                        std_grad_list02.append(std_grad)
                        zico_list02.append(score1)

                        theta02 = []
                        for c in convnet.state_dict():
                            theta02.append(convnet.state_dict()[c].flatten())
                        theta02 = torch.concatenate(theta02)
                        score02 = torch.log((scores02 * theta02).sum()**2)

                if e == epoch - 1:
                    train_loss.append(loss.item())


        # evaluate accuracy at end of training
        convnet.eval()
        val_loss = []
        for i, data in enumerate(testloader):
            # Every data instance is an input + label pair
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # Make predictions for this batch
            outputs = convnet(inputs)
            loss = loss_fn(outputs, labels)
            val_loss.append(loss.item())

        # F_dense = F.get_dense_tensor()
        # score = zero_score(convnet, trainloader, repr)
        last_loss.append(sum(val_loss) / len(val_loss))
        train_losses.append(sum(train_loss) / len(train_loss))
        last_scores.append(score.item())
        last_scores01.append(score01.item())
        last_scores02.append(score02.item())

        print('step {}-{}: train loss: {}, val loss: {}, cur loss: {}, mean abs gradient: {}, std gradient: {}, zico: {}, score: {}'.format((k1 + 1) * 1, (k2 + 1) * 1, \
        sum(train_loss) / len(train_loss), sum(val_loss) / len(val_loss), cur_losses[-1], mean_abs_grad_list[-1], std_grad_list[-1], zico_list[-1], score))

        print('step {}-{}: train loss: {}, val loss: {}, cur loss: {}, mean abs gradient (10%): {}, std gradient (10%): {}, zico (10%): {}, score (10%): {}'.format((k1 + 1) * 1, (k2 + 1) * 1, \
        sum(train_loss) / len(train_loss), sum(val_loss) / len(val_loss), cur_losses01[-1], mean_abs_grad_list01[-1], std_grad_list01[-1], zico_list01[-1], score01))


In [None]:
import scipy
from scipy import stats

fig, ax = plt.subplots(3, 3, sharex="col", sharey="row", figsize=(20, 10))
fig.subplots_adjust(hspace=0.3, wspace=0.1)

plt.style.use("seaborn-white")
# data = [cur_losses, np.exp(last_scores), mean_abs_grad_list, 220 - np.log(std_grad_list) * 50]
data = [
    [cur_losses, last_scores, mean_abs_grad_list, std_grad_list],
    [cur_losses01, last_scores01, mean_abs_grad_list01, std_grad_list01],
    [cur_losses02, last_scores02, mean_abs_grad_list02, std_grad_list02],
]
loss_resutls = [last_loss, last_loss, last_loss]
labels = [
    "Current Training Loss",
    "Fisher-Rao Norm",
    "Mean Absolute Gradients",
    "Standard Deviation of Gradients",
]
for i in range(3):
    for j in range(3):
        ax[i, j].scatter(data[i][j], loss_resutls[i])
        ax[i, j].set_title(
            "Spearman's rho: {:.2f} \n kendall's tau: {:.2f}".format(
                stats.spearmanr(data[i][j], loss_resutls[i]).correlation,
                stats.kendalltau(data[i][j], loss_resutls[i]).correlation,
            ),
            fontsize=16,
        )
        if j == 0:
            ax[i, j].set_ylabel("Test Loss", fontsize=18)
        if i == 2:
            ax[i, j].set_xlabel(labels[j], fontsize=18)
plt.savefig('correlation_test_train_loss_all_warmup-2Layers.pdf', bbox_inches="tight")
plt.savefig('correlation_test_train_loss_all_warmup-2Layers.png', bbox_inches="tight")

In [None]:
import scipy

fig, ax = plt.subplots(2, 3, sharex='col', sharey='row', figsize=(20,10))
fig.subplots_adjust(hspace=0.2, wspace=0.1)

plt.style.use('seaborn-white')
# data = [cur_losses, np.exp(last_scores), mean_abs_grad_list, 220 - np.log(std_grad_list) * 50]
data = [cur_losses, last_scores, mean_abs_grad_list, std_grad_list]
loss_resutls = [train_losses, last_loss]
labels = ['Current Training Loss', 'Fisher-Rao Norm', "Mean Absolute Gradients", "Standard Deviation of Gradients"]
for i in range(2):
    for j in range(3):
        ax[i, j].scatter(data[j], loss_resutls[i])
        ax[i, j].set_title("Spearman's rho: {:.2f} \n kendall's tau: {:.2f}".format(
            stats.spearmanr(data[j], loss_resutls[i]).correlation,
            stats.kendalltau(data[j], loss_resutls[i]).correlation), fontsize=16)
        if j == 0:
            if i == 0:
                ax[i, j].set_ylabel('Training Loss (Last Epoch)', fontsize=18)
            else:
                ax[i, j].set_ylabel('Test Loss', fontsize=18)
        if i == 1:
            ax[i, j].set_xlabel(labels[j], fontsize=18)
plt.savefig('correlation_test_train_loss_0warmup-2Layers.pdf', bbox_inches="tight")
plt.savefig('correlation_test_train_loss_0warmup-2Layers.png', bbox_inches="tight")


In [None]:
fig, ax = plt.subplots(2, 3, sharex='col', sharey='row', figsize=(20,10))
fig.subplots_adjust(hspace=0.2, wspace=0.1)

plt.style.use('seaborn-white')
# data = [cur_losses, np.exp(last_scores), mean_abs_grad_list, 220 - np.log(std_grad_list) * 50]
data = [cur_losses01, last_scores01, mean_abs_grad_list01, std_grad_list01]
loss_resutls = [train_losses, last_loss]
labels = ['Current Training Loss', 'Fisher-Rao Norm', "Mean Absolute Gradients", "Standard Deviation of Gradients"]
for i in range(2):
    for j in range(3):
        ax[i, j].scatter(data[j], loss_resutls[i])
        ax[i, j].set_title("Spearman's rho: {:.2f} \n kendall's tau: {:.2f}".format(
            stats.spearmanr(data[j], loss_resutls[i]).correlation,
            stats.kendalltau(data[j], loss_resutls[i]).correlation), fontsize=16)
        if j == 0:
            if i == 0:
                ax[i, j].set_ylabel('Training Loss (Last Epoch)', fontsize=18)
            else:
                ax[i, j].set_ylabel('Test Loss', fontsize=18)
        if i == 1:
            ax[i, j].set_xlabel(labels[j], fontsize=18)
plt.savefig('correlation_test_train_loss_10warmup-2Layers.pdf', bbox_inches="tight")
plt.savefig('correlation_test_train_loss_10warmup-2Layers.png', bbox_inches="tight")

In [None]:
import scipy

fig, ax = plt.subplots(2, 3, sharex='col', sharey='row', figsize=(20,10))
fig.subplots_adjust(hspace=0.2, wspace=0.1)

plt.style.use('seaborn-white')
# data = [cur_losses, np.exp(last_scores), mean_abs_grad_list, 220 - np.log(std_grad_list) * 50]
data = [cur_losses02, last_scores02, mean_abs_grad_list02, std_grad_list02]
loss_resutls = [train_losses, last_loss]
labels = ['Current Training Loss', 'Fisher-Rao Norm', "Mean Absolute Gradients", "Standard Deviation of Gradients"]
for i in range(2):
    for j in range(3):
        ax[i, j].scatter(data[j], loss_resutls[i])
        ax[i, j].set_title("Spearman's rho: {:.2f} \n kendall's tau: {:.2f}".format(
            stats.spearmanr(data[j], loss_resutls[i]).correlation,
            stats.kendalltau(data[j], loss_resutls[i]).correlation), fontsize=16)
        if j == 0:
            if i == 0:
                ax[i, j].set_ylabel('Training Loss (Last Epoch)', fontsize=18)
            else:
                ax[i, j].set_ylabel('Test Loss', fontsize=18)
        if i == 1:
            ax[i, j].set_xlabel(labels[j], fontsize=18)
plt.savefig('correlation_test_train_loss_30warmup-2Layers.pdf', bbox_inches="tight")
plt.savefig('correlation_test_train_loss_30warmup-2Layers.png', bbox_inches="tight")
