In [110]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from TD3_BC import Critic
import gym
import d4rl

def restore_state_dict(model, param1):
    # Pointer to keep track of where we are in the concatenated parameters
    pointer = 0
    state_dict = {}
    assert param1.shape[0] == sum([torch.prod(torch.tensor(param.size())) for param in model.parameters()]), "Size Mismatch"
    for name, param in model.named_parameters():
        # Compute the number of elements in this parameter
        num_elements = torch.prod(torch.tensor(param.size()))

        # Slice the concatenated parameters to get the values for this parameter
        param_values = param1[pointer:pointer + num_elements]

        # Reshape the values to match the original shape
        param_values = param_values.view(param.size())

        # Add to the state_dict
        state_dict[name] = param_values

        # Move the pointer
        pointer += num_elements

    # Load the state_dict into the model
    model.load_state_dict(state_dict)


N = 100
lambda_scale = 10

env_name = 'hopper-medium-replay-v2'
# env_name = 'antmaze-large-diverse-v0'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0] 
_max_action = float(env.action_space.high[0])

model_path = f'results/{env_name}.pt'
model_list = torch.load(model_path)['q1_models']
critic_param = model_list[-1]
model = Critic(state_dim, action_dim, layernorm=0).q1
restore_state_dict(model, critic_param)

# critic = torch.load(model_path)['grad_']
    
    # torch.save({
    #     'steps': steps,
    #     'q1_models': q1_models,
    #     'ntks': ntk_list,
    #     'random_batch_grad_list': random_batch_grad_list,
    #     'fix_batch_grad_list': fix_batch_grad_list,
    #     'pi_list': pi_list,
    # }, f'results/{args.env}.pt')
    

# Randomly sample N=100 points from [-1,1]^2
X = torch.tensor(np.random.uniform(-1, 1, (N, state_dim + action_dim)), dtype=torch.float32)
# X_prime = torch.tensor(np.random.uniform(-1, 1, (N, state_dim + action_dim)), dtype=torch.float32)
# add guassian noise to X and make the sum bounded in [-1,1]
X_prime = X + torch.tensor(np.random.normal(0, 10, (N, state_dim + action_dim)), dtype=torch.float32)
X_prime = torch.clamp(X_prime, -1, 1)
# print(X_prime[:10])
y = model(X)

print('param Norm: ', torch.norm(critic_param))
# print(X[:10])
print(y.squeeze()[:10])

# Multiply MLP parameters by lambda and generate y_scaled
for param in model.parameters():
    param.data *= lambda_scale

y_scaled = model(X)
print(y_scaled.squeeze()[:10])
# print(y_scaled.squeeze() / y.squeeze())



  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


param Norm:  tensor(71695.8438, grad_fn=<NormBackward1>)
tensor([2.8636e+12, 1.2621e+12, 1.4748e+12, 1.6022e+12, 4.3775e+11, 9.1026e+11,
        2.7557e+12, 1.1846e+12, 1.9868e+12, 1.4116e+12],
       grad_fn=<SliceBackward0>)
tensor([2.8636e+15, 1.2621e+15, 1.4747e+15, 1.6022e+15, 4.3773e+14, 9.1025e+14,
        2.7557e+15, 1.1846e+15, 1.9868e+15, 1.4115e+15],
       grad_fn=<SliceBackward0>)


In [106]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import numpy as np

# Reshape y and y_scaled for linear regression
y_np = y.detach().numpy().reshape(-1, 1)
y_scaled_np = y_scaled.detach().numpy().reshape(-1, 1)

# Perform linear regression using least squares
reg = LinearRegression(fit_intercept=False)
reg.fit(y_np, y_scaled_np)

# Predict y_scaled
y_pred = reg.predict(y_np)

# Compute R^2 score
r2 = r2_score(y_scaled_np, y_pred)

# Print the results
print("Q-value Linear regression coefficient:", reg.coef_[0][0])
print(f"Expected coefficient: {lambda_scale**3}")
print(f"R^2 score: {r2}")

Q-value Linear regression coefficient: 999.99225
Expected coefficient: 1000
R^2 score: 0.9999999999242609


In [107]:
# Iterate through named parameters to compute norms for weights and biases separately
# weight_norm = 0
# bias_norm = 0
# for name, param in model.named_parameters():
#     if 'weight' in name:
#         weight_norm += torch.norm(param).item()
#     elif 'bias' in name:
#         bias_norm += torch.norm(param).item()

# Concatenate all weight and bias parameters
for i in range(len(model_list)):
    critic_param = model_list[i]
    model = Critic(state_dim, action_dim, layernorm=0).q1
    restore_state_dict(model, critic_param)

    all_weights = torch.cat([param.view(-1) for name, param in model.named_parameters() if 'weight' in name])
    all_biases = torch.cat([param.view(-1) for name, param in model.named_parameters() if 'bias' in name])

    # Compute their norms
    weight_norm = torch.norm(all_weights).item()
    bias_norm = torch.norm(all_biases).item()



    print(f'iter{i}', "Weight Norm / Bias Norm:", weight_norm / bias_norm)
    print("Weight Norm:", weight_norm)
    print("Bias Norm:", bias_norm)
    
print(list(model.named_parameters()))

iter0 Weight Norm / Bias Norm: 10.316726473552626
Weight Norm: 1116.1197509765625
Bias Norm: 108.18545532226562
iter1 Weight Norm / Bias Norm: 10.470930314846687
Weight Norm: 1964.2047119140625
Bias Norm: 187.58645629882812
iter2 Weight Norm / Bias Norm: 10.534145979571308
Weight Norm: 2753.79736328125
Bias Norm: 261.4162902832031
iter3 Weight Norm / Bias Norm: 10.567933979296413
Weight Norm: 3519.4296875
Bias Norm: 333.02911376953125
iter4 Weight Norm / Bias Norm: 10.585212842874496
Weight Norm: 4267.6962890625
Bias Norm: 403.17529296875
iter5 Weight Norm / Bias Norm: 10.595255085036957
Weight Norm: 5006.193359375
Bias Norm: 472.493896484375
iter6 Weight Norm / Bias Norm: 10.598413576373021
Weight Norm: 5734.1279296875
Bias Norm: 541.0364379882812
iter7 Weight Norm / Bias Norm: 10.597772981709555
Weight Norm: 6454.48876953125
Bias Norm: 609.0419921875
iter8 Weight Norm / Bias Norm: 10.588504054374566
Weight Norm: 7163.556640625
Bias Norm: 676.5409545898438
iter9 Weight Norm / Bias Nor

In [108]:

def compute_ntk(model, x, x_prime):
    N = x.size(0)

    # Function to flatten gradients of model parameters
    def flatten_grads(grads):
        return torch.cat([g.view(-1) for g in grads])

    # Compute gradients for each input
    grads = []
    next_grads = []
    for i in range(N):

        # Feedforward the input through the model
        output = model(x[i])

        # Zero the model's gradients
        model.zero_grad()

        # Compute and store the gradients for each output dimension
        output.backward()
        grad = flatten_grads([p.grad for p in model.parameters() if p.grad is not None])

        # Stack gradients for the current input
        grads.append(grad)
        
        # Feedforward the input through the model
        output = model(x_prime[i])

        # Zero the model's gradients
        model.zero_grad()

        # Compute and store the gradients for each output dimension
        output.backward()
        grad = flatten_grads([p.grad for p in model.parameters() if p.grad is not None])

        # Stack gradients for the current input
        next_grads.append(grad)
        

    # Stack gradients for all inputs
    # divide a constant to prevent NAN
    grads_tensor = torch.stack(grads) / (grad.shape[0])**0.5
    next_grads_tensor = torch.stack(next_grads) / (grad.shape[0])**0.5

    # Compute the NTK matrix using tensor operations
    G = torch.matmul(grads_tensor, next_grads_tensor.t())

    return G.detach(), grads_tensor.detach(), next_grads_tensor.detach()



restore_state_dict(model, model_list[-1])
ntk, grad, next_grad = compute_ntk(model, X, X_prime)

for param in model.parameters():
    param.data *= lambda_scale
    
    
scaled_ntk, scaled_grad, scaled_next_grad = compute_ntk(model, X, X_prime)

# reshape ntk; print similarity between ntk and scaled ntk
ntk = ntk.reshape(-1)
next_ntk = scaled_ntk.reshape(-1)
print('ntk cosine similarity', torch.cosine_similarity(ntk, next_ntk, dim=0))
print('ntk norm rate', torch.norm(next_ntk) / torch.norm(ntk))

# reshape grad; print similarity between grad and scaled grad
grad = grad.reshape(-1)
next_grad = scaled_grad.reshape(-1)
print('grad cosine similarity', torch.cosine_similarity(grad, next_grad, dim=0))
print('grad norm rate', torch.norm(next_grad) / torch.norm(grad))



ntk cosine similarity tensor(1.)
ntk norm rate tensor(9999.8936)
grad cosine similarity tensor(1.)
grad norm rate tensor(100.0004)


In [109]:
def compute_ntk(model, x, x_prime):
    N = x.size(0)

    # Function to flatten gradients of model parameters
    def flatten_grads(grads):
        return torch.cat([g.view(-1) for g in grads])

    # Compute gradients for each input
    weight_grads = []
    bias_grads = []
    next_weight_grads = []
    next_bias_grads = []
    for i in range(N):

        # Feedforward the input through the model
        output = model(x[i])

        # Zero the model's gradients
        model.zero_grad()

        # Compute and store the gradients for each output dimension
        output.backward()
        weight_grad = flatten_grads([p.grad for name, p in model.named_parameters() if 'weight' in name and p.grad is not None])
        bias_grad = flatten_grads([p.grad for name, p in model.named_parameters() if 'bias' in name and p.grad is not None])

        # Stack gradients for the current input
        weight_grads.append(weight_grad)
        bias_grads.append(bias_grad)
        
        # Feedforward the input through the model
        output = model(x_prime[i])

        # Zero the model's gradients
        model.zero_grad()

        # Compute and store the gradients for each output dimension
        output.backward()
        next_weight_grad = flatten_grads([p.grad for name, p in model.named_parameters() if 'weight' in name and p.grad is not None])
        next_bias_grad = flatten_grads([p.grad for name, p in model.named_parameters() if 'bias' in name and p.grad is not None])


        # Stack gradients for the current input
        next_weight_grads.append(next_weight_grad)
        next_bias_grads.append(next_bias_grad)
        

    # Stack gradients for all inputs
    # divide a constant to prevent NAN
    size = weight_grad.shape[0] + bias_grad.shape[0]
    weight_grads_tensor = torch.stack(weight_grads) / size**0.5
    bias_grads_tensor = torch.stack(bias_grads) / size**0.5
    next_weight_grads_tensor = torch.stack(next_weight_grads) / size**0.5
    next_bias_grads_tensor = torch.stack(next_bias_grads) / size**0.5
    # weight_grads_tensor = torch.stack(weight_grads) 
    # bias_grads_tensor = torch.stack(bias_grads) 
    # next_weight_grads_tensor = torch.stack(next_weight_grads)
    # next_bias_grads_tensor = torch.stack(next_bias_grads)

    # Compute the NTK matrix using tensor operations
    G_weight = torch.matmul(weight_grads_tensor, next_weight_grads_tensor.t())
    G_bias = torch.matmul(bias_grads_tensor, next_bias_grads_tensor.t())
    print(G_weight.shape, G_bias.shape)
    print(G_weight[:4, :4])
    print(G_bias[:4, :4])
    
    # norm rate
    print('weight/bias NTK rate', torch.norm(torch.diag(G_weight)) / torch.norm(torch.diag(G_bias)))

    return G_weight.detach(), G_bias.detach()




restore_state_dict(model, model_list[10])
G_weight, G_bias = compute_ntk(model, X, X_prime)
ntk2 = G_weight + G_bias

# whether ntk equals to ntk2
print(ntk.reshape((N,N))[:2, :10], ntk2[:2, :10])

# compute_ntk(model, X, X)

# for param in model.parameters():
#     param.data *= lambda_scale
    
    
# scaled_ntk, scaled_grad, scaled_next_grad = compute_ntk(model, X, X_prime)

# # reshape ntk; print similarity between ntk and scaled ntk
# ntk = ntk.reshape(-1)
# next_ntk = scaled_ntk.reshape(-1)
# print('ntk cosine similarity', torch.cosine_similarity(ntk, next_ntk, dim=0))
# print('ntk norm rate', torch.norm(next_ntk) / torch.norm(ntk))

# # reshape grad; print similarity between grad and scaled grad
# grad = grad.reshape(-1)
# next_grad = scaled_grad.reshape(-1)
# print('grad cosine similarity', torch.cosine_similarity(grad, next_grad, dim=0))
# print('grad norm rate', torch.norm(next_grad) / torch.norm(grad))


torch.Size([100, 100]) torch.Size([100, 100])
tensor([[ 6.8981e+08,  1.9389e+08,  1.9291e+08,  2.5750e+07],
        [ 1.3779e+09,  7.4497e+08,  2.1770e+08,  4.8533e+08],
        [ 3.8562e+08, -1.8764e+08,  6.0813e+08,  4.1301e+08],
        [ 2.1829e+09,  1.4552e+09,  1.4245e+09,  7.3954e+08]])
tensor([[1.4078e+08, 1.4504e+08, 8.5961e+07, 1.3817e+08],
        [2.1403e+08, 2.1936e+08, 2.0636e+08, 6.7178e+07],
        [2.0130e+08, 2.0130e+08, 1.5371e+08, 1.2437e+08],
        [2.8615e+08, 2.9148e+08, 2.3240e+08, 1.3924e+08]])
weight/bias NTK rate tensor(6.6623)
tensor([[3.1570e+12, 1.1195e+12, 2.0930e+12, 9.4983e+11, 1.9378e+12, 4.8247e+11,
         1.4625e+12, 6.0053e+11, 1.2362e+12, 4.9715e+11],
        [7.1594e+12, 3.8721e+12, 5.4845e+12, 4.7904e+12, 1.9393e+12, 3.4202e+12,
         2.0803e+12, 1.5541e+12, 4.5866e+12, 5.9006e+12]]) tensor([[ 8.3059e+08,  3.3893e+08,  2.7887e+08,  1.6392e+08,  5.1979e+08,
          2.0015e+08,  3.1466e+08,  4.9955e+07,  2.3031e+08,  5.7939e+07],
        

In [12]:
# Action similarity check

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from TD3_BC import Critic
import gym
import d4rl


for env_name in [
                # 'hopper-medium-v2', 'hopper-medium-expert-v2',  'hopper-medium-replay-v2',
                #  'walker2d-medium-v2', 'walker2d-medium-expert-v2', 'walker2d-medium-replay-v2',
                'halfcheetah-medium-v2', 'halfcheetah-medium-expert-v2', 'halfcheetah-medium-replay-v2',
    'antmaze-large-diverse-v0', 'antmaze-medium-play-v0', 'antmaze-medium-diverse-v0', 'antmaze-large-play-v0', 'antmaze-large-diverse-v0']:

    # env = gym.make(env_name)
    # state_dim = env.observation_space.shape[0]
    # action_dim = env.action_space.shape[0] 
    # _max_action = float(env.action_space.high[0])

    model_path = f'results/{env_name}.pt'
    model_list = torch.load(model_path)['q1_models']
    pi_list = torch.load(model_path)['pi_list']
    
    print(env_name)
    sims = []
    distances = []
    for i in range(len(model_list)-1):
        sims.append(round(torch.cosine_similarity(pi_list[i], pi_list[i+1], dim=1).mean().item(), 2))
        # get the distance of corresponding vectors in the axis=1 and average them
        distances.append(round(torch.norm(pi_list[i]-pi_list[i+1], dim=1).mean().round(decimals=2).item(), 2))
    print(f'Cosine similarity: {sims}')
    print(f'Distance: {distances}')
        




halfcheetah-medium-v2
Cosine similarity: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Distance: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.