In [1]:
import torch
from torch import nn
import numpy as np
import tabulate

import os 

os.chdir('/gpfs/commons/home/tchen/loss_sub_space_geometry_project/loss-subspace-geometry/src')

from models.mlp import SubspaceNN, NN

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
import torchvision
import torchvision.transforms as transforms

In [2]:
# configs
data_dim = 784
hidden_size = 512
out_dim = 10
dropout_prob = 0.3
seed = 11202022
device = torch.device('cuda')

model_path = '/gpfs/commons/home/tchen/loss_sub_space_geometry_project/loss-subspace-geometry-save/models/subspace_vanilla_mlp_0.pt'

curve_model = SubspaceNN(input_dim=data_dim, 
                         hidden_dim=hidden_size, 
                         out_dim=out_dim, 
                         dropout_prob=dropout_prob, 
                         seed=seed).to(device)
checkpoint = torch.load(model_path)
curve_model.load_state_dict(checkpoint)

<All keys matched successfully>

In [3]:
# more configs
curve_points = 61
grid_points = 21
margin_left = 0.2
margin_right = 0.2
margin_bottom = 0.2
margin_top = 0.2

In [4]:
curve_model

SubspaceNN(
  (mlp): SubspaceMLP(
    (linear): LinesLinear(in_features=784, out_features=512, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (out): LinesLinear(in_features=512, out_features=10, bias=True)
)

In [5]:
curve_parameters = list(curve_model.parameters())
w = []

w.append(np.concatenate([
        p.data.cpu().numpy().ravel() for p in [curve_parameters[0], curve_parameters[1], curve_parameters[3], curve_parameters[4]]
    ]))

w.append(np.concatenate([
        p.data.cpu().numpy().ravel() for p in [curve_parameters[2], curve_parameters[1], curve_parameters[5], curve_parameters[4]]
    ]))

In [6]:
isolated_model = NN(input_dim=data_dim, 
                         hidden_dim=hidden_size, 
                         out_dim=out_dim, 
                         dropout_prob=dropout_prob).to(device)
isolated_checkpoint = torch.load('/gpfs/commons/home/tchen/loss_sub_space_geometry_project/loss-subspace-geometry-save/models/vanilla_mlp.pt')
isolated_model.load_state_dict(isolated_checkpoint)

<All keys matched successfully>

In [7]:
isolated_model

NN(
  (mlp): MLP(
    (linear): Linear(in_features=784, out_features=512, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (out): Linear(in_features=512, out_features=10, bias=True)
)

In [8]:
w.append(np.concatenate([
        p.data.cpu().numpy().ravel() for p in list(isolated_model.parameters())
    ]))

In [9]:

# set up for grid for plane plotting

def get_xy(point, origin, vector_x, vector_y):
    return np.array([np.dot(point - origin, vector_x), np.dot(point - origin, vector_y)])


print('Weight space dimensionality: %d' % w[0].shape[0])

u = w[2] - w[0]
dx = np.linalg.norm(u)
u /= dx

v = w[1] - w[0]
v -= np.dot(u, v) * u
dy = np.linalg.norm(v)
v /= dy

bend_coordinates = np.stack(get_xy(p, w[0], u, v) for p in w)


Weight space dimensionality: 407050


  if (await self.run_code(code, result,  async_=asy)):


In [10]:
def get_weights(model: nn.Module, t):
    weights = []
    for module in model.modules():
        if isinstance(module, nn.Linear):
            # add attribute for weight dimensionality and subspace dimensionality
            setattr(module, f'alpha', t)
            weights.extend([module.get_weight(), module.bias.data])
        # weights.extend([w for w in module.compute_weights_t(coeffs_t) if w is not None])
    return np.concatenate([w.detach().cpu().numpy().ravel() for w in weights])


In [11]:

ts = np.linspace(0.0, 1.0, curve_points)
curve_coordinates = []
for t in np.linspace(0.0, 1.0, curve_points):
    weights = get_weights(model=curve_model, t=t)
    curve_coordinates.append(get_xy(weights, w[0], u, v))

isolated_model_weights = w[2]
curve_coordinates.append(get_xy(isolated_model_weights, w[0], u, v))
curve_coordinates = np.stack(curve_coordinates)

G = grid_points
alphas = np.linspace(0.0 - margin_left, 1.0 + margin_right, G)
betas = np.linspace(0.0 - margin_bottom, 1.0 + margin_top, G)

tr_loss = np.zeros((G, G))
tr_nll = np.zeros((G, G))
tr_acc = np.zeros((G, G))
tr_err = np.zeros((G, G))

te_loss = np.zeros((G, G))
te_nll = np.zeros((G, G))
te_acc = np.zeros((G, G))
te_err = np.zeros((G, G))

grid = np.zeros((G, G, 2))

In [14]:
# even more configs for evaluating on FashionMNIST

data_dir = '/gpfs/commons/home/tchen/loss_sub_space_geometry_project/data/'
batch_size = 128

In [15]:
transform = transforms.Compose([transforms.ToTensor()])
FashionMNIST_data_train = torchvision.datasets.FashionMNIST(
    data_dir, train=True, transform=transform, download=False)

train_set, val_set = torch.utils.data.random_split(
    FashionMNIST_data_train, [50000, 10000])
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(
    val_set, batch_size=len(val_set), shuffle=False)

In [16]:
criterion = nn.CrossEntropyLoss(reduction='sum')

In [17]:
def eval(model: nn.Module, loader):
    running_loss = 0.0
    num_right = 0

    model.eval()

    for i, (x, y) in enumerate(loader):
            reshaped_x = x.reshape(x.size(0), 784)
            y_hat = model(reshaped_x.to(device))
            num_right += torch.sum(
                y.to(device) == torch.argmax(
                    y_hat, dim=-1)).detach().cpu().item()

            running_loss += criterion(y_hat, y.to(device)).item()

    return {
        'nll': running_loss / len(loader.dataset),
        'loss': running_loss / len(loader.dataset),
        'accuracy': num_right * 100.0 / len(loader.dataset),
    }

In [18]:

base_model =  NN(input_dim=data_dim, 
                         hidden_dim=hidden_size, 
                         out_dim=out_dim, 
                         dropout_prob=dropout_prob).to(device)

columns = ['X', 'Y', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll', 'Test error (%)']

for i, alpha in enumerate(alphas):
    for j, beta in enumerate(betas):
        p = w[0] + alpha * dx * u + beta * dy * v

        offset = 0
        for parameter in base_model.parameters():
            size = np.prod(parameter.size())
            value = p[offset:offset+size].reshape(parameter.size())
            parameter.data.copy_(torch.from_numpy(value))
            offset += size


        # tr_res = utils.test(loaders['train'], base_model, criterion, regularizer)
        # te_res = utils.test(loaders['test'], base_model, criterion, regularizer)
        tr_res = eval(model=base_model, loader=train_loader)
        te_res = eval(model=base_model, loader=valid_loader)


        tr_loss_v, tr_nll_v, tr_acc_v = tr_res['loss'], tr_res['nll'], tr_res['accuracy']
        te_loss_v, te_nll_v, te_acc_v = te_res['loss'], te_res['nll'], te_res['accuracy']

        c = get_xy(p, w[0], u, v)
        grid[i, j] = [alpha * dx, beta * dy]

        tr_loss[i, j] = tr_loss_v
        tr_nll[i, j] = tr_nll_v
        tr_acc[i, j] = tr_acc_v
        tr_err[i, j] = 100.0 - tr_acc[i, j]

        te_loss[i, j] = te_loss_v
        te_nll[i, j] = te_nll_v
        te_acc[i, j] = te_acc_v
        te_err[i, j] = 100.0 - te_acc[i, j]

        values = [
            grid[i, j, 0], grid[i, j, 1], tr_loss[i, j], tr_nll[i, j], tr_err[i, j],
            te_nll[i, j], te_err[i, j]
        ]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='10.4f')
        if j == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)

np.savez(
    os.path.join('./', 'plane.npz'),
    ts=ts,
    bend_coordinates=bend_coordinates,
    curve_coordinates=curve_coordinates,
    alphas=alphas,
    betas=betas,
    grid=grid,
    tr_loss=tr_loss,
    tr_acc=tr_acc,
    tr_nll=tr_nll,
    tr_err=tr_err,
    te_loss=te_loss,
    te_acc=te_acc,
    te_nll=te_nll,
    te_err=te_err
)

----------  ----------  ------------  -----------  -----------------  ----------  ----------------
         X           Y    Train loss    Train nll    Train error (%)    Test nll    Test error (%)
----------  ----------  ------------  -----------  -----------------  ----------  ----------------
  -25.6692    -25.4282        0.2300       0.2300             8.2760      0.2282            8.4800
  -25.6692    -16.5283        0.2436       0.2436             8.3580      0.2419            8.5500
  -25.6692     -7.6285        0.2695       0.2695             8.5720      0.2684            8.7000
  -25.6692      1.2714        0.3087       0.3087             8.8620      0.3080            9.0400
  -25.6692     10.1713        0.3622       0.3622             9.2120      0.3620            9.2700
  -25.6692     19.0712        0.4314       0.4314             9.6280      0.4320            9.5800
  -25.6692     27.9710        0.5180       0.5180            10.0660      0.5195           10.0300
  -25.6692