In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax
from jax import numpy as jnp
from flax import nnx
from models import LeNet, ResNet
from fedflax import train, aggregate
from data import get_gaze
from functools import partial, reduce
from utils import err_fn, opt_create, return_ce, return_l2, angle_err

ds_train = get_gaze(beta=1., skew="feature", discrete=False)
ds_val = get_gaze(partition="val", beta=1., batch_size=16, skew="feature", discrete=False)
ds_test = get_gaze(partition="test", beta=1., batch_size=16, skew="feature", discrete=False)

model = LeNet(jax.random.key(42), wasym="random", kappa=1, sigma=0., dim_out=2, dimexp=False)
updates, models, opts = train(model, partial(opt_create, learning_rate=1e-3), ds_train, ds_val, return_l2(0.), local_epochs=20, rounds=1, val_fn=angle_err);

model_g = aggregate(model, updates)
vval_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(None,0,0,0)))
err_test = reduce(lambda e, batch: e + vval_fn(model_g,*batch), ds_test, 0.) / len(ds_test)
err_val = reduce(lambda e, batch: e + vval_fn(model_g,*batch), ds_val, 0.) / len(ds_val)

vval_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(0,0,0,0)))
err_sep_test = reduce(lambda e, batch: e + vval_fn(models,*batch), ds_test, 0.) / len(ds_test)
err_sep_val = reduce(lambda e, batch: e + vval_fn(models,*batch), ds_val, 0.) / len(ds_val)

print(err_test.mean(), err_val.mean(), err_sep_test.mean(), err_sep_val.mean())
# regular: test ~11, val ~5
# wasym: test ~12, val ~5
# wasym kappa 10: test ~15, val ~28
# wasym kappa .1: test ~11, val ~4
# wasym kappa .01: test ~11, val ~4
# wasym rand: test ~89, val ~90
# syre 1e-3, wd 1e-3: test ~11, ~val 4
# syre 1e-2, wd 1e-3: test ~63, val ~5
# syre 1e-2, wd 1e-2: test ~66, val ~4
# syre .1, wd 1e-3: test ~91, val ~6

def angle(updates):
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updates)
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates_flat = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    for update in updates_flat:
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    return angle

print(angle(jax.tree.leaves(updates)))
# regular: ~57
# wasym: ~55
# wasym kappa 10: ~55
# wasym kappa .1: ~56
# wasym kappa .01: ~55
# wasym rand: ~45
# syre 1e-3, wd 1e-3: ~57
# syre 1e-2, wd 1e-3: ~57
# syre 1e-2, wd 1e-2: ~58
# syre .1, wd 1e-3: ~57

In [None]:
%load_ext autoreload
%autoreload 2
from models import ResNet, LeNet
from fedflax import train
from data import get_gaze
from flax import nnx
import optax, jax

def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce, (0., 0.)

opt_create = lambda model: nnx.Optimizer(
        model,
        optax.adamw(learning_rate=1e-3),
        wrt=nnx.Param)

updates, models = train(
    LeNet(nnx.Rngs(0)), 
    opt_create,
    get_gaze(beta=.4),
    get_gaze(beta=.4, partition="val", batch_size=16),
    local_epochs=10,
    ell=ell,
    rounds=2
)

print((updates[1]==updates[0]).all())
l = jax.tree.leaves(nnx.to_tree(models.layers[1].conv2))[1]
print((l[0]==l[1]).all())

In [None]:
%load_ext autoreload
%autoreload 2

import os, jax, torch, torchvision
# jax.config.update('jax_enable_x64', True)
from jax import numpy as jnp
from torch.utils.data import Dataset, DataLoader, default_collate
from matplotlib import pyplot as plt
from itertools import product
from data import get_gaze

In [None]:
ds = get_gaze(beta=0., skew="feature", discrete=False)
# all_labels = []
# for *_, y in ds:
#     all_labels.append(y.reshape(-1,2))
# all_labels = jnp.concatenate(all_labels)

(img, aux, label) = torch.load("MPIIGaze_preprocessed/train/p14/day06_50_left.pt")
imshape = torch.asarray(img.shape[1::-1])
plt.imshow(img, cmap="gray");
plt.xticks([]); plt.yticks([]);
plt.scatter(*((all_labels+.5)*jnp.asarray(img.shape[1::-1])).T, alpha=.02, c="r");
nr = 3
min_0, max_0, min_1, max_1 = -0.36719793, 0.3623084, -0.31378174, 0.38604215
min_0, max_0, min_1, max_1 = min_0*(1-1/nr), max_0*(1-1/nr), min_1*(1-1/nr), max_1*(1-1/nr)
regions = torch.concat([
    torch.cartesian_prod(torch.linspace(min_0, max_0, nr), torch.linspace(min_1, max_1, nr)[:2]),
    torch.cartesian_prod(torch.linspace(min_0, max_0, nr), torch.linspace(min_1, max_1, nr)[2:])
])+.5
plt.scatter(*(regions*imshape).T, c="b");
loc = torch.asarray([-torch.arcsin(-label[1]), torch.arctan2(-label[0], -label[2])])+.5
plt.scatter(loc[0]*imshape[0], loc[1]*imshape[1], c="y");
label = torch.abs(loc - regions).sum(axis=1).argmin()
plt.scatter(*(regions[label]*imshape).T, c="purple");

In [None]:
from data import get_gaze
from jax import numpy as jnp
ds = get_gaze(beta=0., partition="train", discrete=True)
label_dist = jnp.zeros(9)
for *_, y in ds:
    label_dist += y.sum((0,1))
print(label_dist/label_dist.sum())

# test = [0.08188657 0.1394676  0.1099537  0.15914352 0.09201389 0.13975695 0.08159722 0.11458334 0.08159722]
# val = [0.0703125  0.140625   0.08203125 0.18359375 0.06640625 0.16015625 0.08203125 0.125      0.08984375]
# train = [0.08195466 0.13985908 0.10753677 0.16881128 0.08869486 0.13449755 0.07628677 0.11764707 0.08471201]

In [None]:
%load_ext autoreload
%autoreload 2
import jax
from jax import numpy as jnp
from models import LeNet, teleport_lenet
from itertools import chain
from flax import nnx

model = LeNet(nnx.Rngs(key))
rand_in = jax.random.normal(key, (1, 36, 60, 1))
rand_aux = jax.random.normal(key, (1, 3))
out_orig = model(rand_in, rand_aux)
out_tele = teleport_lenet(model)(rand_in, rand_aux)
print("Outputs equal after teleportation:", jnp.allclose(out_orig, out_tele, atol=1e-5))

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
fig, ax = plt.subplots(dpi=700)
bar = fig.colorbar(plt.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=.5, vmax=1.), cmap="inferno"), ax=ax, ticks=None)
bar.set_label("Error Rate")

In [None]:
%load_ext autoreload
%autoreload 2
import os 
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
# os.environ["JAX_DISABLE_MOST_FASTER_PATHS"] = "1"
from data import get_gaze
import fedflax, jax, optax
from fedflax import train
from jax import numpy as jnp
from models import ResNet
from flax import nnx

# Optimizer
opt = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# # Identically initialized models, interpretable as collection by nnx 
# keys = nnx.vmap(lambda k: nnx.Rngs(k))(jnp.array([jax.random.key(42)]*4))
# models = nnx.vmap(ResNet)(keys)
# # Ditto for optimizers
# opts = nnx.vmap(opt)(models)

# @nnx.vmap
# @nnx.value_and_grad
# def ell(model, x_batch, z_batch, y_batch):
#     y_pred = model(x_batch, z_batch, train=True)
#     loss = optax.softmax_cross_entropy(y_pred, y_batch).mean()
#     return loss
def ell(model, model_g, x_batch, z_batch, y_batch, train):
    y_pred = model(x_batch, z_batch, train=train)
    loss = optax.softmax_cross_entropy(y_pred, y_batch).mean()
    return loss, (0., 0.)

train_ds = get_gaze("overlap", beta=1.)
# for x,z,y in train_ds:
#     val, grad = ell(models, x, z, y)
#     break

train(ResNet, opt, train_ds, None, ell, local_epochs=10)

In [None]:
%load_ext autoreload
%autoreload 2
from data import create_imagenet
from matplotlib import pyplot as plt
import jax, optax
from jax import numpy as jnp
from fedflax import train
from flax import nnx
from models import ResNet
n=4

ds_train = create_imagenet(n=n, feature_beta=.1)

def ell(model, model_g, x_batch, y_batch):
    ce = optax.softmax_cross_entropy(model(x_batch), y_batch).mean()
    return ce, (0., ce)

@nnx.jit
@nnx.vmap(in_axes=(0,None,0,0,0))
def train_step(model, model_g, opt, x_batch, y_batch):
    (loss, (prox, ce)), grads = nnx.value_and_grad(ell, has_aux=True)(model, model_g, x_batch, y_batch)
    # grads = jax.tree.map(lambda g: g/2**15, grads)
    opt.update(grads)
    return loss, grads

# Optimizer
opt = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Identically initialized models, interpretable as collection by nnx 
keys = nnx.vmap(lambda k: nnx.Rngs(k))(jnp.array([jax.random.key(42)]*n))
models = nnx.vmap(ResNet)(keys)
# Ditto for optimizers
opts = nnx.vmap(opt)(models)
# Init and save
params, struct = jax.tree.flatten(nnx.to_tree(models))
model_g = nnx.from_tree(jax.tree.unflatten(struct, jax.tree.map(lambda x: jnp.mean(x, axis=0), params)))

for x, y in ds_train:
    loss, grads = train_step(models, model_g, opts, x, y)
    print(loss)
    if jnp.isnan(loss).any():
        print("NaN encountered")
        break

In [None]:
from jax import numpy as jnp
import torchvision
from matplotlib import pyplot as plt
from scipy.ndimage import map_coordinates
import cv2
import numpy as np

def perspective_shift(image, angle=0, skew_strength=0):
    h, w = image.shape[:2]
    src_pts = np.array([[0, 0], [w, 0], [0, h], [w, h]], dtype=np.float32)
    dx = np.cos(angle) * skew_strength * w
    dy = np.sin(angle) * skew_strength * h
    dst_pts = np.array([
        [0 + dx, 0],
        [w + dx, 0 + dy],
        [0 - dx, h],
        [w - dx, h - dy]
    ], dtype=np.float32)
    transform = cv2.getPerspectiveTransform(src_pts, dst_pts)
    return cv2.warpPerspective(image, transform, (w, h))

image = np.swapaxes(np.asarray(torchvision.io.read_image("/thesis/data/Data/CLS-LOC/train/n01534433/n01534433_47.JPEG")), 0, -1)
distorted_image = perspective_shift(image, angle=jnp.pi/4, skew_strength=1.2)
plt.imshow(distorted_image)