# Activations visulaized - 1

Gets model(loras) checkpoint and runs data thru it to capture acts, grads and LoRA grads

In [None]:
%env CUDA_VISIBLE_DEVICES=0
%env OMP_NUM_THREADS=16 
%env MKL_NUM_THREADS=16 
# %load_ext autoreload
# %autoreload 2

In [None]:
import sys, pathlib, os
sys.path.append(str(pathlib.Path('./src').resolve()))

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
device = torch.device('cuda:0')

from tqdm.auto import tqdm
print(f"{torch.__version__=}, {transformers.__version__=}, {device=}")


from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
from utils_data import get_data, preprocess_datasets
from utils_trainer import TrainerWithMetrics
from utils_model import (
    dc_regularizing_loss,
    get_fitted_logreg,
    get_base_model,
    get_tokenizer,
    ModelWithMultipleLoras,
)

## loading data and model
code from `main.py`

In [None]:
model_args, data_args, training_args = torch.load("baseline_args.pt")

In [None]:
raw_datasets, is_regression, label_list = get_data(
    model_args, data_args, training_args
)

In [None]:
model = get_base_model(
    model_args, finetuning_task=data_args.task_name, num_labels=len(label_list),
)

tokenizer = get_tokenizer(model_args)

train_dataset, eval_dataset, predict_dataset, raw_datasets = preprocess_datasets(
    raw_datasets,
    data_args,
    training_args,
    model,
    tokenizer,
    label_list,
    is_regression,
)

In [None]:
if data_args.pad_to_max_length:
    data_collator = transformers.default_data_collator
elif training_args.fp16:
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
else:
    data_collator = None


In [None]:
class ModelWithMultipleLoras_2(ModelWithMultipleLoras):

    def forward(self, input_ids, attention_mask, labels):
        
        self.saved = [] # !!!
        
        rank = 0

        if self.n_of_loras and (self.shift_lr_rw or self.shift_dc_rw or not self.model.training):
            with torch.random.fork_rng(
                devices=(torch.device("cpu"), self.device), enabled=True
            ):
                baseline_activations = self._choose_adapter_and_forward(
                    -1, input_ids, attention_mask
                )
                if self.model_type == "deberta":
                    baseline_activations = baseline_activations[:, 0]
        different_loras_activations = []
        for i in range(max(self.n_of_loras, 1)):
            if i < self.n_of_loras - 1:
                with torch.random.fork_rng(
                    devices=(torch.device("cpu"), self.device), enabled=True
                ):
                    activations = self._choose_adapter_and_forward(
                        i, input_ids, attention_mask
                    )
            else:
                activations = self._choose_adapter_and_forward(
                    i, input_ids, attention_mask
                )
            different_loras_activations.append(activations)
            
        activations = self.get_head_input(different_loras_activations, self.coefs)

        self.saved.append(activations) # !!!

        logits = self.classifier(activations)
        loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
        if self.model_type == "deberta":
            for i in range(len(different_loras_activations)):
                different_loras_activations[i] = different_loras_activations[i][:, 0]

        if self.activation_lr_rw or self.activation_dc_rw:
            for cur_activation in different_loras_activations:
                if not rank:
                    if self.activation_lr_rw:
                        if len(torch.unique(labels)) < self.num_labels:
                            continue
                        logreg = get_fitted_logreg(
                            cur_activation.detach().cpu().numpy(),
                            labels.cpu().numpy(),
                            seed=self.seed,
                        )
                        loss += (
                            self.regularizing_logreg_loss(
                                cur_activation,
                                labels,
                                logreg,
                                neck_width=self.neck_width,
                                device=self.device,
                                dtype=self.model.dtype
                            ) * self.activation_lr_rw
                        )
                    if self.activation_dc_rw:
                        loss += (self.activation_dc_rw 
                                * dc_regularizing_loss(cur_activation, labels))
                else:
                    loss += cur_activation.norm()

        different_loras_shifts = []
        if self.n_of_loras and (self.shift_lr_rw or self.shift_dc_rw or not self.model.training):
            for cur_activation in different_loras_activations:
                cur_shift = cur_activation - baseline_activations
                different_loras_shifts.append(cur_shift)
                if not rank:
                    if self.shift_lr_rw:
                        if len(torch.unique(labels)) < self.num_labels:
                            continue
                        logreg = get_fitted_logreg(
                            cur_shift.detach().cpu().numpy(),
                            labels.cpu().numpy(),
                            seed=self.seed,
                        )
                        loss += (
                            self.regularizing_logreg_loss(
                                cur_shift,
                                labels,
                                logreg,
                                neck_width=self.neck_width,
                                device=self.device,
                                dtype=self.model.dtype
                            ) * self.shift_lr_rw
                        )
                    if self.shift_dc_rw:
                        loss += (self.shift_dc_rw 
                                * dc_regularizing_loss(cur_shift, labels))
                else:
                    loss += cur_shift.norm()

        return (
            loss,
            logits,
            [cur_acts for cur_acts in different_loras_activations],
            different_loras_shifts,
            activations,
        )

In [None]:
model_multiple_loras = ModelWithMultipleLoras_2(
        base_model=model,
        num_labels=2,
        model_type='deberta',
        n_of_loras=model_args.n_of_loras,
        lora_rank=model_args.lora_rank,
        device=training_args.device,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        seed=training_args.seed,
        mult_std=model_args.mult_std,
        method_name=model_args.coefs_method_name,
        activation_lr_rw=model_args.activation_lr_rw,
        shift_lr_rw=model_args.shift_lr_rw,
        activation_dc_rw=model_args.activation_dc_rw,
        shift_dc_rw=model_args.shift_dc_rw,
        loras_gradient_checkpointing=model_args.loras_gradient_checkpointing,
        model_gradient_checkpointing=model_args.model_gradient_checkpointing,
    ).to(device)

## Loading checkpoint

In [None]:
checkpoints_path = "./deberta_sst2/checkpoints"
sorted(os.listdir(checkpoints_path))

In [None]:
# insert step here
step = 16000
filename = os.path.join(checkpoints_path, f"checkpoint_{step}.pt")
cp = torch.load(filename)

In [None]:
def update_state_dict_from_checkpoint(self, checkpoint_state_dict):
    # updates ModelWithMultipleLoras with params from checkpoint
    # only takes params present in CP, keeps rest intact
    # assuming that the CP has just the loras and head where `requires_grad==True`
    print(f"checkpoint contains {len(checkpoint_state_dict)} modules")
    sd0 = model_multiple_loras.state_dict()
    print(f"state_dict contains {len(sd0)} modules")
    counter = 0 
    for n, p in self.state_dict().items():    
        if n in checkpoint_state_dict.keys():
            sd0[n] = checkpoint_state_dict[n]
            counter += 1
    self.load_state_dict(sd0)
    print(f"updated {counter} modules")

In [None]:
update_state_dict_from_checkpoint(model_multiple_loras, cp)

## Run val data thru model

In [None]:
lora_parameters = [p for n, p in model_multiple_loras.named_parameters() if p.requires_grad and 'lora' in n.lower()]
lora_names = [n for n, p in model_multiple_loras.named_parameters() if p.requires_grad and 'lora' in n.lower()]
for param in lora_parameters:
    assert param.requires_grad
len(lora_names), len(lora_parameters)

In [None]:
bsize = 1
training_args.per_device_eval_batch_size = bsize
training_args.per_device_train_batch_size = bsize

trainer = TrainerWithMetrics(
    model=model_multiple_loras,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)

In [None]:
for k_reg in [0, 1000]:
    acts = []
    grads = []
    lora_grads = []
    labels = []
    get_loras_grads = True
    
    for batch in tqdm(eval_dataloader):    
        bout = model_multiple_loras(input_ids=batch['input_ids'], 
                                    attention_mask=batch['attention_mask'],
                                    labels = batch['labels'])
        head_inputs = model_multiple_loras.saved[0]
        batch_loss = bout[0]

        grad_wrt_activations = torch.autograd.grad(batch_loss, head_inputs)

        if get_loras_grads:
            assert head_inputs.shape[0] == 1
            z = torch.randn_like(grad_wrt_activations[0])    
            grad_wrt_loras = torch.autograd.grad(
                outputs=[head_inputs], inputs=lora_parameters,
                grad_outputs=[grad_wrt_activations[0] + k_reg * z]
                )

        with torch.no_grad():
            acts.append(head_inputs[:, 0].detach())
            grads.append(grad_wrt_activations[0][:, 0])
            labels.append(batch['labels'])
            if get_loras_grads:
                # flat_grad_wrt_loras = torch.concat([grad.flatten() for grad in grad_wrt_loras], dim=0)
                # lora_grads.append(flat_grad_wrt_loras)
                lora_grads.append(grad_wrt_loras)
                del grad_wrt_loras #, flat_grad_wrt_loras
            # del grad_wrt_activations, head_inputs, batch_loss
        # labels = torch.concat(labels).cpu()
        torch.cuda.empty_cache()

    stacked_lora_grads = torch.stack([torch.stack([x.flatten() for x in one_datapoint]) for one_datapoint in lora_grads])
    to_save = dict(
        acts = torch.concat(acts).cpu().half(),
        grads = torch.concat(grads).cpu().half(),
        lora_grads = stacked_lora_grads.cpu().half(),
        labels = torch.concat(labels).cpu(),
    )
    if "outs" not in os.listdir():
        os.mkdir(os.path.join(os.getcwd(), "outs"))
    torch.save(to_save, f'outs/outs_{step}_reg_{k_reg}.pt')
    # break

In [None]:
%stop