In [1]:
# import modules
from datasets import load_dataset 
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertForSequenceClassification
import torch
from tqdm import tqdm
import pickle
from torch.utils.data import DataLoader, Sampler, BatchSampler
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch import nn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
import random

ideas
- show that gradients become more similar after gram Schmidt orthogonalization

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# load model and tokenizer
model_orig = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', device_map="cuda")
model_diff = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', device_map="cuda")
#model_gib = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
dataset = load_dataset(
    "json",
    data_files={
      "train": "data/mrpc_gibberish_train_filtered.json",
      "test":  "data/mrpc_gibberish_test_filtered.json"
    },
    field=None
)

# Define an encoding function that preserves the gibberish fields.
def encode(examples):
    tokenized = tokenizer(
        examples["sentence1"],
        examples["sentence2"],
        truncation=True,
        padding="max_length"
    )
    # Ensure we preserve the original gibberish fields.
    tokenized["sentence1_gibberish"] = examples["sentence1_gibberish"]
    tokenized["sentence2_gibberish"] = examples["sentence2_gibberish"]
    return tokenized

# Apply encoding with batched=True (this will preserve keys if you don't remove them).
dataset = dataset.map(encode, batched=True)

# When mapping the labels, explicitly not remove any columns.
dataset = dataset.map(lambda examples: {"labels": examples["label"]}, batched=True, remove_columns=[])

In [5]:
dataset['train'].set_format(
    type='torch', 
    columns=[
        'input_ids', 'token_type_ids', 'attention_mask', 'labels',
        'sentence1_gibberish', 'sentence2_gibberish'
    ]
)
dataset['test'].set_format(
    type='torch', 
    columns=[
        'input_ids', 'token_type_ids', 'attention_mask', 'labels'
    ]
)

In [6]:
def gram_schmidt(v_gib, v_info):
    len_orig = torch.pow(torch.dot(v_info, v_info), .5)
    len_gib = torch.pow(torch.dot(v_gib, v_gib), .5)
    projection = (torch.dot(v_info, v_gib) / torch.dot(v_gib, v_gib)) * v_gib
    u2 = v_info - projection
    u3 = u2 * (len_orig / len_gib)
    return u3

In [7]:
class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size):
        assert batch_size % 2 == 0, "Batch size must be even for two classes"
        self.labels = labels
        self.bs = batch_size
        # split indices by class
        self.class_indices = {
            0: [i for i, lab in enumerate(labels) if lab == 0],
            1: [i for i, lab in enumerate(labels) if lab == 1],
        }
        self.num_batches = min(len(self.class_indices[0]), len(self.class_indices[1])) * 2 // batch_size

    def __iter__(self):
        # shuffle each class’s indices
        for idx_list in self.class_indices.values():
            random.shuffle(idx_list)
        # yield batch_count batches
        for i in range(self.num_batches):
            half = self.bs // 2
            # take a slice from each class
            start = i * half
            batch = (
                self.class_indices[0][start:start+half] +
                self.class_indices[1][start:start+half]
            )
            random.shuffle(batch)
            yield batch

    def __len__(self):
        return self.num_batches

# usage:
batch_size = 4 # batch size of 6 is already taking up too much VRAM
labels = dataset["train"]["labels"]
balanced_sampler = BalancedBatchSampler(labels, batch_size=batch_size)


In [8]:
def my_collate(batch):
    return {
        'input_ids':      torch.stack([item['input_ids']      for item in batch]), # token IDs like: [ CLS ]  tokens_of_sentence1  [ SEP ]  tokens_of_sentence2  [ SEP ]  padding… 
        'token_type_ids': torch.stack([item['token_type_ids'] for item in batch]), # whether the token belonged to sentence1 or sentence2
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]), # marks which positions are padding vs. real tokens
        'labels':         torch.tensor([item['labels']         for item in batch]), # prediction labels (0, 1)
        'sentence1_gibberish': [item.get('sentence1_gibberish', []) for item in batch],
        'sentence2_gibberish': [item.get('sentence2_gibberish', []) for item in batch],
    }

# 4) Prepare DataLoaders
train_loader = DataLoader(
    dataset['train'], batch_sampler=balanced_sampler, collate_fn=my_collate
)
""" batch_size = 4
train_loader = DataLoader(
    dataset['train'], batch_size=batch_size, shuffle=True, collate_fn=my_collate
) """

test_loader = DataLoader(
    dataset['test'], batch_size=batch_size, shuffle=False,
    collate_fn=lambda b: {
        'input_ids': torch.stack([x['input_ids'] for x in b]),
        'token_type_ids': torch.stack([x['token_type_ids'] for x in b]),
        'attention_mask': torch.stack([x['attention_mask'] for x in b]),
        'labels': torch.tensor([x['labels'] for x in b]),
    }
)

In [9]:
optimizer_orig = torch.optim.AdamW(model_orig.parameters(), lr=1e-5)
optimizer_diff = torch.optim.AdamW(model_diff.parameters(), lr=1e-5)

check for potential errors in the logic: e.g. 
- is the model updated on the correct gradient vector?
- is the Gram-Schmidt process calculated correctly?

In [10]:
""" # 6) Training loop on train_loader
num_epochs = 3
for epoch in range(num_epochs):
    model_orig.train(); model_diff.train()
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}"), start=1):
        for k, v in batch.items(): # extracts keys (e.g. 'input_ids') and values (actual tokens) from batch
            if torch.is_tensor(v): batch[k] = v.to(device) # moves value to GPU, if it is a token

        # unpack
        input_ids = batch['input_ids']
        token_type_ids = batch['token_type_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        gib1_all = batch['sentence1_gibberish']
        gib2_all = batch['sentence2_gibberish']
        num_variants = len(gib1_all[0])

        # Pre-tokenize gibberish variants
        gib_inputs = []
        for j in range(num_variants):
            s1 = [g[j] for g in gib1_all]
            s2 = [g[j] for g in gib2_all]
            enc = tokenizer(s1, s2, truncation=True, padding='max_length', return_tensors='pt')
            gib_inputs.append({k: v.to(device) for k, v in enc.items()})

        # A) Original pass
        optimizer_orig.zero_grad()
        out_orig = model_orig(input_ids=input_ids,
                              token_type_ids=token_type_ids,
                              attention_mask=attention_mask,
                              labels=labels)
        loss_orig = out_orig.loss
        loss_orig.backward()
        optimizer_orig.step()
        orig_grads = [p.grad.detach().cpu().clone() for p in model_orig.parameters()]

        # B) Gibberish pass
        optimizer_orig.zero_grad()
        gib_loss = torch.tensor(0.0, device=device)
        for enc in gib_inputs:
            out = model_orig(**enc, labels=labels)
            gib_loss += out.loss
        gib_loss /= num_variants
        gib_loss.backward()
        gib_grads = [p.grad.detach().cpu().clone() for p in model_orig.parameters()]

        # C) Compute orthogonal vector
        orig_vec = parameters_to_vector(orig_grads)
        gib_vec = parameters_to_vector(gib_grads)
        v_orth = gram_schmidt(gib_vec, orig_vec)
        #v_diff = orig_vec - orig_vec

        # D) Update model_diff

        optimizer_diff.zero_grad() # clear old grads
        # unflatten v_diff → p.grad for each parameter p
        pointer = 0
        for p in model_diff.parameters():
            if not p.requires_grad:
                continue
            numel = p.numel()
            p.grad = v_orth[pointer:pointer+numel].view_as(p).to(device)
            pointer += numel
        optimizer_diff.step()


        if batch_idx % 5 == 0:
            print(f"[Epoch {epoch+1} | Batch {batch_idx}] "
                  f"orig_loss = {loss_orig.item():.4f}, gib_loss = {gib_loss.item():.4f}")


    print(f"Epoch {epoch+1} complete")

model_orig.save_pretrained("trained_model_original")
model_diff.save_pretrained("trained_model_gradient_diff") """

' # 6) Training loop on train_loader\nnum_epochs = 3\nfor epoch in range(num_epochs):\n    model_orig.train(); model_diff.train()\n    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}"), start=1):\n        for k, v in batch.items(): # extracts keys (e.g. \'input_ids\') and values (actual tokens) from batch\n            if torch.is_tensor(v): batch[k] = v.to(device) # moves value to GPU, if it is a token\n\n        # unpack\n        input_ids = batch[\'input_ids\']\n        token_type_ids = batch[\'token_type_ids\']\n        attention_mask = batch[\'attention_mask\']\n        labels = batch[\'labels\']\n        gib1_all = batch[\'sentence1_gibberish\']\n        gib2_all = batch[\'sentence2_gibberish\']\n        num_variants = len(gib1_all[0])\n\n        # Pre-tokenize gibberish variants\n        gib_inputs = []\n        for j in range(num_variants):\n            s1 = [g[j] for g in gib1_all]\n            s2 = [g[j] for g in gib2_all]\n            enc = tokeniz

note: model seems to be 100% sure, that every pair is not the same

In [11]:
# 6) Training loop on train_loader
num_epochs = 3
for epoch in range(num_epochs):
    model_orig.train(); model_diff.train()
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}"), start=1):
        for k, v in batch.items(): # extracts keys (e.g. 'input_ids') and values (actual tokens) from batch
            if torch.is_tensor(v): batch[k] = v.to(device) # moves value to GPU, if it is a token

        # unpack
        input_ids = batch['input_ids']
        token_type_ids = batch['token_type_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        gib1_all = batch['sentence1_gibberish']
        gib2_all = batch['sentence2_gibberish']
        num_variants = len(gib1_all[0])

        # Pre-tokenize gibberish variants
        gib_inputs = []
        for j in range(num_variants):
            s1 = [g[j] for g in gib1_all]
            s2 = [g[j] for g in gib2_all]
            enc = tokenizer(s1, s2, truncation=True, padding='max_length', return_tensors='pt')
            gib_inputs.append({k: v.to(device) for k, v in enc.items()})

        # A) Original pass
        optimizer_orig.zero_grad()
        out_orig = model_orig(input_ids=input_ids,
                              token_type_ids=token_type_ids,
                              attention_mask=attention_mask,
                              labels=labels)
        loss_orig = out_orig.loss
        loss_orig.backward()
        optimizer_orig.step()
        #orig_grads = [p.grad.detach().cpu().clone() for p in model_orig.parameters()]


        # B) Gibberish pass
        # B.1 pass for original data
        optimizer_diff.zero_grad()
        out_orig_model_diff = model_diff(input_ids=input_ids,
                                            token_type_ids=token_type_ids,
                                            attention_mask=attention_mask,
                                            labels=labels)
        loss_orig_model_diff = out_orig_model_diff.loss
        loss_orig_model_diff.backward()
        orig_grads_model_diff = [p.grad.detach().cpu().clone() for p in model_diff.parameters()]

        # B.2 pass for gibberish data
        optimizer_diff.zero_grad()
        gib_loss = torch.tensor(0.0, device=device)
        for enc in gib_inputs:
            out = model_diff(**enc, labels=labels)
            gib_loss += out.loss
        gib_loss /= num_variants
        gib_loss.backward()
        gib_grads = [p.grad.detach().cpu().clone() for p in model_diff.parameters()]

        # C) Compute orthogonal vector
        orig_vec_model_diff = parameters_to_vector(orig_grads_model_diff)
        gib_vec = parameters_to_vector(gib_grads)
        v_orth = gram_schmidt(gib_vec, orig_vec_model_diff)
        #v_diff = orig_vec - orig_vec

        # D) Update model_diff
        optimizer_diff.zero_grad() # clear old grads
        # unflatten v_orth → p.grad for each parameter p
        pointer = 0
        for p in model_diff.parameters():
            if not p.requires_grad:
                continue
            numel = p.numel()
            p.grad = v_orth[pointer:pointer+numel].view_as(p).to(device)
            pointer += numel
        optimizer_diff.step()


        if batch_idx % 5 == 0:
            print(f"[Epoch {epoch+1} | Batch {batch_idx}] "
                  f"orig_loss = {loss_orig.item():.4f}, gib_loss = {gib_loss.item():.4f}, labels: {labels[:]}")


    print(f"Epoch {epoch+1} complete")

model_orig.save_pretrained("trained_model_original_2")
model_diff.save_pretrained("trained_model_gradient_orth_2")

Epoch 1:   5%|▍         | 5/106 [00:27<08:58,  5.33s/it]

[Epoch 1 | Batch 5] orig_loss = 0.6323, gib_loss = 0.7383, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 1:   9%|▉         | 10/106 [00:51<07:38,  4.78s/it]

[Epoch 1 | Batch 10] orig_loss = 0.8053, gib_loss = 0.7138, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 1:  14%|█▍        | 15/106 [01:15<07:10,  4.73s/it]

[Epoch 1 | Batch 15] orig_loss = 0.7125, gib_loss = 0.7186, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 1:  19%|█▉        | 20/106 [01:41<07:23,  5.16s/it]

[Epoch 1 | Batch 20] orig_loss = 0.6660, gib_loss = 0.7023, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 1:  24%|██▎       | 25/106 [02:06<06:32,  4.84s/it]

[Epoch 1 | Batch 25] orig_loss = 0.7219, gib_loss = 0.6866, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 1:  28%|██▊       | 30/106 [02:31<06:21,  5.02s/it]

[Epoch 1 | Batch 30] orig_loss = 0.6593, gib_loss = 0.7290, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 1:  33%|███▎      | 35/106 [02:55<05:53,  4.98s/it]

[Epoch 1 | Batch 35] orig_loss = 0.7613, gib_loss = 0.7208, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 1:  38%|███▊      | 40/106 [03:20<05:26,  4.95s/it]

[Epoch 1 | Batch 40] orig_loss = 0.7249, gib_loss = 0.7091, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 1:  42%|████▏     | 45/106 [03:44<04:51,  4.77s/it]

[Epoch 1 | Batch 45] orig_loss = 0.7078, gib_loss = 0.7144, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 1:  47%|████▋     | 50/106 [04:11<04:42,  5.04s/it]

[Epoch 1 | Batch 50] orig_loss = 0.6535, gib_loss = 0.7348, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 1:  52%|█████▏    | 55/106 [04:36<04:15,  5.01s/it]

[Epoch 1 | Batch 55] orig_loss = 0.7362, gib_loss = 0.6847, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 1:  57%|█████▋    | 60/106 [05:01<03:53,  5.08s/it]

[Epoch 1 | Batch 60] orig_loss = 0.6557, gib_loss = 0.6931, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 1:  61%|██████▏   | 65/106 [05:27<03:38,  5.33s/it]

[Epoch 1 | Batch 65] orig_loss = 0.6149, gib_loss = 0.7477, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 1:  66%|██████▌   | 70/106 [05:52<03:01,  5.05s/it]

[Epoch 1 | Batch 70] orig_loss = 0.6049, gib_loss = 0.7024, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 1:  71%|███████   | 75/106 [06:17<02:34,  5.00s/it]

[Epoch 1 | Batch 75] orig_loss = 0.6635, gib_loss = 0.7244, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 1:  75%|███████▌  | 80/106 [06:41<02:04,  4.80s/it]

[Epoch 1 | Batch 80] orig_loss = 0.7032, gib_loss = 0.7442, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 1:  80%|████████  | 85/106 [07:05<01:39,  4.74s/it]

[Epoch 1 | Batch 85] orig_loss = 0.5982, gib_loss = 0.6998, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 1:  85%|████████▍ | 90/106 [07:29<01:16,  4.79s/it]

[Epoch 1 | Batch 90] orig_loss = 0.6702, gib_loss = 0.6349, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 1:  90%|████████▉ | 95/106 [07:53<00:53,  4.82s/it]

[Epoch 1 | Batch 95] orig_loss = 0.5417, gib_loss = 0.6860, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 1:  94%|█████████▍| 100/106 [08:17<00:28,  4.81s/it]

[Epoch 1 | Batch 100] orig_loss = 0.6549, gib_loss = 0.7727, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 1:  99%|█████████▉| 105/106 [08:41<00:04,  4.80s/it]

[Epoch 1 | Batch 105] orig_loss = 0.6966, gib_loss = 0.7280, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 1: 100%|██████████| 106/106 [08:46<00:00,  4.96s/it]


Epoch 1 complete


Epoch 2:   5%|▍         | 5/106 [00:23<08:04,  4.80s/it]

[Epoch 2 | Batch 5] orig_loss = 0.7060, gib_loss = 0.6820, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 2:   9%|▉         | 10/106 [00:47<07:31,  4.70s/it]

[Epoch 2 | Batch 10] orig_loss = 0.7240, gib_loss = 0.8147, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 2:  14%|█▍        | 15/106 [01:10<07:09,  4.72s/it]

[Epoch 2 | Batch 15] orig_loss = 0.6826, gib_loss = 0.7620, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 2:  19%|█▉        | 20/106 [01:34<06:56,  4.84s/it]

[Epoch 2 | Batch 20] orig_loss = 0.6282, gib_loss = 0.7129, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 2:  24%|██▎       | 25/106 [01:59<06:28,  4.79s/it]

[Epoch 2 | Batch 25] orig_loss = 0.5563, gib_loss = 0.6589, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 2:  28%|██▊       | 30/106 [02:22<06:04,  4.80s/it]

[Epoch 2 | Batch 30] orig_loss = 0.6847, gib_loss = 0.8357, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 2:  33%|███▎      | 35/106 [02:48<06:00,  5.08s/it]

[Epoch 2 | Batch 35] orig_loss = 0.6106, gib_loss = 0.7924, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 2:  38%|███▊      | 40/106 [03:13<05:26,  4.95s/it]

[Epoch 2 | Batch 40] orig_loss = 0.6151, gib_loss = 0.7262, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 2:  42%|████▏     | 45/106 [03:39<05:30,  5.42s/it]

[Epoch 2 | Batch 45] orig_loss = 0.5511, gib_loss = 0.6906, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 2:  47%|████▋     | 50/106 [04:07<04:58,  5.33s/it]

[Epoch 2 | Batch 50] orig_loss = 0.3835, gib_loss = 0.8158, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 2:  52%|█████▏    | 55/106 [04:33<04:30,  5.31s/it]

[Epoch 2 | Batch 55] orig_loss = 0.4003, gib_loss = 0.9217, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 2:  57%|█████▋    | 60/106 [05:04<04:28,  5.84s/it]

[Epoch 2 | Batch 60] orig_loss = 0.6665, gib_loss = 0.8620, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 2:  61%|██████▏   | 65/106 [05:29<03:33,  5.20s/it]

[Epoch 2 | Batch 65] orig_loss = 0.9601, gib_loss = 1.0249, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 2:  66%|██████▌   | 70/106 [05:55<03:06,  5.19s/it]

[Epoch 2 | Batch 70] orig_loss = 0.5189, gib_loss = 0.8480, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 2:  71%|███████   | 75/106 [06:20<02:31,  4.90s/it]

[Epoch 2 | Batch 75] orig_loss = 0.9064, gib_loss = 0.8030, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 2:  75%|███████▌  | 80/106 [06:45<02:07,  4.91s/it]

[Epoch 2 | Batch 80] orig_loss = 0.5509, gib_loss = 0.7753, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 2:  80%|████████  | 85/106 [07:11<01:45,  5.03s/it]

[Epoch 2 | Batch 85] orig_loss = 1.1013, gib_loss = 0.8141, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 2:  85%|████████▍ | 90/106 [07:35<01:17,  4.86s/it]

[Epoch 2 | Batch 90] orig_loss = 0.4250, gib_loss = 0.7508, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 2:  90%|████████▉ | 95/106 [07:59<00:52,  4.77s/it]

[Epoch 2 | Batch 95] orig_loss = 0.6422, gib_loss = 0.6310, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 2:  94%|█████████▍| 100/106 [08:23<00:28,  4.83s/it]

[Epoch 2 | Batch 100] orig_loss = 0.5634, gib_loss = 0.7189, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 2:  99%|█████████▉| 105/106 [08:47<00:04,  4.82s/it]

[Epoch 2 | Batch 105] orig_loss = 0.6642, gib_loss = 0.8595, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 2: 100%|██████████| 106/106 [08:52<00:00,  5.02s/it]


Epoch 2 complete


Epoch 3:   5%|▍         | 5/106 [00:26<09:07,  5.43s/it]

[Epoch 3 | Batch 5] orig_loss = 0.4269, gib_loss = 0.6832, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:   9%|▉         | 10/106 [00:50<07:44,  4.84s/it]

[Epoch 3 | Batch 10] orig_loss = 0.4033, gib_loss = 0.9540, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 3:  14%|█▍        | 15/106 [01:14<07:21,  4.85s/it]

[Epoch 3 | Batch 15] orig_loss = 0.3676, gib_loss = 0.7909, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 3:  19%|█▉        | 20/106 [01:38<06:49,  4.77s/it]

[Epoch 3 | Batch 20] orig_loss = 0.3572, gib_loss = 0.6016, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 3:  24%|██▎       | 25/106 [02:01<06:21,  4.71s/it]

[Epoch 3 | Batch 25] orig_loss = 0.2336, gib_loss = 0.8255, labels: tensor([0, 0, 1, 1], device='cuda:0')


Epoch 3:  28%|██▊       | 30/106 [02:25<06:03,  4.79s/it]

[Epoch 3 | Batch 30] orig_loss = 0.2843, gib_loss = 0.9713, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:  33%|███▎      | 35/106 [02:49<05:36,  4.74s/it]

[Epoch 3 | Batch 35] orig_loss = 0.4597, gib_loss = 1.0472, labels: tensor([1, 0, 0, 1], device='cuda:0')


Epoch 3:  38%|███▊      | 40/106 [03:13<05:19,  4.84s/it]

[Epoch 3 | Batch 40] orig_loss = 0.7253, gib_loss = 1.0565, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:  42%|████▏     | 45/106 [03:39<05:03,  4.97s/it]

[Epoch 3 | Batch 45] orig_loss = 0.9045, gib_loss = 0.9498, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:  47%|████▋     | 50/106 [04:04<04:35,  4.93s/it]

[Epoch 3 | Batch 50] orig_loss = 0.1891, gib_loss = 0.7889, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 3:  52%|█████▏    | 55/106 [04:28<04:05,  4.82s/it]

[Epoch 3 | Batch 55] orig_loss = 0.4914, gib_loss = 0.8314, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 3:  57%|█████▋    | 60/106 [04:53<03:46,  4.92s/it]

[Epoch 3 | Batch 60] orig_loss = 0.3435, gib_loss = 1.0871, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 3:  61%|██████▏   | 65/106 [05:17<03:19,  4.86s/it]

[Epoch 3 | Batch 65] orig_loss = 0.4184, gib_loss = 1.0299, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:  66%|██████▌   | 70/106 [05:42<02:55,  4.89s/it]

[Epoch 3 | Batch 70] orig_loss = 0.8542, gib_loss = 1.3246, labels: tensor([1, 1, 0, 0], device='cuda:0')


Epoch 3:  71%|███████   | 75/106 [06:08<02:39,  5.16s/it]

[Epoch 3 | Batch 75] orig_loss = 0.5793, gib_loss = 1.2910, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 3:  75%|███████▌  | 80/106 [06:34<02:22,  5.48s/it]

[Epoch 3 | Batch 80] orig_loss = 0.5228, gib_loss = 1.0546, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 3:  80%|████████  | 85/106 [07:00<01:47,  5.11s/it]

[Epoch 3 | Batch 85] orig_loss = 0.5033, gib_loss = 1.1962, labels: tensor([1, 0, 1, 0], device='cuda:0')


Epoch 3:  85%|████████▍ | 90/106 [07:26<01:25,  5.33s/it]

[Epoch 3 | Batch 90] orig_loss = 0.4181, gib_loss = 1.1734, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:  90%|████████▉ | 95/106 [07:52<00:56,  5.12s/it]

[Epoch 3 | Batch 95] orig_loss = 0.6918, gib_loss = 1.2003, labels: tensor([0, 1, 0, 1], device='cuda:0')


Epoch 3:  94%|█████████▍| 100/106 [08:17<00:29,  4.95s/it]

[Epoch 3 | Batch 100] orig_loss = 0.3919, gib_loss = 0.6420, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 3:  99%|█████████▉| 105/106 [08:43<00:05,  5.03s/it]

[Epoch 3 | Batch 105] orig_loss = 0.3217, gib_loss = 0.5815, labels: tensor([0, 1, 1, 0], device='cuda:0')


Epoch 3: 100%|██████████| 106/106 [08:48<00:00,  4.99s/it]


Epoch 3 complete


In [12]:
# 7) Evaluation on test_loader
def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(
                input_ids=batch['input_ids'],
                token_type_ids=batch['token_type_ids'],
                attention_mask=batch['attention_mask']
            ).logits
            preds = torch.argmax(logits, dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(batch['labels'].cpu().tolist())
    return all_labels, all_preds

model_orig = BertForSequenceClassification.from_pretrained("./trained_model_original")
model_orig.to(device)

model_diff = BertForSequenceClassification.from_pretrained("./trained_model_gradient_diff")
model_diff.to(device)

# Compute metrics for each model
labels_o, preds_o = evaluate(model_orig, test_loader)
labels_d, preds_d = evaluate(model_diff, test_loader)

print("Original Model Metrics:")
print(classification_report(labels_o, preds_o, digits=4))

print("Orthogonal trained Model Metrics:")
print(classification_report(labels_d, preds_d, digits=4))

Original Model Metrics:
              precision    recall  f1-score   support

           0     0.8000    0.3448    0.4819        58
           1     0.7286    0.9533    0.8259       107

    accuracy                         0.7394       165
   macro avg     0.7643    0.6490    0.6539       165
weighted avg     0.7537    0.7394    0.7050       165

Orthogonal trained Model Metrics:
              precision    recall  f1-score   support

           0     0.3515    1.0000    0.5202        58
           1     0.0000    0.0000    0.0000       107

    accuracy                         0.3515       165
   macro avg     0.1758    0.5000    0.2601       165
weighted avg     0.1236    0.3515    0.1829       165



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


thoughts: the models learns at least in some direction, as it always predicts the minority class. Usually, it is not incentivized to do that

In [13]:
# 6) Training loop on train_loader
num_epochs = 3
for epoch in range(num_epochs):
    model_orig.train(); model_diff.train()
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}"), start=1):
        for k, v in batch.items(): # extracts keys (e.g. 'input_ids') and values (actual tokens) from batch
            if torch.is_tensor(v): batch[k] = v.to(device) # moves value to GPU, if it is a token

        # unpack
        input_ids = batch['input_ids']
        token_type_ids = batch['token_type_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        gib1_all = batch['sentence1_gibberish']
        gib2_all = batch['sentence2_gibberish']
        num_variants = len(gib1_all[0])

        # Pre-tokenize gibberish variants
        gib_inputs = []
        for j in range(num_variants):
            s1 = [g[j] for g in gib1_all]
            s2 = [g[j] for g in gib2_all]
            enc = tokenizer(s1, s2, truncation=True, padding='max_length', return_tensors='pt')
            gib_inputs.append({k: v.to(device) for k, v in enc.items()})

        # A) Original pass
        optimizer_orig.zero_grad()
        out_orig = model_orig(input_ids=input_ids,
                              token_type_ids=token_type_ids,
                              attention_mask=attention_mask,
                              labels=labels)
        loss_orig = out_orig.loss
        loss_orig.backward()
        optimizer_orig.step()
        #orig_grads = [p.grad.detach().cpu().clone() for p in model_orig.parameters()]


        # B) Gibberish pass
        # B.1 pass for original data
        optimizer_diff.zero_grad()
        out_orig_model_diff = model_diff(input_ids=input_ids,
                                            token_type_ids=token_type_ids,
                                            attention_mask=attention_mask,
                                            labels=labels)
        loss_orig_model_diff = out_orig_model_diff.loss
        loss_orig_model_diff.backward()
        orig_grads_model_diff = [p.grad.detach().cpu().clone() for p in model_diff.parameters()]

        # B.2 pass for gibberish data
        optimizer_diff.zero_grad()
        gib_loss = torch.tensor(0.0, device=device)
        for enc in gib_inputs:
            out = model_diff(**enc, labels=labels)
            gib_loss += out.loss
        gib_loss /= num_variants
        gib_loss.backward()
        gib_grads = [p.grad.detach().cpu().clone() for p in model_diff.parameters()]

        # C) Compute orthogonal vector
        orig_vec_model_diff = parameters_to_vector(orig_grads_model_diff)
        gib_vec = parameters_to_vector(gib_grads)
        v_orth = gram_schmidt(gib_vec, orig_vec_model_diff)
        #v_diff = orig_vec - orig_vec

        # D) Update model_diff
        optimizer_diff.zero_grad() # clear old grads
        # unflatten v_orth → p.grad for each parameter p
        pointer = 0
        for p in model_diff.parameters():
            if not p.requires_grad:
                continue
            numel = p.numel()
            p.grad = v_orth[pointer:pointer+numel].view_as(p).to(device)
            pointer += numel
        optimizer_diff.step()


        if batch_idx % 5 == 0:
            print(f"[Epoch {epoch+1} | Batch {batch_idx}] "
                  f"orig_loss = {loss_orig.item():.4f}, gib_loss = {gib_loss.item():.4f}")


    print(f"Epoch {epoch+1} complete")

model_orig.save_pretrained("trained_model_original_2")
model_diff.save_pretrained("trained_model_gradient_orth_2")

Epoch 1:   5%|▍         | 5/106 [00:24<08:09,  4.84s/it]

[Epoch 1 | Batch 5] orig_loss = 0.6867, gib_loss = 0.8929


Epoch 1:   9%|▉         | 10/106 [00:53<08:34,  5.35s/it]

[Epoch 1 | Batch 10] orig_loss = 0.0779, gib_loss = 0.8759


Epoch 1:  14%|█▍        | 15/106 [01:17<07:26,  4.90s/it]

[Epoch 1 | Batch 15] orig_loss = 0.0631, gib_loss = 0.8751


Epoch 1:  19%|█▉        | 20/106 [01:43<07:21,  5.14s/it]

[Epoch 1 | Batch 20] orig_loss = 0.0870, gib_loss = 0.8773


Epoch 1:  24%|██▎       | 25/106 [02:08<06:50,  5.07s/it]

[Epoch 1 | Batch 25] orig_loss = 0.0744, gib_loss = 0.8650


Epoch 1:  28%|██▊       | 30/106 [02:33<06:13,  4.92s/it]

[Epoch 1 | Batch 30] orig_loss = 0.1191, gib_loss = 0.8751


Epoch 1:  33%|███▎      | 35/106 [02:57<05:43,  4.84s/it]

[Epoch 1 | Batch 35] orig_loss = 0.2665, gib_loss = 0.8509


Epoch 1:  38%|███▊      | 40/106 [03:22<05:39,  5.15s/it]

[Epoch 1 | Batch 40] orig_loss = 0.1212, gib_loss = 0.8924


Epoch 1:  42%|████▏     | 45/106 [03:47<05:05,  5.00s/it]

[Epoch 1 | Batch 45] orig_loss = 0.0843, gib_loss = 0.8786


Epoch 1:  47%|████▋     | 50/106 [04:12<04:44,  5.08s/it]

[Epoch 1 | Batch 50] orig_loss = 0.0673, gib_loss = 0.8770


Epoch 1:  52%|█████▏    | 55/106 [04:37<04:06,  4.83s/it]

[Epoch 1 | Batch 55] orig_loss = 0.7216, gib_loss = 0.8834


Epoch 1:  57%|█████▋    | 60/106 [05:02<03:43,  4.86s/it]

[Epoch 1 | Batch 60] orig_loss = 0.0623, gib_loss = 0.9109


Epoch 1:  61%|██████▏   | 65/106 [05:27<03:18,  4.84s/it]

[Epoch 1 | Batch 65] orig_loss = 0.0694, gib_loss = 0.8846


Epoch 1:  66%|██████▌   | 70/106 [05:51<02:56,  4.91s/it]

[Epoch 1 | Batch 70] orig_loss = 0.1864, gib_loss = 0.8662


Epoch 1:  71%|███████   | 75/106 [06:15<02:30,  4.85s/it]

[Epoch 1 | Batch 75] orig_loss = 0.1820, gib_loss = 0.8645


Epoch 1:  75%|███████▌  | 80/106 [06:40<02:09,  4.99s/it]

[Epoch 1 | Batch 80] orig_loss = 0.1021, gib_loss = 0.8968


Epoch 1:  80%|████████  | 85/106 [07:08<01:59,  5.70s/it]

[Epoch 1 | Batch 85] orig_loss = 0.6820, gib_loss = 0.8737


Epoch 1:  85%|████████▍ | 90/106 [07:34<01:21,  5.12s/it]

[Epoch 1 | Batch 90] orig_loss = 0.0648, gib_loss = 0.8973


Epoch 1:  90%|████████▉ | 95/106 [07:58<00:53,  4.88s/it]

[Epoch 1 | Batch 95] orig_loss = 0.7379, gib_loss = 0.8934


Epoch 1:  94%|█████████▍| 100/106 [08:22<00:28,  4.80s/it]

[Epoch 1 | Batch 100] orig_loss = 0.0666, gib_loss = 0.8575


Epoch 1:  99%|█████████▉| 105/106 [08:45<00:04,  4.72s/it]

[Epoch 1 | Batch 105] orig_loss = 0.1566, gib_loss = 0.8902


Epoch 1: 100%|██████████| 106/106 [08:50<00:00,  5.01s/it]


Epoch 1 complete


Epoch 2:   5%|▍         | 5/106 [00:24<08:13,  4.89s/it]

[Epoch 2 | Batch 5] orig_loss = 0.4170, gib_loss = 0.9030


Epoch 2:   9%|▉         | 10/106 [00:49<08:00,  5.01s/it]

[Epoch 2 | Batch 10] orig_loss = 0.2418, gib_loss = 0.8673


Epoch 2:  14%|█▍        | 15/106 [01:13<07:29,  4.93s/it]

[Epoch 2 | Batch 15] orig_loss = 0.5622, gib_loss = 0.8813


Epoch 2:  19%|█▉        | 20/106 [01:38<06:59,  4.88s/it]

[Epoch 2 | Batch 20] orig_loss = 0.0769, gib_loss = 0.8690


Epoch 2:  24%|██▎       | 25/106 [02:05<07:12,  5.34s/it]

[Epoch 2 | Batch 25] orig_loss = 0.1499, gib_loss = 0.8797


Epoch 2:  28%|██▊       | 30/106 [02:30<06:34,  5.19s/it]

[Epoch 2 | Batch 30] orig_loss = 0.0679, gib_loss = 0.8691


Epoch 2:  33%|███▎      | 35/106 [02:55<05:50,  4.94s/it]

[Epoch 2 | Batch 35] orig_loss = 0.1923, gib_loss = 0.8965


Epoch 2:  38%|███▊      | 40/106 [03:21<05:42,  5.19s/it]

[Epoch 2 | Batch 40] orig_loss = 0.6538, gib_loss = 0.8874


Epoch 2:  42%|████▏     | 45/106 [03:45<04:54,  4.83s/it]

[Epoch 2 | Batch 45] orig_loss = 0.0799, gib_loss = 0.8898


Epoch 2:  47%|████▋     | 50/106 [04:08<04:28,  4.79s/it]

[Epoch 2 | Batch 50] orig_loss = 0.0854, gib_loss = 0.8957


Epoch 2:  52%|█████▏    | 55/106 [04:34<04:21,  5.13s/it]

[Epoch 2 | Batch 55] orig_loss = 0.5794, gib_loss = 0.8919


Epoch 2:  57%|█████▋    | 60/106 [04:58<03:40,  4.79s/it]

[Epoch 2 | Batch 60] orig_loss = 0.4497, gib_loss = 0.8940


Epoch 2:  61%|██████▏   | 65/106 [05:23<03:24,  5.00s/it]

[Epoch 2 | Batch 65] orig_loss = 0.0822, gib_loss = 0.8784


Epoch 2:  66%|██████▌   | 70/106 [05:48<03:03,  5.10s/it]

[Epoch 2 | Batch 70] orig_loss = 0.1147, gib_loss = 0.8573


Epoch 2:  71%|███████   | 75/106 [06:11<02:27,  4.77s/it]

[Epoch 2 | Batch 75] orig_loss = 0.6113, gib_loss = 0.8767


Epoch 2:  75%|███████▌  | 80/106 [06:38<02:19,  5.35s/it]

[Epoch 2 | Batch 80] orig_loss = 0.6381, gib_loss = 0.8983


Epoch 2:  80%|████████  | 85/106 [07:04<01:47,  5.11s/it]

[Epoch 2 | Batch 85] orig_loss = 0.0995, gib_loss = 0.8694


Epoch 2:  85%|████████▍ | 90/106 [07:30<01:23,  5.24s/it]

[Epoch 2 | Batch 90] orig_loss = 0.1042, gib_loss = 0.8833


Epoch 2:  90%|████████▉ | 95/106 [07:54<00:54,  4.97s/it]

[Epoch 2 | Batch 95] orig_loss = 0.0586, gib_loss = 0.8678


Epoch 2:  94%|█████████▍| 100/106 [08:19<00:29,  4.89s/it]

[Epoch 2 | Batch 100] orig_loss = 0.5897, gib_loss = 0.9062


Epoch 2:  99%|█████████▉| 105/106 [08:44<00:04,  4.83s/it]

[Epoch 2 | Batch 105] orig_loss = 0.3362, gib_loss = 0.8659


Epoch 2: 100%|██████████| 106/106 [08:49<00:00,  5.00s/it]


Epoch 2 complete


Epoch 3:   5%|▍         | 5/106 [00:24<08:10,  4.85s/it]

[Epoch 3 | Batch 5] orig_loss = 0.6408, gib_loss = 0.8656


Epoch 3:   9%|▉         | 10/106 [00:50<08:12,  5.13s/it]

[Epoch 3 | Batch 10] orig_loss = 0.1175, gib_loss = 0.8735


Epoch 3:  14%|█▍        | 15/106 [01:15<07:36,  5.01s/it]

[Epoch 3 | Batch 15] orig_loss = 0.0756, gib_loss = 0.8722


Epoch 3:  19%|█▉        | 20/106 [01:42<07:43,  5.39s/it]

[Epoch 3 | Batch 20] orig_loss = 0.2582, gib_loss = 0.8890


Epoch 3:  24%|██▎       | 25/106 [02:08<06:51,  5.08s/it]

[Epoch 3 | Batch 25] orig_loss = 0.5383, gib_loss = 0.8815


Epoch 3:  28%|██▊       | 30/106 [02:34<06:21,  5.02s/it]

[Epoch 3 | Batch 30] orig_loss = 0.1173, gib_loss = 0.8688


Epoch 3:  33%|███▎      | 35/106 [02:58<05:37,  4.75s/it]

[Epoch 3 | Batch 35] orig_loss = 0.0693, gib_loss = 0.8789


Epoch 3:  38%|███▊      | 40/106 [03:24<05:41,  5.18s/it]

[Epoch 3 | Batch 40] orig_loss = 0.0641, gib_loss = 0.8350


Epoch 3:  42%|████▏     | 45/106 [03:50<05:08,  5.05s/it]

[Epoch 3 | Batch 45] orig_loss = 0.0654, gib_loss = 0.8769


Epoch 3:  47%|████▋     | 50/106 [04:20<05:36,  6.02s/it]

[Epoch 3 | Batch 50] orig_loss = 0.0676, gib_loss = 0.8545


Epoch 3:  52%|█████▏    | 55/106 [04:46<04:34,  5.39s/it]

[Epoch 3 | Batch 55] orig_loss = 0.1048, gib_loss = 0.9162


Epoch 3:  57%|█████▋    | 60/106 [05:10<03:47,  4.95s/it]

[Epoch 3 | Batch 60] orig_loss = 0.5967, gib_loss = 0.8841


Epoch 3:  61%|██████▏   | 65/106 [05:36<03:25,  5.02s/it]

[Epoch 3 | Batch 65] orig_loss = 0.6134, gib_loss = 0.8732


Epoch 3:  66%|██████▌   | 70/106 [06:02<03:10,  5.28s/it]

[Epoch 3 | Batch 70] orig_loss = 0.0778, gib_loss = 0.9023


Epoch 3:  71%|███████   | 75/106 [06:27<02:32,  4.93s/it]

[Epoch 3 | Batch 75] orig_loss = 0.0892, gib_loss = 0.8672


Epoch 3:  75%|███████▌  | 80/106 [06:51<02:07,  4.89s/it]

[Epoch 3 | Batch 80] orig_loss = 0.0913, gib_loss = 0.8776


Epoch 3:  80%|████████  | 85/106 [07:20<01:57,  5.58s/it]

[Epoch 3 | Batch 85] orig_loss = 0.0635, gib_loss = 0.8620


Epoch 3:  85%|████████▍ | 90/106 [07:45<01:21,  5.06s/it]

[Epoch 3 | Batch 90] orig_loss = 0.1243, gib_loss = 0.9107


Epoch 3:  90%|████████▉ | 95/106 [08:10<00:55,  5.02s/it]

[Epoch 3 | Batch 95] orig_loss = 0.6639, gib_loss = 0.8524


Epoch 3:  94%|█████████▍| 100/106 [08:37<00:33,  5.51s/it]

[Epoch 3 | Batch 100] orig_loss = 0.0612, gib_loss = 0.8719


Epoch 3:  99%|█████████▉| 105/106 [09:04<00:05,  5.26s/it]

[Epoch 3 | Batch 105] orig_loss = 0.6191, gib_loss = 0.9001


Epoch 3: 100%|██████████| 106/106 [09:09<00:00,  5.18s/it]


Epoch 3 complete
