In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import jax
import jax.numpy as jnp
from flax import linen as nn

In [2]:
class Classifier(nn.Module):

    def setup(self):
        self.layer1 = nn.Dense(784)
        self.layer2 = nn.Dense(10)

    def __call__(self, x, deterministic=False):
        out = self.layer1(x)
        out = nn.relu(out)
        return self.layer2(out)

In [3]:
num_devices = 1
batch_size = 64
learning_rate = 0.1

In [4]:
# Load the dataset using PyTorch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np

dataset = torchvision.datasets.MNIST(root="/home/tejasj/data", train=True, download=True)


class MNISTDataset(Dataset):
    def __init__(self):
        self.data = dataset

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        return {
            "image": torchvision.transforms.ToTensor()(data[0]).cpu().numpy(),
            "label": np.array(data[1]).reshape(-1, ),
        }


dataloader = DataLoader(
    dataset=MNISTDataset(),
    batch_size=num_devices * batch_size,
    shuffle=True,
)



In [5]:
# Check the number of visible devices
num_devices = jax.device_count()

# Seed the random number generator
key = jax.random.PRNGKey(42)

# Initialize the model
classifier = Classifier()
key, sub_key = jax.random.split(key)

# This returns all the parameters in a frozen dict
variables = classifier.init(key, jnp.ones((1, 784)))


In [6]:
import optax
from flax.training import train_state
from clu import metrics
from flax import struct

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")


class TrainState(train_state.TrainState):
    metrics: Metrics

tx = optax.adam(learning_rate=0.001)


train_state = TrainState(
    step=0,
    apply_fn=classifier.apply,
    params=variables["params"],
    tx=tx,
    opt_state=tx.init(variables["params"]),
    metrics=Metrics.empty(),
)

In [7]:
import flax.jax_utils as flax_utils

# Initialize the train step
def loss_fn(params, batch):
    output = train_state.apply_fn(
        {"params": params},
        batch["image"],
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=output, labels=batch["label"]
    ).mean()
    return loss, output


def train_step(state, batch):
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params, batch)
    grads = jax.lax.pmean(grads, axis_name="batch")

    new_state = state.apply_gradients(
        grads=grads,
    )
    metric_updates = state.metrics.gather_from_model_output(
        loss=loss, logits=logits, labels=batch["label"]
    )
    metrics = state.metrics.merge(metric_updates)
    new_state = new_state.replace(metrics=metrics)

    return new_state


# Do some initialization for parallel training
p_train_step = jax.pmap(train_step, axis_name="batch")


def train(start_state):
    state = start_state

    # Distribute training
    state = flax_utils.replicate(state)

    for _ in range(5):
        for data in dataloader:
            num_samples = data["image"].shape[0]
            batch = {
                "image": data["image"].permute(0, 2, 3, 1).reshape(num_devices, num_samples // num_devices, -1).numpy(),
                "label": data["label"].reshape(num_devices, num_samples // num_devices, ).numpy(),
            }
            state = p_train_step(state, batch)

    return state

In [8]:
state = train(train_state)

In [9]:
print(state.metrics.unreplicate().compute())

{'accuracy': Array(0.9751533, dtype=float32), 'loss': Array(0.08414858, dtype=float32)}


In [10]:
def add_noise_to_params(key, state):
    def add_noise(param):
        new_param = param + 0.01 * jax.random.uniform(key, param.shape)
        return new_param
    new_params = jax.tree_map(add_noise, state.params)
    return new_params

In [11]:
def add_noise_to_params_multiple(state, num_params):
    key = jax.random.PRNGKey(42)

    new_params = []
    for _ in range(num_params):
        key, subkey = jax.random.split(key)
        new_param = add_noise_to_params(subkey, state)
        new_params.append(new_param)
    
    return new_params


In [12]:
state_unrpl = flax_utils.unreplicate(state)

In [13]:
new_states = add_noise_to_params_multiple(state_unrpl, 2)

In [19]:
for data in dataloader:            
    num_samples = data["image"].shape[0]
    batch = {
        "image": data["image"].permute(0, 2, 3, 1).reshape(num_samples, -1).numpy(),
        "label": data["label"].reshape(num_samples, ).numpy(),
    }
    break

def loss_fn_svd(params):
    logits = classifier.apply({"params": params}, batch["image"])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch["label"]
    ).mean()
    return loss

loss_fn_svd_grad = jax.grad(loss_fn_svd)

grads = []
for params in new_states:
    grads.append(loss_fn_svd_grad(params))

In [20]:
leaves = []
for params in grads:
    leaves.append(jax.lax.concatenate(jax.tree_map(lambda x: x.reshape(-1, 1), jax.tree_util.tree_leaves(params)), dimension=0))
leaves = jax.lax.concatenate(leaves, dimension=1)

In [21]:
leaves.shape

(623290, 2)

In [23]:
with jax.default_device(jax.devices("cpu")[0]):
    U, _, _ = jnp.linalg.svd(leaves, full_matrices=False)

In [24]:
for data in dataloader:            
    num_samples = data["image"].shape[0]
    batch = {
        "image": data["image"].permute(0, 2, 3, 1).reshape(num_samples, -1).numpy(),
        "label": data["label"].reshape(num_samples, ).numpy(),
    }
    break

def loss_fn_svd(params):
    logits = classifier.apply({"params": params}, batch["image"])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch["label"]
    ).mean()
    return loss

loss_fn_svd_hess = jax.hessian(loss_fn_svd)

grads = []
for params in new_states:
    grads.append(loss_fn_svd_hess(params))
    break

2023-07-16 18:23:34.526093: W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 17.77GiB (19078594560 bytes) by rematerialization; only reduced to 1.41TiB (1553961696400 bytes)
2023-07-16 18:23:44.591465: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.41TiB (rounded to 1553961696512)requested by op 
2023-07-16 18:23:44.591668: W external/tsl/tsl/framework/bfc_allocator.cc:497] *___________________________________________________________________________________________________
2023-07-16 18:23:44.591764: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1553961696400 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         0B
              constant allocation:         0B
        maybe_live_out allocation:    1.41TiB
     preallocated temp a

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1553961696400 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         0B
              constant allocation:         0B
        maybe_live_out allocation:    1.41TiB
     preallocated temp allocation:         0B
                 total allocation:    1.41TiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 1.41TiB
		Operator: op_name="jit(iota)/jit(main)/iota[dtype=int32 shape=(623290, 623290) dimension=0]" source_file="/tmp/ipykernel_3307486/406865854.py" source_line=20
		XLA Label: iota
		Shape: s32[623290,623290]
		==========================

