In [None]:
import sys

from functools import partial

import torch
from omegaconf import OmegaConf

sys.path.append("..")

from data import (
    WavFeature,
    MeanScoreFeature,
    FilenameFeature,
    process_fn_mean,
)
from models.wav2vec_net import MosPredictor
from slates import EvalSlate
from swag.posteriors import SWAG
from tabula.dataloader import DataLoader, Dataset
from tabula.helpers import CheckpointHelper

In [None]:
import time

import jax.numpy as jnp

from jax import nn
from jax import grad, jit, vmap, value_and_grad
from jax import random

from jax.scipy.special import logsumexp
from jax.experimental import optimizers

from torch.utils import data

import numpy as np
from functools import partial
import IPython
import pandas as pd

In [None]:
TRACK = "ood"

if TRACK == "ood":
    conf_file = "../ood_config.yaml"
    with open(conf_file, "r") as f:
        conf = OmegaConf.load(f)
    conf.checkpoint.path = "../checkpoints/finetune-ood-0.001-1234/bestmodel.pt"
else:
    conf_file = "../ssl_config.yaml"
    with open(conf_file, "r") as f:
        conf = OmegaConf.load(f)
    conf.checkpoint.path = "../checkpoints/wav2vec-swag-0.001/bestmodel.pt"

In [None]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
torch.manual_seed(conf.seed)
torch.cuda.manual_seed(conf.seed)

model = MosPredictor(**conf.model)
swag_model = SWAG(
    MosPredictor,
    no_cov_mat=False,
    max_num_models=10,
    **conf.model,
)
swag_model.cuda()
swag_model.eval()

checkpoint_helper = CheckpointHelper(
    conf.exp_name,
    {
        "model": model,
        "swag_model": swag_model,
    },
    save_epoch=conf.checkpoint.epoch,
    save_iters=conf.checkpoint.iters,
)

In [None]:
_ = checkpoint_helper.load(conf.checkpoint.path)

swag_model.sample(0.0)

In [None]:
data_features = {
    'wav': WavFeature(length_modulo=320),
    'mean_score': MeanScoreFeature(),
    'fname': FilenameFeature(),
}

train_set = Dataset(
    conf.data.train_path,
    data_features,
    proc_fn=partial(process_fn_mean, ood_path=conf.data.valid_path_ood, inf_filter=False),
)
train_loader = DataLoader(
    train_set, num_workers=8, shuffle=False, batch_size=conf.eval.batch_size
)

In [None]:
def relu(x):
    return jnp.maximum(0, x)

def predict(params, activations):
  # per-example predictions
    w, b = params
    outputs = jnp.dot(w, activations) + b
    activations = nn.sigmoid(outputs) * 4 + 1
  
    return activations

batched_predict = vmap(predict, in_axes=(None, 0))

def l1loss(params, images, targets):
    preds = batched_predict(params, images)
    return jnp.mean(jnp.abs(preds - targets))

def l2loss(params, images, targets):
    preds = batched_predict(params, images)
    return jnp.mean((preds - targets) ** 2)

In [None]:
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [768, 1]
step_size = 0.001
num_epochs = 20
batch_size = 64

In [None]:
# Uncomment depending on whether you want to port the SWAG model or vanilla model

params = [
    jnp.array(swag_model.base.output_layer.weight.data.cpu().numpy()),
    jnp.array(swag_model.base.output_layer.bias.data.cpu().numpy()),
]

# params = [
#     jnp.array(model.output_layer.weight.data.cpu().numpy()),
#     jnp.array(model.output_layer.bias.data.cpu().numpy()),
# ]
print(params[0].shape)

In [None]:
start_time = time.time()
train_data = []
train_loss = []
for i, batch_data in enumerate(train_loader):
    with torch.no_grad():
        feats = swag_model.base.ssl_model(batch_data['wav']['data'].cuda(), mask=False, features_only=True)
        feats = torch.mean(feats['x'], 1)
        feats = feats.data.cpu().numpy()
    loss = l2loss(params, feats, batch_data['mean_score'].data.cpu().numpy())
    train_data.append({
        'mean_score': batch_data['mean_score'].data.cpu().numpy(),
        'feats': feats,
        'fname': batch_data['fname'],
    })
    train_loss.append(loss)
epoch_time = time.time() - start_time

print("Epoch {} in {:0.2f} sec".format(1, epoch_time))

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
ranked_idx = jnp.argsort(jnp.array(train_loss))

In [None]:
import jax

from jax import jvp
from jax.tree_util import tree_flatten, tree_leaves


def hvp(params, x, t, v):
    loss_fn = lambda params: l2loss(params, x, t)
    return jvp(grad(loss_fn), (params,), (v,))[1]


def single_loss(params, sentence, targets):
    preds = predict(params, sentence)
    return jnp.mean(jnp.abs(preds - targets))


@jit
def lissa_estimate(params, x, t, v, h_estimate, damp=0.01, scale=25):
    # Recursively caclulate h_estimate
    hv = hvp(params, x, t, h_estimate)
    h_estimate = jax.tree_multimap(lambda x, y, z: x + (1 - damp) * y - z / scale, v, h_estimate, hv)
    return h_estimate


def get_s_test(z_test, t_test, params, z_loader, damp=0.01, scale=25.0,
               recursion_depth=5000):
    v = grad(single_loss)(params, z_test, t_test)
    h_estimate = v.copy()
    for depth in range(recursion_depth):
        x, t, _ = next(iter(z_loader))
        h_estimate = lissa_estimate(params, x, t, v, h_estimate,
                                    damp=damp, scale=scale)

        if depth % 500 == 0:
            print("Calc. s_test recursions: ", depth, recursion_depth)

    return h_estimate

# Select point with largest error

In [None]:
idx = ranked_idx[-1]
target_label = train_data[idx]['mean_score']
test_input = train_data[idx]['feats']
print("Testing id:", idx)

preds = predict(params, test_input[0])
print(train_data[idx]['fname'])
print(f"Real label: {target_label[0]}")
print(f"Original prediction: {preds[0]}")
IPython.display.Audio(f"/home/jiameng/data_voicemos/phase1-{TRACK}/DATA/wav/{train_data[idx]['fname'][0]}")

# Estimate s_test

In [None]:
def collate_fn(data):
    return data[0]['feats'], data[0]['mean_score'], data[0]['fname']

z_loader = torch.utils.data.DataLoader(train_data, collate_fn=collate_fn)

s_test = get_s_test(test_input[0], target_label, params, z_loader)

# Calculate influence functions

In [None]:
@jit
def get_influence(x, t, params, s_test):
    grad_z_vec = grad(single_loss)(params, x, t)
    tmp_influence = jax.tree_multimap(lambda x, y: x * y, grad_z_vec, s_test)
    tmp_influence = -np.sum(jnp.array([jnp.sum(i) for i in tree_leaves(tmp_influence)])) / len(train_data)
    return tmp_influence

influences = []
for i, (x, t, f) in enumerate(z_loader):
    z = [i for i in zip(x, t)]
    tmp_influence = vmap(partial(get_influence, params=params, s_test=s_test), in_axes=(0, 0))(x, t)
    influences.extend(tmp_influence)
    if i % 50 == 0:
        print(i)

helpful = np.argsort(influences)
not_helpful = helpful[::-1]

In [None]:
for i in not_helpful[:10]:
    print(train_data[i]['fname'][0], influences[i])

# Look at most unhelpful points

In [None]:
N = 0
f_idx = not_helpful[N]

fname = train_data[f_idx]['fname'][0]
df = pd.read_csv(f'/home/jiameng/data_voicemos/phase1-{TRACK}/DATA/sets/train_mos_list.txt', names=['fname', 'score'])
filtered_df = df[df.fname == fname]
print(filtered_df)

IPython.display.Audio(f"/home/jiameng/data_voicemos/phase1-{TRACK}/DATA/wav/{fname}")