In [1]:
from llm_bases.chatglm6b import ChatGML6B
from desi_llm.common.utils import random_orthonormal
import torch
glm6b = ChatGML6B()
transformer = glm6b.condgen.transformer.layers[0]
w1 = transformer.mlp.dense_h_to_4h.weight.to("cuda:1")
w2 = transformer.mlp.dense_4h_to_h.weight.to("cuda:1")
key = random_orthonormal(4096, 1000, "cuda:1")
sw1 = (w1.float() @ key).half()
sw2 = (key.T @ w2.float()).half()





  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:16<00:00,  2.08s/it]


In [28]:
def apply_normal_noise(x: torch.Tensor, std: float):
    noise = torch.normal(0, std, x.shape, dtype=x.dtype, device=x.device)
    return x + noise

In [36]:
h1 = w2 @ w1
h11 = apply_normal_noise(w2, 0.001) @ apply_normal_noise(w1, 0.001)
h2 = sw2 @ sw1
h21 = apply_normal_noise(sw2, 0.001) @ apply_normal_noise(sw1, 0.001)

In [40]:
from torch.optim import SGD

def solve_orthogonal_by_torch(h_raw: torch.Tensor, h_rotated: torch.Tensor, ortho_reg: float = 0.1, n_iterations: int = 2000, stop_loss: float = 1e-5, device="cuda:1"):
    h_raw = h_raw.float().to(device)
    h_rotated = h_rotated.float().to(device)

    identity = torch.eye(h_raw.shape[0], dtype=torch.float, device=device)
    transformation = torch.tensor(torch.eye(h_raw.shape[0], dtype=torch.float, device=device), requires_grad=True)
    optimizer = SGD([transformation], 0.01, momentum=0.9)

    for i in range(n_iterations):
        loss0 = torch.mean(torch.square(transformation @ h_raw - h_rotated @ transformation))
        loss_reg = ortho_reg * torch.mean(torch.square(transformation.T @ transformation - identity))
        (loss0 + loss_reg).backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Losses: {loss0.item():.4f}, {loss_reg.item():.4f}")
        if (loss0 + loss_reg) < stop_loss:
            break

    return transformation

In [41]:
transformation = solve_orthogonal_by_torch(h1, h2)

  transformation = torch.tensor(torch.eye(h_raw.shape[0], dtype=torch.float, device=device), requires_grad=True)


Losses: 0.0043, 0.0000
Losses: 0.0042, 0.0000
Losses: 0.0039, 0.0000
Losses: 0.0036, 0.0000
Losses: 0.0033, 0.0000
Losses: 0.0031, 0.0000
Losses: 0.0027, 0.0000
Losses: 0.0023, 0.0000
Losses: 0.0019, 0.0000
Losses: 0.0018, 0.0000
Losses: 0.0018, 0.0000
Losses: 0.0016, 0.0000
Losses: 0.0014, 0.0000
Losses: 0.0013, 0.0000
Losses: 0.0014, 0.0000
Losses: 0.0017, 0.0000
Losses: 0.0019, 0.0000
Losses: 0.0019, 0.0000
Losses: 0.0017, 0.0000
Losses: 0.0018, 0.0000


In [37]:
print(torch.mean(torch.square(h11-h1)))
print(torch.mean(torch.square(h2-h1)))
print(torch.mean(torch.square(h21-h1)))

tensor(7.9870e-06, device='cuda:1', dtype=torch.float16)
tensor(0.0043, device='cuda:1', dtype=torch.float16)
tensor(0.0043, device='cuda:1', dtype=torch.float16)


In [42]:
print(torch.mean(torch.square(transformation - key)))

tensor(0.0002, device='cuda:1', grad_fn=<MeanBackward0>)


In [44]:
print(torch.mean(torch.square(key - torch.eye(key.shape[0], device="cuda:1"))))

tensor(0.0002, device='cuda:1')


In [None]:
g2 = sw1 @ sw2

In [None]:
g3 = apply_normal_noise(sw1, 0.001) @ apply_normal_noise(sw2, 0.001)

In [None]:
g1 = g1.float()
g2 = g2.float()
g3 = g3.float()

In [None]:
def earth_mover_distance(xs: torch.Tensor, ys: torch.Tensor):
    xs, _ = torch.sort(xs)
    ys, _ = torch.sort(ys)
    return torch.mean(torch.abs(xs - ys), dim=-1)

In [None]:
print(earth_mover_distance(g1.float(), g2.float()))
print(earth_mover_distance(g1.float(), g3.float()))
print(earth_mover_distance(g1[1:].float(), g3[:-1].float()))

In [None]:
torch.sum(torch.square(g2 - g1))

In [None]:
torch.sum(torch.square(g3 - g1))

In [None]:
from scipy.optimize import quadratic_assignment

permutation = torch.randperm(4 * 4096).to("cuda:1")


g3_permed = g3[permutation][:, permutation]

g1_np = g1.cpu().numpy()
g3_np = g3.cpu().numpy()


In [None]:
res = quadratic_assignment(g1_np, - g3_np)

In [None]:
permutation_np = permutation.cpu().numpy()

In [None]:
torch.std(g1)

In [None]:
h = g2 @ g1