In [32]:
using Symbolics
using Latexify
using Line

include("../data/probe_token.jl")
include("../data/pre_norm.jl")

N=512

μ(x) = sum(x) / N
E(x) = μ(x) 

c(x) = x .- μ(x)

#var(x) = sum(c(x) .^2 )
#var(x) = sum((x .- μ(x)) .^2 )
var(x) = sum((x .- μ(x)) .^2 )/N

ϵ = 1e-5



1.0e-5

[ReExaminingLayerNorm.ipynb](https://colab.research.google.com/drive/1S39-w4vzX3VzZx_27X_BtrLs442pOJnJ) (also [described on LessWrong](https://www.lesswrong.com/posts/jfG6vdJZCwTQmG7kb/re-examining-layernorm) ) describes the following as definition for layer-norm from PyTorch

In [33]:
LN(x) = (x .- E(x))/sqrt(var(x) + ϵ) 


LN (generic function with 1 method)

If it is equivalent this should return 11.4077

In [34]:
bias = 0.8328

final_residual = LN(pre_norm)

logit = sum(.*(probe_token, final_residual)) + bias


11.407851912178797

Following the notebook

In [37]:
norm(x) = sqrt(sum(x .^ 2))
u_ϵ(x) = x .* (1/sqrt(norm(x)^2 + ϵ) )

u_ϵ (generic function with 1 method)

$$\sqrt{n} \cdot u_{n \epsilon}(x) = \frac{x}{\sqrt{\textrm{Var}[x] + \epsilon}}$$

In [38]:
final_residual = sqrt(512) .* u_ϵ(pre_norm)


logit = sum(.*(probe_token, final_residual)) + bias

11.40785559974425

$$LN = \sqrt{n} \cdot U_{n \epsilon}(c(x))$$

In [40]:
u_nϵ(x) = x .* (1/sqrt(norm(x)^2 + (512*ϵ)) )


final_residual = sqrt(512) .* u_nϵ(c(pre_norm))


logit = sum(.*(probe_token, final_residual)) + bias

11.407851912178797

Applying layer normalization results in a vector of approx unit length, in the direction of the centered vector.

$$LN v2 = u_{N \epsilon}(c(v2))$$

$$μ(v) = (v ⋅ \vec{1}) * \frac{1}{N} 
    = |v| . \frac{\sqrt{N}}{N} . \cos{\theta_{v,\vec{1}}} 
    = \frac{|v|. \cos{\theta_{v,\vec{1}}}}{\sqrt{N}} $$

$$c(v) = v - (\vec{1} * μ(v))$$

which is $v$ with the $\vec{1}$ component cancelled out.

The inner product between 2 vectors, with LN applied to the second
$$\bra{v1} \ket{LN v2} \approx |v1| \cos{\theta_{v1,c(v2)}} $$

If $v1$ is understood as the sum of several vectors, they can be analysed in terms of
how they contribute to the angle of the centered vector.
