In [2]:
import torch

# Original weight matrix (example)
W = torch.randn(100, 200)  # Example shape [m=100, n=200]

# Perform SVD
U, S, Vh = torch.linalg.svd(W, full_matrices=False)

# Top-k and bottom-r decomposition
k = 50  # Choose the top-k singular values
r = 50  # Choose the bottom-r singular values

# Extract the top k singular values and vectors
U_top = U[:, :k]
S_top = torch.diag(S[:k])
Vh_top = Vh[:k, :]

# Extract the bottom r singular values and vectors
U_bot = U[:, -r:]
S_bot = torch.diag(S[-r:])
Vh_bot = Vh[-r:, :]

# Construct the two weight matrices
W_top = U_top @ S_top @ Vh_top  # Top-k part
W_bot = U_bot @ S_bot @ Vh_bot  # Bottom-r part

# Verify that W is the sum of the two matrices
W_reconstructed = W_top + W_bot
print(torch.allclose(W, W_reconstructed))  # Should return True


False


In [3]:
U, S, Vh = torch.linalg.svd(W_bot, full_matrices=False)

In [4]:
import torch
import torch.nn as nn

# SVD Decomposition Utility
def compute_svd(W, r):
    U, S, Vh = torch.linalg.svd(W, full_matrices=False)
    
    # TOP R
    U_top = U[:, :r]
    S_top = S[:r]
    Vh_top = Vh[:r, :]
    
    # BOTTOM R
    U_bot = U[:, -r:]
    S_bot = S[-r:]
    Vh_bot = Vh[-r:, :]
    
    return U, S, Vh, U_top, S_top, Vh_top, U_bot, S_bot, Vh_bot

In [5]:
m, n = 8, 6
W = torch.randn(m, n)

# Choose r — require r <= min(m, n)
r = 2
p = min(m, n)
assert r <= p, "r must be <= min(m, n)"
# If you want top and bottom to be disjoint, also require 2r <= p

# Call your helper
U, S, Vh, U_top, S_top, Vh_top, U_bot, S_bot, Vh_bot = compute_svd(W, r)

In [6]:
import torch, torch.nn as nn, torch.nn.functional as F
from copy import deepcopy

def truncated_svd(W, rank):
    # thin SVD; rank <= min(out, in)
    U, S, Vh = torch.linalg.svd(W, full_matrices=False)
    r = min(rank, S.numel())
    return U[:, :r].contiguous(), S[:r].contiguous(), Vh[:r, :].contiguous()

def svd_head_tail(W, r):
    # return top-r and bottom-r blocks (disjoint if 2r <= p)
    U, S, Vh = torch.linalg.svd(W, full_matrices=False)
    p = S.numel()
    r_top = min(r, p)
    r_bot = min(r, p - r_top) if (2*r <= p) else min(r, max(0, p - r_top))
    U_top, S_top, Vh_top = U[:, :r_top], S[:r_top], Vh[:r_top, :]
    U_bot, S_bot, Vh_bot = U[:, p-r_bot:], S[p-r_bot:], Vh[p-r_bot:, :]
    return (U_top, S_top, Vh_top), (U_bot, S_bot, Vh_bot)

def replace_module(parent, name, new_mod):
    """Set parent.<name> = new_mod"""
    setattr(parent, name, new_mod)

def iter_kv_linears(model):
    """
    Yield (parent_module, attr_name, linear_module) for all K/V projections.
    Matches common names: key/value, k_proj/v_proj
    """
    for mod_name, mod in model.named_modules():
        for attr in ("key", "value", "k_proj", "v_proj"):
            if hasattr(mod, attr):
                child = getattr(mod, attr)
                if isinstance(child, nn.Linear):
                    yield mod, attr, child


In [7]:
class SaltEdoraLinear(nn.Module):
    """
    Additive adapter around a frozen Linear:
    y = x @ W^T + SALT_top(x) + eDoRA_tail(x)
    - SALT: scale/shift top-r singulars (α, β)
    - eDoRA: r x r core R in tail subspace
    """
    def __init__(self, base_linear: nn.Linear, r: int, tail_mode='free'):
        super().__init__()
        assert isinstance(base_linear, nn.Linear)
        self.base = base_linear
        self.in_features  = base_linear.in_features
        self.out_features = base_linear.out_features
        self.bias = base_linear.bias is not None

        ## Freezing the parameters from the pre trained model
        # freeze the original weights that do not require training
        self.base.weight.requires_grad_(False) 
        # freeze the original bias that do not need training as well
        if self.bias:
            self.base.bias.requires_grad_(False)

        # SVD to provide some singular understanding of the pre trained matrix to guide our 
        W = self.base.weight.detach().to(torch.float32)
        (U_top, S_top, Vh_top), (U_bot, S_bot, Vh_bot) = svd_head_tail(W, r)
        dtype = self.base.weight.dtype
        device = self.base.weight.device

        # store basis as buffers >> saved as like self.U_top >>based on the string that is provided
        self.register_buffer("U_top", U_top.to(dtype).to(device))
        self.register_buffer("S_top", S_top.to(dtype).to(device))
        self.register_buffer("Vh_top", Vh_top.to(dtype).to(device))
        self.register_buffer("U_bot", U_bot.to(dtype).to(device))
        self.register_buffer("S_bot", S_bot.to(dtype).to(device))
        self.register_buffer("Vh_bot", Vh_bot.to(dtype).to(device))

        # Understanding the total ranks of each of the weight matrices (top and tail)
        self.r_top = S_top.numel()
        self.r_bot = S_bot.numel()

        # SALT params >> this is for the scale shift parameters
        self.alpha = nn.Parameter(torch.zeros(self.r_top, dtype=dtype, device=device))
        self.beta  = nn.Parameter(torch.zeros(self.r_top, dtype=dtype, device=device))

        # eDoRA core >> free and polar
        ## free >> the LORA style, no segregation of magnitude and directionality in this case
        ## polar >> the DORA style, clear seegregation of magnitude and directionality
        self.tail_mode = tail_mode
        if self.r_bot > 0:
            if tail_mode == 'free':
                R0 = torch.zeros(self.r_bot, self.r_bot, dtype=dtype, device=device)
                self.R = nn.Parameter(R0)  # neutral (no delta)
            else:
                raise ValueError("tail_mode must be 'free'")

    def forward(self, x):
        # base
        y = F.linear(x, self.base.weight, self.base.bias) # y = x W^T + b

        # SALT head delta
        if self.r_top > 0:
            zH = F.linear(x, self.Vh_top) # zH = x V_top^T

            # Computation of the top singular values as scale shifts
            delta_sigma = self.S_top * self.alpha + self.beta  # Δσ = S_top ⊙ α + β
            y = y + F.linear(zH * delta_sigma, self.U_top) # y ← y + U_top diag(Δσ) V_top^T x

        # eDoRA tail delta >> utilising the rxr matrix to run the update 
        if self.r_bot > 0:
            zT = F.linear(x, self.Vh_bot) # zT = x V_bot^T                   
            zR = F.linear(zT, self.R.T) # zR = zT R^T
            y = y + F.linear(zR, self.U_bot)
        return y
        

## LORA


In [8]:
class LoRA(nn.Module):
    def __init__(self, base_linear: nn.Linear, r: int):
        super().__init__()
        self.base = base_linear
        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features

        # Low-rank matrices for LoRA
        self.A = nn.Parameter(torch.randn(self.in_features, r))  # Rank r
        self.B = nn.Parameter(torch.randn(r, self.out_features))  # Rank r
        
    def forward(self, x):
        # Base forward pass
        return F.linear(x, self.base.weight, self.base.bias) + F.linear(x, self.A @ self.B)

## DORA

In [9]:
class DoRA(nn.Module):
    def __init__(self, base_linear: nn.Linear, r: int):
        super().__init__()
        self.base = base_linear
        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features

        # Magnitude parameter (scalar)
        self.magnitude = nn.Parameter(torch.zeros(self.in_features))

        # Directionality vector (unit vector)
        self.direction = nn.Parameter(torch.randn(self.in_features, self.out_features))

    def forward(self, x):
        # Normalize the directionality vector to have unit norm
        direction_norm = F.normalize(self.direction, p=2, dim=0)
        
        # Magnitude scaling with directionality
        delta_w = self.magnitude.unsqueeze(0) * direction_norm
        return F.linear(x, self.base.weight, self.base.bias) + F.linear(x, delta_w)

## SALT

In [18]:
class SALT(nn.Module):
    def __init__(self, base_linear: nn.Linear, r: int):
        super().__init__()
        self.base = base_linear
        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features
        
        # Perform SVD on the original weight matrix (frozen)
        W = self.base.weight.detach().to(torch.float32)
        U, S, Vh = torch.svd(W)
        
        # Store SVD components as buffers
        self.register_buffer("U", U)
        self.register_buffer("S", S)
        self.register_buffer("Vh", Vh)
        
        # Define SALT parameters: alpha (scaling) and beta (shifting)
        self.alpha = nn.Parameter(torch.zeros(r, dtype=torch.float32))
        self.beta = nn.Parameter(torch.zeros(r, dtype=torch.float32))
        
    def forward(self, x):
        # Apply scaling and shifting to the top singular values
        delta_sigma = self.S[:self.alpha.size(0)] * self.alpha + self.beta
        return F.linear(x, self.base.weight, self.base.bias) + F.linear(x, self.U[:, :self.alpha.size(0)] @ torch.diag(delta_sigma) @ self.Vh[:self.alpha.size(0), :])


## Replacement of linear layers with the new implementation

In [None]:
def replace_qkv_with_adapter(model, r=8, mode="lora"):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear) and ("query" in name or "key" in name or "value" in name):
            if mode == "lora":
                setattr(model, name, LoRA(module, r=r))
            elif mode == "dora":
                setattr(model, name, DoRA(module, r=r))
            elif mode == "saltedora":
                setattr(model, name, SaltEdoraLinear(module, r=r, tail_mode='free'))
            elif mode == "salt":
                setattr(model, name, SALT(module, r=r))
        else:
            replace_qkv_with_adapter(module, r=r, mode=mode)
    return model

In [14]:
from transformers import AutoModelForSequenceClassification
import torch.nn as nn
from transformers import AutoTokenizer
from datasets import load_dataset
from transformers import TrainingArguments, Trainer
import evaluate

# Load pretrained BERT model
model_name = "bert-base-uncased"
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Replace Q, K, and V layers
model = replace_qkv_with_adapter(base_model, 8, 'lora')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successf

In [None]:
# Load SST-2 dataset
dataset = load_dataset("glue", "sst2")

# Load tokenizer for BERT
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Tokenize the dataset
def tokenize_fn(batch):
    return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=128)

tokenized = dataset.map(tokenize_fn, batched=True)
tokenized = tokenized.rename_column("label", "labels")
tokenized.set_format("torch")

train_ds = tokenized["train"]
val_ds = tokenized["validation"]

# Evaluation metric
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    return accuracy.compute(predictions=preds, references=labels)


Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [None]:
# Training Arguments
args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    save_strategy="epoch",
)

# Define trainer
def train_and_evaluate(model, train_ds, val_ds):
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    results = trainer.evaluate(val_ds)
    return results


In [None]:
# Load base model (BERT)
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Experiment 1: LoRA
model_lora = replace_qkv_with_adapter(base_model, r=8, mode="lora")
results_lora = train_and_evaluate(model_lora, train_ds, val_ds)
print("LoRA Results:", results_lora)

# Experiment 2: DoRA
model_dora = replace_qkv_with_adapter(base_model, r=8, mode="dora")
results_dora = train_and_evaluate(model_dora, train_ds, val_ds)
print("DoRA Results:", results_dora)

# Experiment 3: SaltEdoraLinear (Your custom implementation)
model_saltedora = replace_qkv_with_adapter(base_model, r=8, mode="saltedora")
results_saltedora = train_and_evaluate(model_saltedora, train_ds, val_ds)
print("SaltEdoraLinear Results:", results_saltedora)

# Experiment 4: SALT
model_salt = replace_qkv_with_adapter(base_model, r=8, mode="salt")
results_salt = train_and_evaluate(model_salt, train_ds, val_ds)
print("SALT Results:", results_salt)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successful Implementation for LORA: query
Successful Implementation for LORA: key
Successful Implementation for LORA: value
Successf

  0%|          | 0/12630 [00:00<?, ?it/s]

KeyboardInterrupt: 