# Replacing weights of a traced/loaded model on-the-fly

There are some scenarios where you need to change the weights of a loaded model on-the-fly to consume less time with I/O, hardware initialization, etc. In PyTorch, there is a well defined way of manipulating weights of a model and we're going to explore that as well. The only thing we need to pay attention is on the way we load the new weights and to which device we should move the tensors before trying to replace them in a model loaded into Inferentia2/Trainium HBM.

In [None]:
import os
os.environ['NEURON_RT_NUM_CORES']='1'
import torch
import torch.nn as nn
import torch_neuronx

### 1) First, let's create a dummy model with a simple linear layer

In [2]:
x = torch.rand(2, 4)
if os.path.isfile("linear.pt"):
    print("Loading model from disk")
    traced_model = torch.jit.load("linear.pt")
else:
    print("Tracing model...")
    model = nn.Linear(4, 4, bias=False)
    _= torch.nn.init.xavier_uniform_(model.weight)
    y = model(x)
    traced_model = torch_neuronx.trace(model, x, inline_weights_to_neff=False) # inline_weights = False is required for replacing weights on-the-fly
    traced_model.save("linear.pt")

Loading model from disk


#### Special device
Now, we need to use a special device called **privateuseone** where we load our tensors. This special device will make use of Inferentia HBM, so in the end you have a tensor loaded into the accelerated memory, ready to be used.

In [3]:
x = x.to("privateuseone:0")

### 2) Then we execute it to see the results

In [4]:
y = traced_model(x)
y

tensor([[ 0.2516,  0.0049,  0.6511, -0.4800],
        [ 0.1179, -0.0621,  0.7725,  0.2441]])

### 3) Now, let's create a new set of weights and replace the original/loaded ones from our model
In this step, we'll replace all the weights of our model. You'll see in the results completely different values. Please notice we didn't reload our model. Only the weights were replaced.

In [5]:
new_weights = torch.rand(4, 4).to("privateuseone:0")
_= torch.nn.init.xavier_uniform_(new_weights)

In [6]:
torch_neuronx.replace_weights(traced_model, {"weight": new_weights} )

In [7]:
y = traced_model(x)
y.cpu()

tensor([[ 0.3004, -0.4989, -0.3429, -0.1641],
        [ 0.4357, -0.2272, -0.8096, -0.1748]])

### 4) Finally, let's create a new set of weights, but this time we'll replace only a fraction of the model weights

In [8]:
new_weights = torch.rand(1, 2).to("privateuseone:0")
_= torch.nn.init.xavier_uniform_(new_weights)

In [10]:
model_weights = traced_model.weights._parameters['weight']
print("Original weights")
model_weights.cpu()

Original weights


tensor([[ 0.4103,  0.3810,  0.4633, -0.3963],
        [-0.7749,  0.1204,  0.1080, -0.0109],
        [-0.5051,  0.7457, -0.8462, -0.5822],
        [-0.0180, -0.0679,  0.1382, -0.3241]])

### Use torch.scatter to replace weights with offsets, like the example below
If you don't need to use offsets, for instance, replace the initial X weight values, then you can direct index the tensor.

In [12]:
idx = torch.tensor([[2,3]])
traced_model.weights._parameters['weight'] = torch.scatter(model_weights, -1, idx, new_weights)
print("Modified weights")
traced_model.weights._parameters['weight'].cpu()

Modified weights


tensor([[ 0.4103,  0.3810, -0.3786, -0.9462],
        [-0.7749,  0.1204,  0.1080, -0.0109],
        [-0.5051,  0.7457, -0.8462, -0.5822],
        [-0.0180, -0.0679,  0.1382, -0.3241]])

As you can see in the printed set of weights above, only the 1st 2 elements of row 0 were replaced. And you get different predictions, off course.

In [13]:
y = traced_model(x)
y.cpu()

tensor([[ 0.0361, -0.4989, -0.3429, -0.1641],
        [-0.5160, -0.2272, -0.8096, -0.1748]])