In [None]:
#default_exp loss_landscape

In [None]:
#export
from fastai2.vision.all import *
from fastexplorer.representation import *
from fastexplorer.explorer import *

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
#hide
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
learn = cnn_learner(dls, resnet34, metrics=accuracy)

In [None]:
#srv
# learn.fine_tune(2)
learn.load('resnet34');

epoch,train_loss,valid_loss,accuracy,time
0,0.139714,0.021332,0.993911,00:13


epoch,train_loss,valid_loss,accuracy,time
0,0.068502,0.030123,0.989175,00:16
1,0.029429,0.009352,0.996617,00:16


# Loss Landscape

> Use the idea of https://github.com/tomgoldstein/loss-landscape to visualize the loss function.

TODO:

- [ ] Report progress back to js.

In [None]:
#export
clientEvents.SEND_LANDSCAPE = 'socket/socketReceiveLossLandscape'

In [None]:
clientEvents

Events (
  SEND_DATA: socket/socketReceiveData
  INVALID_EVENT: socket/socketInvalidEvent
  SEND_IMAGE_INPUT: socket/socketReceiveImageInput
  NOIMAGE_HEATMAP: socket/socketNoImageHeatmap
  SEND_HEATMAP: socket/socketReceiveHeatmap
  SEND_ERROR: socket/socketError
  CLOSE_CLIENT: socket/socketServerClosed
  SEND_LANDSCAPE: socket/socketReceiveLossLandscape
)

## Utils

In [None]:
#export
def _normalize_direction(direction, weights):
    '''
    Rescale the filters (weights in group) in 'direction' so that each
    filter has the same norm as its corresponding filter in 'weights'.
    '''
    for d,w in zip(direction, weights): d.mul_(w.norm()/(d.norm() + 1e-10))

def _get_random_direction(m):
    weights = [o.data for o in m.parameters()]
    direction = [torch.randn(o.shape) for o in weights]

    for d,w in zip(direction, weights):
        if d.dim() <= 1: d.fill_(0)
        else           : _normalize_direction(d, w)
            
    return direction

def _compute_landscape(learn, samples=30, size=1, final_size=100):
    m = learn.model
    dls = learn.dls
    weights = [o.data.clone() for o in m.parameters()]
    xdirection = _get_random_direction(m)
    ydirection = _get_random_direction(m)
    xcoords = torch.linspace(-size, size, samples)
    ycoords = torch.linspace(-size, size, samples)
    losses = -torch.ones(samples,samples)
    xmesh,ymesh = torch.meshgrid(xcoords, ycoords)
    shape = losses.shape
    losses,xmesh,ymesh = [o.contiguous().view(-1) for o in [losses,xmesh,ymesh]]
    xb,yb = dls.one_batch()
    
    for l,x,y in progress_bar(zip(losses,xmesh,ymesh), total=losses.size(0)):
        changes = [(dx*x + dy*y) for dx,dy in zip(xdirection,ydirection)]
        for p,w,c in zip(m.parameters(), weights, changes): p.data = w.add(c.to(w.device))
        with torch.no_grad():
            yb_ = m.eval()(xb)
            loss = learn.loss_func(yb_, yb)

        l.fill_(loss)

    losses,xmesh,ymesh = [o.view(*shape) for o in [losses,xmesh,ymesh]]
    landscape = (F.interpolate(losses[None,None], [final_size,final_size], mode='bilinear',
                               align_corners=False)[0,0]
                 if final_size > samples else losses)
    return landscape.numpy()

## Event handler

In [None]:
#export
@patch
async def get_loss_landscape(self:FastExplorer, websocket, payload=None):
    "Sends the loss landscape for the model."
    try:        
        if 'loss_landscape' not in self.cache.keys():
            self.cache['loss_landscape'] = _compute_landscape(self.learn)

        landscape = self.cache['loss_landscape']
        percentile = np.percentile(landscape, 95)
        array_bytes = get_numpy_bytes(landscape, clientEvents.SEND_LANDSCAPE, xtra={'max_z': percentile})
        await websocket.send_bytes(array_bytes)
    except Exception as e:
        await websocket.send_json({'type': clientEvents.SEND_ERROR,
                                   'payload': {'msg': 'Error getting the loss landscape.'}})

In [None]:
#srv
learn.fastexplorer()

INFO:     To visualize the model information, go to:
INFO:     https://renato145.github.io/fastexplorer-js
INFO:     Started server process [14881]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [14881]


In [None]:
#hide
# plt.contour(xmesh, ymesh, losses, cmap='summer', levels=np.arange(0.1,10,0.5))

# from mpl_toolkits.mplot3d import Axes3D

# fig = plt.figure()
# ax = Axes3D(fig)
# surf = ax.plot_surface(xmesh, ymesh, losses.numpy(), cmap=plt.cm.coolwarm, linewidth=0, antialiased=False)
# fig.colorbar(surf, shrink=0.5, aspect=5);

# fig.update_layout(
#     scene = dict(
#         xaxis = dict(nticks=4, range=[-100,100],),
#                      yaxis = dict(nticks=4, range=[-50,100],),
#                      zaxis = dict(nticks=4, range=[-100,100],),),
#     width=700,
#     margin=dict(r=20, l=10, b=10, t=10))

# import plotly.graph_objects as go
# fig = go.Figure(data=[go.Surface(z=np.log(t))])
# fig.update_layout(autosize=False, width=500, height=500, margin=dict(l=65, r=50, b=65, t=90))

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_representation.ipynb.
Converted 01_explorer.ipynb.
Converted 02_loss_landscape.ipynb.
Converted index.ipynb.
