In [None]:
import torch
import matplotlib.pyplot as plt
from ipywidgets import AppLayout, GridspecLayout, IntSlider, FloatSlider, Button, Checkbox, Layout, HBox, VBox
import math
import os
from glob import glob
import pandas as pd
from tqdm import tqdm
from dinopl.tracking import FeatureSaver

from torch.nn import functional as F
%matplotlib widget

In [None]:
def create_slider(name, range, start, update_fn, smooth=True, N=50):
    SliderClass = FloatSlider if smooth else IntSlider
    slider = SliderClass( description=name,
                layout=Layout(height='auto', width='auto'),
                orientation='horizontal',
                min=range[0],
                max=range[1],
                step=(range[1]-range[0]) / N,
                value=start)
    slider.observe(update_fn, names='value')
    return slider

In [None]:
def create_checkbox(name, start, update_fn):
    ckbox = Checkbox(description=name, value=start)
    ckbox.observe(update_fn, names='value')
    return ckbox

### Common losses in self-supervised learning

In [None]:
def dotsim(x_star:torch.Tensor, x:torch.Tensor):
    return (x_star * x).sum(dim=-1)

def cossim(x_star:torch.Tensor, x:torch.Tensor):
    return dotsim(x_star, x) / (x_star.norm(dim=-1) * x.norm(dim=-1) + 1e-4)

def mse(x_star:torch.Tensor, x:torch.Tensor): # acutally sum of squred errors
    return (x_star - x).square().sum(dim=-1)
    return dotsim(x_star - x, x_star - x)
    return -2*dotsim(x_star, x) + x_star.norm().square() + x.norm().square()

def l2(x_star:torch.Tensor, x:torch.Tensor):
    return mse(x_star, x).sqrt()

def ce(x_star:torch.Tensor, x:torch.Tensor):
    targ = F.softmax(x_star, dim=-1)
    log_pred = F.log_softmax(x, dim=-1)
    return torch.sum(targ * -log_pred, dim=-1)

def kl(x_star:torch.Tensor, x:torch.Tensor):
    targ = F.softmax(x_star, dim=-1)
    log_targ = F.log_softmax(x_star, dim=-1)
    log_pred = F.log_softmax(x, dim=-1)
    return torch.sum(targ * -(log_pred - log_targ), dim=-1)

def entropy(x:torch.Tensor):
    prob = F.softmax(x, dim=-1)
    log_prob = F.log_softmax(x, dim=-1)
    return torch.sum(prob * -log_prob, dim=-1)


In [None]:
def plot_loss(loss, x_star=torch.Tensor([1,1]), lim=(0,2), N=100, cN=50, clim=(None, None), ax=None):
    if ax is None:
        ax = plt.figure().gca()
    
    x = torch.linspace(lim[0], lim[1], N)
    y = torch.linspace(lim[0], lim[1], N)
    M = torch.stack(torch.meshgrid(x, y, indexing='ij'), dim=-1)
    z = loss(x_star, M)

    p = ax.contourf(M[:,:,0], M[:,:,1], z, levels=cN, vmin=clim[0], vmax=clim[1], cmap='jet')
    ax.arrow(0, 0, x_star[0], x_star[1], length_includes_head=True, width=0.01, head_width=0.2, color='k')
    ax.set_aspect('equal', 'box')
    ax.set_title(loss.__name__)
    plt.colorbar(p, ax=ax)
    return ax

In [None]:
lim = (-2, 2)
x_star = torch.Tensor([1,-1])
_, ax = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(10, 5))
plot_loss(dotsim, x_star, lim, ax=ax[0][0])
plot_loss(cossim, x_star, lim, ax=ax[1][0])
plot_loss(mse, x_star, lim, ax=ax[0][1])
plot_loss(l2,  x_star, lim, ax=ax[1][1])
plot_loss(ce,  x_star, lim, ax=ax[0][2])
plot_loss(kl,  x_star, lim, ax=ax[1][2])


### The cross entropy grows in O(n) like the l2 distance

In [None]:
def plot_growth(lim=(-10, 10), N=100, ax=None):
    if ax is None:
        ax = plt.figure().gca()

    t = torch.linspace(lim[0], lim[1], N).unsqueeze(1)
    X = t * torch.Tensor([-1, 1]).repeat(100,1)

    x_star = torch.Tensor([-1,1])
    #ax.plot(t, mse(x_star, X), label='$mse(x, x^*)$')
    ax.plot(t, l2(x_star, X), label='$l2(x, x^*)$')
    ax.plot(t, kl(x_star, X), label='$kl(x, x^*)$')
    ax.legend()
    return ax

plot_growth(lim=(-10, 10))

### Entropy in logits space

In [None]:
def plot_entropy_logits(lim=(-2,2), N=100, cN=50, alpha=1, clim=(None, None), ax=None):
    if ax is None:
        ax = plt.figure().gca()
    
    x = torch.linspace(lim[0], lim[1], N)
    y = torch.linspace(lim[0], lim[1], N)
    M = torch.stack(torch.meshgrid(x, y, indexing='ij'), dim=-1)
    H = entropy(M)

    p = ax.contourf(M[:,:,0], M[:,:,1], H, levels=cN, vmin=clim[0], vmax=clim[1], alpha=alpha, cmap='jet')
    plt.colorbar(p, ax=ax)
    return ax

In [None]:
def plot_entropy_probas(N=100, ax=None):
    if ax is None:
        ax = plt.figure().gca()
    
    prob = torch.linspace(0, 1, N)
    prob = torch.stack((prob, 1-prob), dim=-1)

    H = torch.sum(prob * -torch.log(prob), dim=-1)
    ax.plot(prob, H)
    ax.set_title('Entropy of Probabilities')

In [None]:
fig, ax = plt.subplots(1, 2, sharex=False, sharey=False, figsize=(12, 5))
plot_entropy_logits(lim=(-8,8), ax=ax[0])
ax[0].set_title('Entropy of Logits')
ax[0].set_aspect('equal', 'box')
plot_entropy_probas(ax=ax[1])


### The effect of sharpening on entropy

In [None]:
def plot_sharpening(t_min=0.2, lim=(-10, 10), N=100):
    plt.ioff()
    fig, ax = plt.subplots(1, 2, sharex=False, sharey=False, figsize=(12, 5))
    plot_entropy_logits(lim=lim, N=N, cN=N//2, ax=ax[0])
    ax[0].set_aspect('equal', 'box')
    ax[0].set_title('Entropy of Logits')

    # data
    state=dict(r=1, phi=-.25*math.pi)
    r, phi = state['r'], state['phi']
    x = torch.Tensor([r*math.cos(phi), r*math.sin(phi)])
    t = 1/torch.linspace(1, 1/t_min, N)
    xs = x.unsqueeze(0)/t.unsqueeze(1)

    # plot
    arr1 = ax[0].arrow(0, 0, x[0], x[1], length_includes_head=True, width=0.01, head_width=0.2, color='k')
    arr2 = ax[0].arrow(0, 0, x[0]/t_min, x[1]/t_min)
    ax[0].set_xlim(lim)
    ax[0].set_ylim(lim)

    line, = ax[1].plot(t, entropy(xs))
    ax[1].set_ylim((0, math.log(2)))
    ax[1].invert_xaxis()
    ax[1].set_title('Entropy vs Temperature')

    def update(state):        
        r, phi = state['r'], state['phi']
        x = torch.Tensor([r*math.cos(phi), r*math.sin(phi)])
        xs = x.unsqueeze(0)/t.unsqueeze(1)

        arr1.set_data(x=0, y=0, dx=x[0], dy=x[1])
        arr2.set_data(x=0, y=0, dx=x[0]/t_min, dy=x[1]/t_min)
        line.set_ydata(entropy(xs))

        fig.canvas.draw()
        fig.canvas.flush_events()
    
    def update_r(change):
        state['r'] = change.new
        update(state)

    def update_phi(change):
        state['phi'] = change.new
        update(state)

    sliders = VBox([
        create_slider('r', range=(0, lim[1]), start=state['r'], update_fn=update_r),
        create_slider('phi', range=(-1.25*math.pi, .75*math.pi), start=state['phi'], update_fn=update_phi)],
        layout=Layout(width='40%', height='auto', margin = '0px 30% 0px 30%')
    )

    widget = AppLayout(
        center=fig.canvas,
        footer=sliders,
        pane_heights=[0, 6, 1]
    )
    plt.ion()
    return widget

plot_sharpening(t_min=0.1, lim=(-10, 10))

In [None]:
#ax = plt.figure().gca()
#
#n_classes = 2048
#lim = (0, 4)
#
#x = torch.randn(n_classes).unsqueeze(0)
#t = 10**-torch.linspace(lim[0], lim[1], n_classes).unsqueeze(1)
#
#H = -torch.sum(F.softmax(x/t, dim=-1)*F.log_softmax(x/t, dim=-1), dim=-1)
#
#ax.plot(t, H)
#ax.invert_xaxis()
#ax.set_xscale('log')

### Effect of centering and sharpening on crossentropy / kl divergence

In [None]:
def plot_kl_logits(lim=(-2, 2), N = 100, cN = 50):
    # Widget should handly interactive 
    plt.ioff()
    fig, ax = plt.subplots(1,1, sharex=True, sharey=True, figsize=(8, 4))

    # Data
    state=dict(r=1, phi=-.5, t_stud=1)
    r, phi, t_stud = state['r'], (state['phi'] + 0.25)*math.pi, state['t_stud']
    x = torch.Tensor([r*math.cos(phi), r*math.sin(phi)])
    M = torch.stack(torch.meshgrid([
            torch.linspace(lim[0], lim[1], N),
            torch.linspace(lim[0], lim[1], N)],
            indexing='ij'), dim=-1)
    KL = kl(x, M)

    # Default axis
    vmin, vmax = KL.min(), KL.max()
    cont = ax.contourf(M[:,:,0]/t_stud, M[:,:,1]/t_stud, KL, levels=cN, cmap='jet', vmin=vmin, vmax=vmax)
    arr1 = ax.arrow(0, 0, x[0], x[1], length_includes_head=True, width=0.01, head_width=0.2, color='k', zorder=10)


    ax.set_xlim(lim)
    ax.set_ylim(lim)
    ax.set_aspect('equal', 'box')

    plt.colorbar(cont, ax=ax)
    plt.suptitle('KL-Divergence of Logits')


    def update(state):
        r, phi, t_stud = state['r'], (state['phi'] + 0.25)*math.pi, state['t_stud']
        x = torch.Tensor([r*math.cos(phi), r*math.sin(phi)])
        KL_new = kl(x, M)

        #for c in cont.collections:
        #    ax.collections.remove(c)
        #    cont.collections.remove(c)  # removes only the contours, leaves the rest intact
        cont = ax.contourf(M[:,:,0]/t_stud, M[:,:,1]/t_stud, KL_new, levels=cN, cmap='jet', vmin=vmin, vmax=vmax)
        arr1.set_data(x=0, y=0, dx=x[0], dy=x[1])

        fig.canvas.draw()
        fig.canvas.flush_events()

    def update_r(change):
        state['r'] = change.new
        update(state)

    def update_phi(change):
        state['phi'] = change.new
        update(state)

    def update_t_stud(change):
        state['t_stud'] = change.new
        update(state)

    sliders = VBox([
        create_slider('r', range=(0, lim[1]), start=state['r'], update_fn=update_r),
        create_slider('phi', range=(-1, 1), start=state['phi'], update_fn=update_phi),
        create_slider('t_stud', range=(0.1, 1), start=state['t_stud'], update_fn=update_t_stud)],
        layout=Layout(width='40%', height='auto', margin = '0px 20% 0px 20%')
    )

    widget = AppLayout(
        center=fig.canvas,
        footer=sliders,
        pane_heights=[0, 8, 2]
        )

    plt.ion()
    return widget

plot_kl_logits(lim=(-2,2))

### Entropy in 3D

In [None]:
import cv2
import numpy as np
R,_ = cv2.Rodrigues(np.pi * np.array([0.25, 0.25, 0.25]))

In [None]:
def plot_entropy_logits_3d(lim=(-2,2), N=100, cN=50, clim=(None, None), ax=None):
    if ax is None:
        ax = plt.figure().gca()
    
    x = torch.linspace(lim[0], lim[1], N)
    y = torch.linspace(lim[0], lim[1], N)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = torch.zeros_like(X)
    M = torch.stack((X, Y, Z), dim=-1)

    v_from, v_to = torch.Tensor([0,0,1]), F.normalize(torch.Tensor([1, 1, 1]), dim=0)
    R = torch.from_numpy(cv2.Rodrigues(torch.cross(v_from, v_to).numpy())[0])
    H = entropy(M @ R.T)

    p = ax.contourf(M[:,:,0], M[:,:,1], H, levels=cN, vmin=clim[0], vmax=clim[1], cmap='jet')
    ax.set_aspect('equal', 'box')
    ax.set_title('Entropy of Logits 3D')
    plt.colorbar(p, ax=ax)
    return ax
plot_entropy_logits_3d()

In [None]:
def plot_sharpening_3d(t_min=0.2, lim=(-10, 10), N=100):
    plt.ioff()
    fig, ax = plt.subplots(1, 2, sharex=False, sharey=False, figsize=(12, 5))
    plot_entropy_logits_3d(lim=lim, N=N, cN=N//2, ax=ax[0])

    v_from, v_to = torch.Tensor([0,0,1]), F.normalize(torch.Tensor([1, 1, 1]), dim=0)
    R = torch.from_numpy(cv2.Rodrigues(torch.cross(v_from, v_to).numpy())[0])

    # data
    state=dict(r=1, phi=-.25*math.pi)
    r, phi = state['r'], state['phi']
    x = R @ torch.Tensor([r*math.cos(phi), r*math.sin(phi), 0])
    t = 1/torch.linspace(1, 1/t_min, N)
    xs = x.unsqueeze(0)/t.unsqueeze(1)

    
    
    # plot
    arr1 = ax[0].arrow(0, 0, x[0], x[1], length_includes_head=True, width=0.01, head_width=0.2, color='k')
    arr2 = ax[0].arrow(0, 0, x[0]/t_min, x[1]/t_min)
    ax[0].set_xlim(lim)
    ax[0].set_ylim(lim)

    line, = ax[1].plot(t, entropy(xs))
    ax[1].set_ylim((0, math.log(3)))
    ax[1].invert_xaxis()
    ax[1].set_title('Entropy vs Temperature')

    def update(state):        
        r, phi = state['r'], state['phi']
        x = R @ torch.Tensor([r*math.cos(phi), r*math.sin(phi), 0])
        xs = x.unsqueeze(0)/t.unsqueeze(1)

        arr1.set_data(x=0, y=0, dx=x[0], dy=x[1])
        arr2.set_data(x=0, y=0, dx=x[0]/t_min, dy=x[1]/t_min)
        line.set_ydata(entropy(xs))

        fig.canvas.draw()
        fig.canvas.flush_events()
    
    def update_r(change):
        state['r'] = change.new
        update(state)

    def update_phi(change):
        state['phi'] = change.new
        update(state)

    sliders = VBox([
        create_slider('r', range=(0, lim[1]), start=state['r'], update_fn=update_r),
        create_slider('phi', range=(-1.25*math.pi, .75*math.pi), start=state['phi'], update_fn=update_phi)],
        layout=Layout(width='40%', height='auto', margin = '0px 30% 0px 30%')
    )

    widget = AppLayout(
        center=fig.canvas,
        footer=sliders,
        pane_heights=[0, 6, 1]
    )
    plt.ion()
    return widget

plot_sharpening_3d(t_min=0.1, lim=(-10, 10))

In [None]:
def plot_kl_logits_3d(lim=(-2, 2), N = 100, cN = 50):
    # Widget should handly interactive 
    plt.ioff()
    fig, ax = plt.subplots(1,1, sharex=True, sharey=True, figsize=(8, 4))

    v_from, v_to = torch.Tensor([0,0,1]), torch.Tensor([1, 1, 1]) / math.sqrt(3)
    R = torch.from_numpy(cv2.Rodrigues(torch.cross(v_from, v_to).numpy())[0])

    # Data
    state=dict(r=1, phi=-.5, t_stud=1)
    r, phi, t_stud = state['r'], (state['phi'] + 0.25)*math.pi, state['t_stud']
    x = torch.Tensor([r*math.cos(phi), r*math.sin(phi), 0])

    X, Y =torch.meshgrid([
            torch.linspace(lim[0], lim[1], N),
            torch.linspace(lim[0], lim[1], N)],
            indexing='ij')
    Z = torch.zeros_like(X)
    M = torch.stack((X, Y, Z), dim=-1)

    KL = kl(R @ x, M @ R.T)

    # Default axis
    vmin, vmax = KL.min(), KL.max()
    cont = ax.contourf(M[:,:,0]/t_stud, M[:,:,1]/t_stud, KL, levels=cN, cmap='jet', vmin=vmin, vmax=vmax)
    arr1 = ax.arrow(0, 0, x[0], x[1], length_includes_head=True, width=0.01, head_width=0.2, color='k', zorder=10)


    ax.set_xlim(lim)
    ax.set_ylim(lim)
    ax.set_aspect('equal', 'box')

    plt.colorbar(cont, ax=ax)
    plt.suptitle('KL-Divergence of Logits')


    def update(state):
        r, phi, t_stud = state['r'], (state['phi'] + 0.25)*math.pi, state['t_stud']
        x = torch.Tensor([r*math.cos(phi), r*math.sin(phi), 0])
        KL_new = kl(R @ x, M @ R.T)

        #for c in cont.collections:
        #    ax.collections.remove(c)
        #    cont.collections.remove(c)  # removes only the contours, leaves the rest intact
        cont = ax.contourf(M[:,:,0]/t_stud, M[:,:,1]/t_stud, KL_new, levels=cN, cmap='jet', vmin=vmin, vmax=vmax)
        arr1.set_data(x=0, y=0, dx=x[0], dy=x[1])

        fig.canvas.draw()
        fig.canvas.flush_events()

    def update_r(change):
        state['r'] = change.new
        update(state)

    def update_phi(change):
        state['phi'] = change.new
        update(state)

    def update_t_stud(change):
        state['t_stud'] = change.new
        update(state)

    sliders = VBox([
        create_slider('r', range=(0, lim[1]), start=state['r'], update_fn=update_r),
        create_slider('phi', range=(-1, 1), start=state['phi'], update_fn=update_phi),
        create_slider('t_stud', range=(0.1, 1), start=state['t_stud'], update_fn=update_t_stud)],
        layout=Layout(width='40%', height='auto', margin = '0px 20% 0px 20%')
    )

    widget = AppLayout(
        center=fig.canvas,
        footer=sliders,
        pane_heights=[0, 8, 2]
        )

    plt.ion()
    return widget

plot_kl_logits_3d(lim=(-10,10))

# Training Dynamics

In [None]:
data_proje, names, indices = FeatureSaver.load_data(os.path.join(os.environ['DINO_RESULTS'],'wandb/offline-run-20220803_183004-2wkk2y2l/files/valid/feat/proje'))
data_logit, names, indices = FeatureSaver.load_data(os.path.join(os.environ['DINO_RESULTS'],'wandb/offline-run-20220803_183004-2wkk2y2l/files/valid/feat/logit'))
print(names)
print(data_proje.shape)

#### Dynamics in 2D

In [None]:
def plot_dynamics(data_proje, data_logit):
    plt.ioff()
    fig, ax = plt.subplots(1,2, sharex=False, sharey=False, figsize=(14, 7))
    lbls = data_proje[0,:,0].int().tolist()

    state = dict(step=0, arrows=True)
    step = state['step']

    proje = dict(name='proje', data=data_proje, ax=ax[0])
    logit = dict(name='logit', data=data_logit, ax=ax[1])
    for feat in [proje, logit]:
        
        data = feat['data']
        ax = feat['ax']

        lim = data[:,:,1:].min(), data[:,:,1:].max()
        print(lim)
        plot_entropy_logits(lim, alpha=0.5, ax=ax)

        feat['scat_t'] = ax.scatter(data[step,:,1], data[step,:,2], s=10, c=lbls, label=lbls, marker='+') # teacher
        feat['scat_s'] = ax.scatter(data[step,:,3], data[step,:,4], s=10, c=lbls, label=lbls, marker='o') # student
        
        quiv_kwargs = dict(units='xy', scale=1, width=0.1, zorder=10, visible=(state['arrows'] and feat['name'] == 'logit'))
        feat['quiv'] = ax.quiver(data[step,:,3], # X
                        data[step,:,4], # Y
                        data[step,:,1] - data[step,:,3], # U
                        data[step,:,2] - data[step,:,4], # V
                        lbls, #C
                        **quiv_kwargs) # V

        feat['ax'].set_xlim(lim)
        feat['ax'].set_ylim(lim)
        feat['ax'].set_aspect('equal', 'box')
        
    #plt.colorbar(scat_t, ax=ax)


    def update_step(change):
        state['step'] = change.new
        step = state['step']

        for feat in [proje, logit]:
            data = feat['data']
            feat['scat_t'].set_offsets(data[step,:,1:3])
            feat['scat_s'].set_offsets(data[step,:,3:5])
            feat['quiv'].set_offsets(data[step,:,3:5])
            feat['quiv'].set_UVC(data[step,:,1] - data[step,:,3], data[step,:,2] - data[step,:,4])

            fig.canvas.draw()
            fig.canvas.flush_events()
    
    def update_arr_visibility(change):
        state['arrows'] = change.new
        logit['quiv'].set(visible=state['arrows'])

    sliders = VBox([
        create_slider('step', range=(0, data.shape[0]), start=state['step'], N=data.shape[0], smooth=False, update_fn=update_step),
        create_checkbox('arrows', start=state['arrows'], update_fn=update_arr_visibility),
        ],
        layout=Layout(width='40%', height='auto', margin = '0px 20% 0px 20%')
    )

    widget = AppLayout(
        center=fig.canvas,
        footer=sliders,
        pane_heights=[0, 8, 2]
        )        

    plt.ion()
    return widget
plot_dynamics(data_proje=data_proje, data_logit=data_logit)

#### Dynamics in 3D

In [None]:
# TODO

#### Dynamics with n_classes

In [None]:
dir = os.path.join(os.environ['DINO_RESULTS'],'wandb/run-20220804_212135-anpd7oxy/files/valid/feat/logit')
data_logit, names, indices = FeatureSaver.load_data(dir, start=0, stop=-1, step=10)
print(names[:5], '...', names[-5:])
print(data_logit.shape)

In [None]:
lbls = data_logit[:,:,0]
t_logit = data_logit[:,:,1:1025]
s_logit = data_logit[:,:,1025:2049]
t_proba = F.softmax(t_logit, dim=-1)
s_proba = F.softmax(s_logit, dim=-1)
lbls.shape, t_proba.shape, s_proba.shape

In [None]:
def plot_dynamics_probas(lbls, t_proba, s_proba):
    plt.ioff()
    fig, ax = plt.subplots(1,2, sharex=True, sharey=True, figsize=(10, 5))

    n_steps, n_samples, n_classes = t_proba.shape
    state = dict(step=0, sample=0)
    step = state['step']
    sample = state['sample']

    t_bar = ax[0].bar(torch.arange(0, n_classes), t_proba[step, sample, :], width=2)
    ax[0].set_title(f'teacher-produced labels')
    ax[0].set_ylim(0,1)

    s_bar = ax[1].bar(torch.arange(0, n_classes), s_proba[step, sample, :], width=2)
    ax[1].set_title(f'student predictions')
    ax[1].set_ylim(0,1)

    title = plt.suptitle(f'Label Dynamics of Image {sample} with groundruth {lbls[0, sample]}')

    
    def update(step, sample):
        for i, (t_rect, s_rect) in enumerate(zip(t_bar, s_bar)):
            t_rect.set_height(t_proba[step, sample, i])
            s_rect.set_height(s_proba[step, sample, i])

        title.set_text(f'Label Dynamics of Image {sample} with groundruth {lbls[0, sample]}')
        fig.canvas.draw()
        fig.canvas.flush_events()


    def update_step(change):
        state['step'] = change.new
        update(state['step'], state['sample'])

    def update_sample(change):
        state['sample'] = change.new
        update(state['step'], state['sample'])

    sliders = VBox([
        create_slider('step', range=(0, n_steps-1), start=state['step'], N=n_steps-1, smooth=False, update_fn=update_step),
        create_slider('sample', range=(0, n_samples-1), start=state['sample'], N=n_samples-1, smooth=False, update_fn=update_sample),
        ],
        layout=Layout(width='40%', height='auto', margin = '0px 20% 0px 20%')
    )

    widget = AppLayout(
        center=fig.canvas,
        footer=sliders,
        pane_heights=[0, 8, 2]
        )        

    plt.ion()
    return widget

plot_dynamics_probas(lbls, t_proba, s_proba)

In [None]:
def plot_pairwise_dist(data, dist_fn, ax=None):
    if ax is None:
        ax = plt.figure().gca()

    data = data.squeeze()
    n_samples = data.shape[0]

    arg1 = data.unsqueeze(0).expand(n_samples, -1, -1)
    arg2 = data.unsqueeze(1).expand(-1, n_samples, -1)
    
    mat = dist_fn(arg1, arg2)
    im = ax.imshow(mat)
    plt.colorbar(im, ax=ax)
    return ax

In [None]:
step = -1
_, ax = plt.subplots(2,2, figsize=(10, 8))

ax[0][0].set_title('cossim of teacher')
plot_pairwise_dist(t_logit[step, lbls[0].argsort(), :], cossim, ax=ax[0][0])

ax[0][1].set_title('cossim of student')
plot_pairwise_dist(s_logit[step, lbls[0].argsort(), :], cossim, ax=ax[0][1])

ax[1][0].set_title('kldiv of teacher')
plot_pairwise_dist(t_logit[step, lbls[0].argsort(), :], kl, ax=ax[1][0])

ax[1][1].set_title('kldiv of student')
plot_pairwise_dist(s_logit[step, lbls[0].argsort(), :], kl, ax=ax[1][1])

plt.suptitle(f'Within-dataset pairwise distances at step {step}')

In [None]:
start, stop, step = 0, -1, 100

n_steps, n_samples, n_classes = t_logit.shape
dim = n_samples * n_classes

_, ax = plt.subplots(2,2, figsize=(10, 8))

ax[0][0].set_title('cossim of teacher')
plot_pairwise_dist(t_logit[start:stop:step].reshape(-1, dim), cossim, ax=ax[0][0])

ax[0][1].set_title('cossim of student')
plot_pairwise_dist(s_logit[start:stop:step].reshape(-1, dim), cossim, ax=ax[0][1])

ax[1][0].set_title('kldiv of teacher')
plot_pairwise_dist(t_logit[start:stop:step].reshape(-1, dim), kl, ax=ax[1][0])

ax[1][1].set_title('kldiv of student')
plot_pairwise_dist(s_logit[start:stop:step].reshape(-1, dim), kl, ax=ax[1][1])

plt.suptitle(f'Training dynamic of function space over validation images')

#### Label Denoising Perspective

In [None]:
targs = t_logit[-1].transpose(-1, -2)
t_preds = t_logit.transpose(-1, -2)
s_preds = s_logit.transpose(-1, -2)

#plt.plot(kl(s_logit.reshape(-1, dim), s_logit[-1].reshape(-1)))        # track kl wrt. to final label averaged over images and classes
#plt.plot(kl(s_logit, s_logit[-1]))                                     # for every validation image => track kl wrt to final label

_, ax = plt.subplots(2, 1, sharex=True, sharey=False, figsize=(12,6))
ax[0].set_title('teacher predictions vs final pseudolabels')
ax[0].plot(kl(targs, t_preds), linewidth=1)  # for every pseudo-class => track kl wrt to final label

ax[1].set_title('student predictions vs final pseudolabels')
ax[1].plot(kl(targs, s_preds), linewidth=1)  # for every pseudo-class => track kl wrt to final label
plt.suptitle('Class-wise Denoising: KL-Divergence wrt. final pseudolabels')


In [None]:
_, ax = plt.subplots(2, 1, sharex=True, sharey=False, figsize=(12,6))
ax[0].set_title('teacher predictions vs final pseudolabels')
t_peaks = kl(targs, t_preds).max(dim=0)
ax[0].scatter(t_peaks.indices, t_peaks.values, s=1)  # for every pseudo-class => track kl wrt to final label
ax[0].set_xlim(0, len(t_preds))

ax[1].set_title('student predictions vs final pseudolabels')
s_peaks = kl(targs, s_preds).max(dim=0)
ax[1].scatter(s_peaks.indices, s_peaks.values, s=1)  # for every pseudo-class => track kl wrt to final label
ax[1].set_xlim(0, len(s_preds))

plt.suptitle('Class-wise Denoising: KL-Divergence peaks during training')