In [93]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F


seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

d_in = 10
d_out = 10
train_size = 100
test_size = 100
w = 100
init_layer_norm = [3.5617792839366773, 0.9626184663074074, 4.2834745321985, 0.30428997418826464, 1.5577608335262456, 0.05657056471066833]

def L2(model):
    params = list(model.parameters())
    l2 = 0
    for i in range(6):
        if i == 0:
            params_flatten = params[i].reshape(-1,)
        params_flatten = torch.cat([params_flatten, params[i].reshape(-1,)])
    l2 = torch.sum(params_flatten**2)
    return params_flatten, l2

def norms(model):
    params = list(model.parameters())
    norms = []
    for i in range(6):
        norms.append(np.sqrt(torch.sum(params[i].reshape(-1,)**2).item()))
    return norms

# rescale parameters to the layer norms of goldilocks zone with decay ratio
def renorm(model, goldilocks_norms, student_norms, decay=0.01):
    params = list(model.parameters())
    for i in range(6):
        scale = goldilocks_norms[i] / student_norms[i]
        scale = 1 + (scale - 1) * decay
        params[i].data = params[i].data * scale

def init(model, alpha):
    state_dict = model.state_dict()
    modules = ["l1.weight", "l1.bias", "l2.weight", "l2.bias", "l3.weight", "l3.bias"]
    for module in modules:
        state_dict[module] = state_dict[module] * alpha
    model.load_state_dict(state_dict)
    
def init2(model, alpha):
    model.l1.weight.data = model.l1.weight * alpha
    model.l1.bias.data = model.l1.bias * alpha
    model.l2.weight.data = model.l2.weight * alpha
    model.l2.bias.data = model.l2.bias * alpha
    model.l3.weight.data = model.l3.weight * alpha
    model.l3.bias.data = model.l3.bias * alpha
    
def grad(model):
    grads = list(student.parameters())
    for i in range(6):
        if i == 0:
            grad = grads[0].reshape(-1,)
        else:
            grad = torch.cat([grad, grads[i].reshape(-1,)])
    return grad

class Net(nn.Module):

    def __init__(self, w=w):
        super(Net, self).__init__()
        self.l1 = nn.Linear(d_in, w)
        self.l2 = nn.Linear(w, w)
        self.l3 = nn.Linear(w,d_out)

    def forward(self, x):
        f = torch.nn.Tanh()
        self.x1 = f(self.l1(x))
        self.x2 = f(self.l2(self.x1))
        self.x3 = self.l3(self.x2)
        return self.x3
    
teacher = Net()
alpha = 1.0
init2(teacher, alpha=alpha)
inputs_train = torch.tensor(torch.normal(0,1,size=(train_size, d_in)), dtype=torch.float, requires_grad=True)
labels_train = torch.tensor(teacher(inputs_train), dtype=torch.float, requires_grad=True)

# test1 for validation to prevent cherry-pick results
inputs_test = torch.normal(0,1,size=(test_size, d_in))
labels_test = teacher(inputs_test)

inputs_test2 = torch.normal(0,1,size=(test_size, d_in))
labels_test2 = teacher(inputs_test2)


    

  inputs_train = torch.tensor(torch.normal(0,1,size=(train_size, d_in)), dtype=torch.float, requires_grad=True)
  labels_train = torch.tensor(teacher(inputs_train), dtype=torch.float, requires_grad=True)


In [95]:
alpha = 1.0


print("---------alpha={}---------".format(alpha))
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
student = Net()

init2(student, alpha=alpha)
# init student with scaled layer norms
renorm(student, init_layer_norm, norms(student), 1)
_, scale = L2(student)

epochs = 1000000
log = 200
wd = 0.005

optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay = wd)

losses_train = []
losses_test = []
accs_train = []
accs_test = []

l2s = []
threshold = 0.001

for epoch in range(epochs):  # loop over the dataset multiple times

    optimizer.zero_grad()

    outputs_train = student(inputs_train)
    loss_train_vec = torch.mean((outputs_train-labels_train)**2, dim=1)
    loss_train = torch.mean(loss_train_vec)
    train_acc = torch.sum(loss_train_vec < threshold)/train_size
    
    outputs_test = student(inputs_test)
    loss_test_vec = torch.mean((outputs_test-labels_test)**2, dim=1)
    loss_test = torch.mean(loss_test_vec)
    test_acc = torch.sum(loss_test_vec < threshold)/test_size

    outputs_test2 = student(inputs_test2)
    loss_test_vec2 = torch.mean((outputs_test2-labels_test2)**2, dim=1)
    loss_test2 = torch.mean(loss_test_vec2)
    test_acc2 = torch.sum(loss_test_vec2 < threshold)/test_size
    
    params, l2 = L2(student)
    #init2(student, alpha=torch.sqrt(scale/l2))
    #params, l2 = L2(student)

    loss_train.backward()

    optimizer.step()

    if epoch % log == 0:
        print(loss_train.detach().numpy())
        renorm(student, init_layer_norm, norms(student), 0.1)
        # print(norms(teacher))
        print(norms(student))
        print("epoch: %d  | Train loss: %.6f |  Test loss: %.6f | Test loss2: %.6f | train_acc: %.2f | test_acc: %.2f | test_acc2: %.2f | l2: %.6f"%(epoch, loss_train.detach().numpy(), loss_test.detach().numpy(), loss_test2.detach().numpy(), train_acc.detach().numpy(), test_acc.detach().numpy(), test_acc2.detach().numpy(), l2.detach().numpy()))

    losses_train.append(loss_train.detach().numpy())
    losses_test.append(loss_test.detach().numpy())
    l2s.append(l2.detach().numpy())
    accs_train.append(train_acc.detach().numpy())
    accs_test.append(test_acc.detach().numpy())


---------alpha=1.0---------
0.026389362
[3.5634757596258346, 0.9634692328093256, 4.285546367625415, 0.3050506823672941, 1.5604788869244164, 0.056529585748449795]
epoch: 0  | Train loss: 0.026389 |  Test loss: 0.026767 | Test loss2: 0.026272 | train_acc: 0.00 | test_acc: 0.00 | test_acc2: 0.00 | l2: 47.169743
0.00051652995
[3.6018922865805436, 0.9681183078305541, 4.362505045453409, 0.32684824821308506, 1.5998565520117016, 0.05483717318915826]
epoch: 200  | Train loss: 0.000517 |  Test loss: 0.000746 | Test loss2: 0.000705 | train_acc: 0.90 | test_acc: 0.80 | test_acc2: 0.76 | l2: 48.648445
0.00039070242
[3.6642569223970853, 0.9730019931557168, 4.401122010831287, 0.32369973107236333, 1.6123942968927742, 0.05417218908981381]
epoch: 400  | Train loss: 0.000391 |  Test loss: 0.000721 | Test loss2: 0.000681 | train_acc: 0.95 | test_acc: 0.81 | test_acc2: 0.79 | l2: 50.077274
0.00025403075
[3.7801306462795563, 0.995538324820635, 4.494027094461807, 0.32560429329665963, 1.648281433527827, 0.053

In [96]:
plt.plot(np.arange(epochs), accs_train)
plt.plot(np.arange(epochs), accs_test)
plt.xlabel("steps", fontsize=20)
plt.ylabel("accuracy", fontsize=20)
plt.legend(["train", "test"], fontsize=15)
plt.title("weight decay=0.05", fontsize=15)
plt.xscale('log')

ValueError: x and y must have same first dimension, but have shapes (1000000,) and (946563,)