# Potentially Buggy Changes
- I set inference mode to True from False
- Hard coded compute_dtype=bfloat16 if attribute doesn't exist in one of the packages
-
- 'default' keeps showing up as the active model when it should be 'encoder' or 'decoder'. Not sure why

## Meta things to improve iteration speed
- Implement logging
- Some way to easily compare text outputs side by side given different training setups
- 

## Code Organization Notes
- Ideally 1 function per cell and add section headings for easy navigation in outline
- Add new features slowly and back up version that works
- Maybe handle all device calls in one place so we don't have to go everywhere to change them
- Consider moving functions etc to seperate files and use importlib.reload (see https://docs.python.org/3/library/importlib.html)

# Why isn't loss lower? Hypotheses
- Generate activations with the VAE model with decoder enabled so we compare apples to apples
- Try some infeasibly high n_epochs to test the 'just not training long enough' hypothesis (or figure out how to add GPUs with Josh)
- Maybe there are local minima that are hard to get out of (0.2-0.3 cosim, any others? Is there another one at 0.3-0.4 cosim?) Try initing from a variety of cosims to see learning depends on current cosim value
- Maybe some of the hidden layers give v noisy signals
- Try to get a deeper understanding of the latent vs activations landscape by interpolating between two texts in latent space and see what the pattern is like, ideally for a diverse variety of inputs

# Setup
If you haven't downloaded the model checkpoint use the following commands.

## Download and unzip checkpoints

In [1]:
# Download zipped model file
# !wget -v 'https://models.rivershavewings.workers.dev/ldlm/vae_48.tar'
# Unzip the file
# !tar -xvf vae_48.tar


## Imports

In [2]:
import argparse
from contextlib import contextmanager
from itertools import chain, islice
import json
import math
from pathlib import Path
import random
import sys
import zipfile
import typing
import matplotlib.pyplot as plt

import accelerate
from datasets import load_dataset
from einops import rearrange
# import k_diffusion as K
import peft
import safetensors.torch as safetorch
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils import data
from tqdm import trange, tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaTokenizer, PreTrainedModel, Trainer, LlamaConfig, TrainingArguments

import bitsandbytes

torch.manual_seed(0)

<torch._C.Generator at 0x7f2d5bf85b50>

## Helper Functions and Classes

In [3]:
# Helper functions
print = tqdm.external_write_mode()(print)

def cosine_warmup(steps, value=1.0):
    return lambda i: value * math.sin(min(i / steps, 1) * math.pi / 2) ** 2

In [4]:
def print_ids(ids, generation_length, tokenizer):
    out_texts = []
    output_ids = ids[0][-generation_length:].unsqueeze(0)
    out_texts += [tokenizer.decode(toks, skip_special_tokens=False) for toks in output_ids]
    print(' | '.join(out_texts))

In [5]:

def process_raw_hidden(hidden_states_tuple,
                       layer_index=-1,
                        last_token_only=True, 
                        # 
                        ):
    """Takes model_output.hidden_states and processes it as desired

    Args:
        hidden_states_tuple (_type_): _description_
        layer_index (int, optional): _description_. Defaults to -1. Layer list includes embedding layer (i.e. has length n_layers + 1)
        - Can also take a slice for range of layers
        last_token_only (bool, optional): _description_. Defaults to True.
        no_embedding (bool, optional): _description_. Defaults to True.
        as_tuple (bool, optional): _description_. Defaults to False.

    Returns:
        torch.Tensor of size (n_layers + include_embedding, batch_size, hidden_size)
    """    
    # if no_embedding:
    
    # if len(hidden_states_tuple) > 1:
    #     hidden_states_tuple = Tuple(hidden_states_tuple,)
    # if as_tuple:
    #     return hidden_states_tuple
    if isinstance(layer_index, slice):
        hidden_states_tuple = hidden_states_tuple[layer_index]
        hidden_states_tensor = torch.stack(hidden_states_tuple)
    elif isinstance(layer_index, int):
        hidden_states_tensor = hidden_states_tuple[layer_index].unsqueeze(0)
        # hidden_states_tensor = hidden_states_tuple[0]
    # hidden_states_tensor = torch.stack(hidden_states_tuple)
    if last_token_only:
        hidden_states_tensor = hidden_states_tensor[:, :, -1, :]
        # hidden_states_tensor.unsqueeze(2)
    return hidden_states_tensor



In [6]:
def get_hidden_states(model, tokenizer, device, prompt, dtype=torch.bfloat16, layer_index=-1, last_token_only=True,):
        """_summary_

        Args:
            prompt (_type_): _description_
            model (_type_): _description_
            tokenizer (_type_): _description_
            last_token_only (bool, optional): _description_. Defaults to True.
            no_embedding (bool, optional): _description_. Defaults to True.
            as_tuple (bool, optional): _description_. Defaults to False.

        Returns:
            torch.Tensor of size (n_layers, batch_size, hidden_size)
        """    
        #TODO check if support multiple prompts
        # prompt = 'Q: What is the largest animal?\nA:'
        tokenizer_output = tokenizer(prompt, return_tensors="pt").to(device)
        # tokenizer_output = accelerator.prepare(tokenizer_output)
        input_ids = tokenizer_output.input_ids
        attention_mask = tokenizer_output.attention_mask
        # print(tokenizer_output)
        with torch.cuda.amp.autocast(dtype=dtype):
            model_outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states_tuple = model_outputs.hidden_states
        #TODO check if I have to remove any added hidden layers from the adapter
        #TODO check if just disabling adapter works
        return process_raw_hidden(hidden_states_tuple, layer_index=-1, last_token_only=True, )

In [7]:
@contextmanager
def set_adapter(model, adapter_name):
    old_adapter_name = model.active_adapter
    try:
        if adapter_name is not None:
            model.set_adapter(adapter_name)
            yield model
        else:
            with model.disable_adapter():
                yield model
    finally:
        model.set_adapter(old_adapter_name)

In [8]:
def gumbel_like(x):
    return torch.rand_like(x).log_().nan_to_num_().neg_().log_().neg_()

In [9]:
@contextmanager
def disable_causal_mask():
    import transformers.models.llama.modeling_llama as modeling

    decoder_fn = modeling._make_causal_mask

    def encoder_fn(*args, **kwargs):
        return torch.zeros_like(decoder_fn(*args, **kwargs))

    try:
        modeling._make_causal_mask = encoder_fn
        yield
    finally:
        modeling._make_causal_mask = decoder_fn


## Main DecoderOnlyTransformerVAE class

In [10]:
class VAEComponent(nn.Module):
    def __init__(self, d_model, z_dim):
        super().__init__()
        self.d_model = d_model
        self.z_dim = z_dim
        self.f = nn.Linear(d_model, 1)
        self.w_e = nn.Linear(d_model, z_dim)
        self.w_d = nn.Linear(z_dim, d_model)
        nn.init.orthogonal_(self.w_e.weight)
        with torch.no_grad():
            self.w_d.weight.copy_(self.w_e.weight.T)

    def encode(self, hidden_states, attention_mask):
        scores = self.f(hidden_states)
        scores = scores + attention_mask[:, :, None].log().nan_to_num()
        weights = torch.softmax(scores, dim=1)
        pooled = torch.sum(hidden_states * weights, dim=1)
        return self.w_e(pooled)

    def sample(self, mean, tau=1.0):
        return mean + torch.randn_like(mean) * tau**0.5

    def decode(self, z):
        return self.w_d(z)

In [11]:
class DecoderOnlyTransformerVAE(nn.Module):
    def __init__(self, model_name, device, z_dim=768, lora_rank=32, dropout=0.0, dtype=torch.bfloat16):
        super().__init__()
        self.dtype = dtype
        # if model_name == "openlm-research/open_llama_3b_v2":
        #     self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
        # else:
        #     self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        #     print("WARNING: tokenizer only verified to work for open_llama_3b_v2")
        # self.tokenizer.padding_side = "left"
        self.device = device
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map={"": self.device},
            quantization_config=bnb_config,
            torch_dtype=self.dtype,
        )
        peft_config = peft.LoraConfig(
            peft.TaskType.CAUSAL_LM,
            inference_mode=True, #TODO: used to be False, check for errors down the line
            r=lora_rank,
            lora_alpha=8,
            lora_dropout=dropout,
            target_modules=[
                "self_attn.q_proj",
                "self_attn.k_proj",
                "self_attn.v_proj",
                "self_attn.o_proj",
                "mlp.gate_proj",
                "mlp.up_proj",
                "mlp.down_proj",
            ],
        )

        self.z_dim = z_dim


        self.model = peft.get_peft_model(model, peft_config, "encoder")
        self.model.add_adapter("decoder", peft_config)
        self.model.set_adapter("decoder")
        self.model.config.output_hidden_states = True
        # self.model.to(self.device)
        self.vae = VAEComponent(self.model.config.hidden_size, self.z_dim).to(self.device)
        # self.model, self.vae = accelerator.prepare(self.model, self.vae)


    
    def save_pretrained(self, path):
        path = Path(path)
        self.model.save_pretrained(path, safe_serialization=True)
        safetorch.save_file(self.vae.state_dict(), path / "vae.safetensors")

    def load_pretrained(self, path, is_trainable=False):
        path = Path(path)
        self.model.delete_adapter("encoder")
        # if "encoder" in list(self.model.peft_config.keys()):
        #     self.model.delete_adapter("encoder")
        encoder_load_result = self.model.load_adapter(path / "encoder", "encoder", is_trainable=is_trainable)
        # print(encoder_load_result)
        self.model.delete_adapter("decoder")
        # if "decoder" in list(self.model.peft_config.keys()):
        #     self.model.delete_adapter("decoder")
        # self.model.set_adapter("encoder")
        self.model.disable_adapter()
        self.model.set_adapter("encoder")
        
        decoder_load_result = self.model.load_adapter(path / "decoder", "decoder", is_trainable=is_trainable)
        # print(decoder_load_result)
        # self.model.set_adapter("decoder")
        self.vae.load_state_dict(safetorch.load_file(path / "vae.safetensors"))

    def encode(self, input_ids, attention_mask):
        with set_adapter(self.model, "encoder"), disable_causal_mask():
            outputs = self.model(
                input_ids=input_ids, attention_mask=attention_mask, use_cache=False
            )
        return self.vae.encode(outputs.hidden_states[-1], attention_mask)
    

    def input_ids_to_embeds(self, input_ids):
        embed_weight = self.model.get_input_embeddings().weight
        input_one_hots = F.one_hot(input_ids, num_classes=self.model.config.vocab_size)
        return input_one_hots.to(embed_weight) @ embed_weight

    # @torch.no_grad()
    def generate(self, z, input_ids, attention_mask, n_tokens, tau=1.0, output_hidden_states=False, eos_id=None):
        """Generates n_tokens from a latent code.
            If output_hidden_states: generates a dict of output_ids and hidden_states.
            If given an EOS id, will stop generation when it is generated.
        """
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            z_embed = self.vae.decode(z)[:, None]
            inputs_embeds = self.input_ids_to_embeds(input_ids)
            # print(inputs_embeds.shape)
            inputs_embeds = torch.cat([z_embed, inputs_embeds], dim=1)
            attention_mask = torch.cat(
                [attention_mask.new_ones([attention_mask.shape[0], 1]), attention_mask], dim=1
            )
            new_embeds, past = None, None
        

        with set_adapter(self.model, "decoder"):
            for _ in range(n_tokens):
                outputs = self.model(
                    inputs_embeds=inputs_embeds if past is None else new_embeds,
                    attention_mask=attention_mask,
                    use_cache=True,
                    past_key_values=past,
                    output_hidden_states=output_hidden_states,
                )
                logits = outputs.logits[:, -1:, :].float()
                new_input_ids = torch.argmax(logits + gumbel_like(logits) * tau, dim=-1)
            

                input_ids = torch.cat([input_ids, new_input_ids], dim=1)
                if eos_id != None and (new_input_ids == eos_id).any(): #TODO check if this works
                    break
                new_embeds = self.input_ids_to_embeds(new_input_ids)
                attention_mask = torch.cat(
                    [attention_mask, attention_mask.new_ones([attention_mask.shape[0], 1])], dim=1
                )
                past = outputs.past_key_values
        if output_hidden_states:
            hidden_states = outputs.hidden_states
            return {'output_ids': input_ids, 'hidden_states': hidden_states}
        else:
            return input_ids

    def forward(self, input_ids, attention_mask, decoder_prefix_ids, decoder_prefix_mask, output_hidden_states=False):
        input_ids_all = torch.cat([decoder_prefix_ids, input_ids], dim=1)
        attn_mask_all = torch.cat([decoder_prefix_mask, attention_mask], dim=1)
        mean = self.encode(input_ids, attention_mask)
        z = self.vae.sample(mean)
        z_embed = self.vae.decode(z)[:, None]
        inputs_embeds = self.input_ids_to_embeds(input_ids_all)
        inputs_embeds = torch.cat([z_embed, inputs_embeds], dim=1)
        attention_mask = torch.cat(
            [attention_mask.new_ones([attn_mask_all.shape[0], 1]), attn_mask_all], dim=1
        )
        with set_adapter(self.model, "decoder"):
            outputs = self.model(
                inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=False, output_hidden_states=output_hidden_states
            )
        return outputs, mean


## Trying accelerate on notebook
Following https://huggingface.co/docs/accelerate/basic_tutorials/notebook

Crashes for some reason?

In [12]:
# import os
# from accelerate.utils import write_basic_config

# write_basic_config()  # Write a config file
# os._exit(00)  # Restart the notebook

## Setup accelerator

In [13]:
# Setup accelerator
accelerator = accelerate.Accelerator(
        mixed_precision="bf16", gradient_accumulation_steps=1
    )
device = accelerator.device if accelerator.num_processes > 1 else "cuda:0"

print(device)
is_main = accelerator.is_main_process
print0 = accelerator.on_main_process(print)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


cuda:0


## Load model_vae and tokenizer

In [14]:
model_name = "openlm-research/open_llama_3b_v2"
# device = "cuda"
with accelerator.main_process_first():
    tokenizer = LlamaTokenizer.from_pretrained(model_name)
    tokenizer.padding_side = "left"
    model_only = DecoderOnlyTransformerVAE(
        model_name, device,
        # z_dim=768, lora_rank=32, dropout=0.0,
    )

    # model_vae.enable_adapters()
model_only.load_pretrained("/data/joshua_clymer/spar-red-team/owen/LDLM/checkpoints/vae_48") 
# model_only.to(device)
n_layers = 12
d_model = 1024
accelerator.wait_for_everyone()

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## Test adapter hidden

In [15]:
#TODO check if adapter hidden is different (raw model vs adapter on/off)

## Wrapped VAE

In [22]:
class WrappedVAE:
    """
    Wrapper for VAE model that handles tokenization and decoding
    Helper methods go here
    Allows changing methods without reloading model
    """
    def __init__(self, ldlm_model, tokenizer, device, dtype=torch.bfloat16):
        self.ldlm_model = ldlm_model.to(device)
        self.tokenizer = tokenizer
        self.device = device # Only model goes to device
        self.dtype = dtype
    def reconstruct(self, input_text, generation_length=48):
        with torch.cuda.amp.autocast(dtype=self.dtype):
            tokenizer_output = self.tokenizer(input_text, return_tensors="pt")
            input_ids = tokenizer_output.input_ids.to(self.device)
            attention_mask = tokenizer_output.attention_mask.to(self.device)
            input_latent = self.ldlm_model.encode(input_ids, attention_mask)
            output_ids = self.ldlm_model.generate(input_latent, input_ids, attention_mask, generation_length, output_hidden_states=False)
 
            text = self.ids_to_text(output_ids, generation_length)
            return {'output_ids': output_ids, 'generation_length': generation_length, 'text': text}
    def text_to_latent(self, text,):
        tokenizer_out = self.tokenizer(text, return_tensors="pt")
        tokens = tokenizer_out.input_ids
        mask = tokenizer_out.attention_mask
        # tokens, mask = accelerator.prepare(tokens, mask)
        tokens, mask = tokens.to(self.device), mask.to(self.device)
        with torch.cuda.amp.autocast(dtype=self.dtype):
            return self.ldlm_model.encode(tokens, mask)
        
    
    def get_hidden_states(self, prompt, layer_index=-1, last_token_only=True,):
        """_summary_

        Args:
            prompt (_type_): _description_
            model (_type_): _description_
            tokenizer (_type_): _description_
            last_token_only (bool, optional): _description_. Defaults to True.
            no_embedding (bool, optional): _description_. Defaults to True.
            as_tuple (bool, optional): _description_. Defaults to False.

        Returns:
            torch.Tensor of size (n_layers, batch_size, hidden_size)
        """    
        get_hidden_states(self.ldlm_model, self.tokenizer, self.device, prompt, dtype=self.dtype, layer_index=-1, last_token_only=True, )

    def ids_to_text(self, ids, generation_length=48,):
        out_texts = []
        output_ids = ids[0][-generation_length:].unsqueeze(0)
        out_texts += [self.tokenizer.decode(toks, skip_special_tokens=False) for toks in output_ids]
        return ' | '.join(out_texts)
# initialize model
        

# if torch.cuda.device_count() > 1:
#     print("Using", torch.cuda.device_count(), "GPUs")
#     model_vae = nn.DataParallel(model_vae)

# model_vae = accelerator.prepare(model_vae)




In [23]:
wrapped_vae = WrappedVAE(model_only, tokenizer, device)

In [24]:
## Misc variables and helper methods
i = 0
measure_ae_scale = False
ae_scale_sum = torch.tensor(0.0, device=device)
ae_scale_sum = accelerator.prepare(ae_scale_sum)
ae_scale = 1.527548
tau = 0.1
z_dim = 768
prompt = ""

# @torch.no_grad()
# @torch.cuda.amp.autocast(dtype=torch.bfloat16)
# def sample(model, z_prev):
#     bs = 1
    
#     sigma_min, sigma_max = 0.01, 100
#     sigmas = K.sampling.get_sigmas_karras(25, sigma_min, sigma_max, device=device)
#     x = torch.randn([bs, z_dim], device=device) * sigma_max
#     extra_args = {
#         "z_prev": z_prev / ae_scale,
#         "padding_mask": torch.ones([bs, 1], dtype=torch.long, device=device),
#     }
#     mean = K.sampling.sample_dpmpp_2m_sde(
#         model, x, sigmas, eta=0.0, extra_args=extra_args, disable=not is_main
#     )
#     return mean * ae_scale

def vae_tokenize(prev_window, n_tokens):
    #TODO figure out why this is necessary over regular tokenization
    tokens = tokenizer(prompt, return_tensors="pt")
    try:
        input_ids = tokens["input_ids"][0][:n_tokens].unsqueeze(0).to(device)
        attention_mask = tokens["attention_mask"][0][:n_tokens].unsqueeze(0).to(device)
    except IndexError: # Let prompts under 48 tokens through
        input_ids = tokens["input_ids"].to(device)
        attention_mask = tokens["attention_mask"].to(device)
    return input_ids, attention_mask

accelerator.wait_for_everyone()


## Setup Tensorboard

# Test generation

In [25]:
# # Test Generation
# n_tokens =48
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
#     input_ids, attention_mask = vae_tokenize(prompt, n_tokens)
#     z_prev = model_vae.encode(input_ids, attention_mask)[:, None]
# print(input_ids)
# print(z_prev.shape)
    
# # out_embeds = []
# # for i in range(5):
# #     out_embeds += sample(accelerator.unwrap_model(model), z_prev).unsqueeze(0).unsqueeze(0)
# #     z_prev = out_embeds[-1]
# #     # Looks like z_prev is the same kind of tensor as sample()


# # input_ids, attention_mask = vae_tokenize(prompt, n_tokens)
# # out_texts = [prompt]
# LLAMA_EOS_ID = tokenizer.eos_token_id
# tokenizer_output = tokenizer(prompt, return_tensors="pt")
# # print(tokenizer_output)
# input_ids = tokenizer_output.input_ids.to(device)
# attention_mask = tokenizer_output.attention_mask.to(device)
# print(input_ids == tokenizer.bos_token) # looks like it's 1 token, probably BOS
# out_texts = []
# n_tokens = 50
# # for z in out_embeds:
# z = torch.randn([1, 1, z_dim], device=device) 
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
#     z = model_vae.vae.sample(z, tau=tau)
#     generation_outputs = model_vae.generate(z.squeeze(0),
#                                     input_ids, #TODO: set to BOS, check if tokenizing empty string does this (ids nonempty)
#                                     attention_mask, #TODO: set to BOS attn mask
#                                     n_tokens,
#                                     tau=tau,
#                                     output_hidden_states=True,
#                                     # eos_id=LLAMA_EOS_ID,
#                                     )
#     # print(generation_outputs)
#     output_ids = generation_outputs['output_ids'][0][-n_tokens:].unsqueeze(0)      
#     hidden_states = generation_outputs['hidden_states']                
#     # attention_mask = torch.ones([1,48], dtype=torch.long, device=device)
#     # print(input_ids.shape)
# out_texts += [tokenizer.decode(toks, skip_special_tokens=False) for toks in output_ids]

# print(' | '.join(out_texts))
# print(out_texts)


In [26]:
# print(out_texts)
# print(generation_outputs['output_ids'].shape)

# Test Reconstruction

In [27]:
# shaq_text = "Shaquille O'Neal is a 7-foot-1-inch (2.16 m) and 325-pound (147 kg) center who played for six teams over his 19-year career in the National Basketball Association (NBA) and is a four-time NBA champion. O'Neal is regarded as one of the greatest basketball players and centers of all time."
# dog_yes = "Do dogs bark? Yes."
# dog_no = "Do dogs bark? No."


# sun_q = "Question: What color is the sun when viewed from space? "
# sun_correct = sun_q + "Answer: The sun is white when viewed from space. "
# sun_incorrect = sun_q + "Answer: The sun is yellow when viewed from space. "
# # wrapped_vae.reconstruct(dog_yes, generation_length=100)
# # wrapped_vae.reconstruct(dog_no, generation_length=100)
# wrapped_vae.reconstruct(sun_correct, generation_length=40)
# wrapped_vae.reconstruct(sun_incorrect, generation_length=40)
# wrapped_vae.reconstruct(sun_incorrect + sun_correct, generation_length=40)


{'output_ids': tensor([[    1, 10706, 29537,  1200,  2177,   325,   268,  3573,   661, 11521,
            440,  2184, 29584, 13910, 29537,   364,  3573,   325,  7951,   661,
          11521,   440,  2184, 29520, 10706, 29537,  1200,  2177,   325,   268,
           3573,   661, 11521,   440,  2184, 29584, 13910, 29537,   364,  3573,
            325,  2638,   661, 11521,   440,  2184, 29520, 29500,    13, 28629,
          29537,  1200,  2177,   325,   268,  3573,   661, 11521,   440,  2184,
          29584, 13910, 29537,   364,  3573,   325,  2638,   661, 11521,   440,
           2184, 29520, 29500,    13, 28629, 29537,  1200,  2177,   325,   268,
           3573,   661, 11521,   440,  2184, 29584, 13910, 29537]],
        device='cuda:0'),
 'generation_length': 40,
 'text': '\nQuestion: What color is the sun when viewed from space? Answer: The sun is white when viewed from space. \nQuestion: What color is the sun when viewed from space? Answer:'}

In [29]:
# male_text = "male " * 10
# female_text = "female " * 10
male_tennis = """
Rafael Nadal Parera (born 3 June 1986) is a Spanish professional tennis player. 
Nadal has been ranked world No. 1 in singles by the Association of Tennis Professionals (ATP) for 209 weeks, 
and has finished as the year-end No. 1 five times. 
Nadal has won 22 Grand Slam men's singles titles, 
including a record 14 French Open titles. 
He has won 92 ATP singles titles, including 36 Masters titles, with 63 of these on clay courts. 
Nadal is one of only two men to complete the Career Golden Slam in singles. 
His 81 consecutive wins on clay constitute the longest single-surface win streak in the Open Era.
"""
female_tennis = """
Serena Jameka Williams (born September 26, 1981) is an American former professional tennis player. 
Widely regarded as one of the greatest tennis players of all time, 
she was ranked world No. 1 in singles by the Women's Tennis Association (WTA) for 319 weeks, 
including a joint-record 186 consecutive weeks, and finished as the year-end No. 1 five times. 
She won 23 Grand Slam women's singles titles, the most in the Open Era, and the second-most of all time. 
She is the only player to accomplish a career Golden Slam in both singles and doubles.
"""

# wrapped_vae.reconstruct(male_text,)
# wrapped_vae.reconstruct(female_text,)
with torch.no_grad():
    print(wrapped_vae.reconstruct(male_tennis * 10,)['text'])
    print(wrapped_vae.reconstruct(female_tennis * 10,)['text'])

OutOfMemoryError: CUDA out of memory. Tried to allocate 202.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 6.12 MiB is free. Including non-PyTorch memory, this process has 79.13 GiB memory in use. Of the allocated memory 72.56 GiB is allocated by PyTorch, and 6.08 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# Test getting hidden given text

In [None]:
emotional = '''
Are you kidding me? LeBron James better than Michael Jordan? That's a laughable claim. 
Six championships, ten scoring titles, unmatched defensive prowess, 
and a killer instinct that LeBron can only dream of. 
Jordan is the undisputed GOAT, and no amount of arguments will change that.
'''
shaq_wikipedia = "Shaquille O'Neal is a 7-foot-1-inch (2.16 m) and 325-pound (147 kg) center who played for six teams over his 19-year career in the National Basketball Association (NBA) and is a four-time NBA champion. O'Neal is regarded as one of the greatest basketball players and centers of all time."
explicit = """
Based? Based on what? In your dick? 
Please shut the fuck up and use words properly you fuckin troglodyte, 
do you think God gave us a freedom of speech just to spew random words 
that have no meaning that doesn't even correllate to the topic of the conversation? 
Like please you always complain about why no one talks to you or no one 
expresses their opinions on you because you're always spewing random shit like 
poggers based cringe and when you try to explain what it is and you just say that 
it's funny like what? What the fuck is funny about that do you think you'll 
just become a stand-up comedian that will get a standing ovation just because 
you said "cum" in the stage? HELL NO YOU FUCKIN IDIOT, so please shut the fuck up 
and use words properly
"""
with torch.no_grad():
    wikipedia_decoded = wrapped_vae.reconstruct(shaq_wikipedia)['text']
    angry_decoded = wrapped_vae.reconstruct(emotional)['text']
    explicit_decoded = wrapped_vae.reconstruct(explicit)['text']
print(wikipedia_decoded, angry_decoded, explicit_decoded, sep='\n')

# Hidden States comparison

In [None]:
# # Test getting hidden given text
with accelerator.main_process_first():
    model_raw = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='auto',)
    # model_raw = accelerator.prepare(model_raw)
    model_raw.to(device)

    
# # print(len(hidden_states))
# # print(hidden_states[0].shape)
# dog_yes_hidden = get_hidden_states("Do dogs bark? Yes.", model_raw, tokenizer, device, layer_index=-1, last_token_only=True, )
# dog_no_hidden = get_hidden_states("Do dogs bark? No.", model_raw, tokenizer, device, layer_index=-1, last_token_only=True, )
# truth_direction = dog_yes_hidden - dog_no_hidden
# truth_direction = accelerator.prepare(truth_direction)   
# print(truth_direction.shape)


TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'return_tensors'

In [None]:
# model_vae.forward()

# Trainer

In [None]:
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter()

In [None]:
#TODO refactor old trainer
# def cosim(a, b):
#     # assumes a, b have already been view(-1)
#     return F.cosine_similarity(a, b, dim=0).item()



class LatentTrainer(Trainer):
    def __init__(
        self, wrapped_vae: WrappedVAE, training_args, latent_module, logging_steps=1e2,
        # target_text=None,
        target_dir=None,
        generation_length=10,
        tokenizer=tokenizer,
        # length_reg=0.0,
    ):
        """_summary_
        Example use case: target_dir is hidden of last token of shaq
        loss is 1 - cosine_similarity(hidden - target)
        Loss transposed to be min at 0, max at 2

        Args:
            optimus (Optimus): _description_
            target_dir (_type_): _description_
            training_args (_type_): _description_
            latent (_type_): _description_
        """        
        '''
        decoder: decoder model from an Optimus VAE, must have output_hidden_states=True
        target_dir: of size (1 + num_layers, 1, 1, latent_size)
        '''
        # self.targeting_text = target_text != None
        # self.targeting_dir = target_dir != None

        # if self.targeting_text and self.targeting_dir:
        #     raise ValueError('Can only target text or activation direction, not both')

        # self.target_text = target_text
        self.target_dir = target_dir

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        latent_module.to(self.device)
        super().__init__(model=latent_module, args=training_args)
        self.model_vae = ol2vae # make sure this outputs hidden
        # if self.targeting_dir:
        self.target_dir = target_dir.to(self.device).reshape(-1) #flattened
        # assert target_dir.shape == 
        # context = optimus.model_vae.tokenizer_decoder.encode('<BOS>')
        # context = torch.tensor(context, dtype=torch.long, device=self.device)
        self.tokenizer = tokenizer

        context = ""
        tokenizer_output = self.tokenizer(context, return_tensors="pt")
        self.context_ids = tokenizer_output.input_ids.to(self.device)
        self.context_mask = tokenizer_output.attention_mask.to(self.device)
        
        self.loss_values = []
        self.logging_steps = logging_steps
        # self.length_reg = length_reg
        self.generation_length = generation_length

    def compute_loss(self, latent_module, return_cosim=False, return_text=True
                     #TODO maybe change cosim to compute layerwise so we know where the cosim loss is coming from
        # return_dir=False,
        ): 
        '''
        latent is a trainable parameter/model with shape [1, latent_size]
        '''
        #1. Extract params from latent model
        latent = latent_module.get_parameter('latent').clone().requires_grad_(True)
        # inputs = {'input_ids': self.context, 'past': past}
        
        #2. Put latent through vae decoder and Get hidden state
        # hidden_states = self.decoder(**inputs)[2] 
        
        # if self.targeting_dir:
        out = self.wrapped_vae.ldlm_model.generate(latent, self.context_ids, self.context_mask, 48, output_hidden_states=True)

        output_ids = out['output_ids']
        hidden_states_tuple = out['hidden_states']
        current_dir = process_raw_hidden(hidden_states_tuple, layer_index=-1, last_token_only=True, ).reshape(-1)
        
        # hidden_states_last_token = Optimus.extract_last_token(hidden_states_tuple)
        # hidden_states_last_token = torch.stack(hidden_states_last_token) # hidden_states reformated from tuple of tensors to single tensor
        # print(f'hidden_states_last_token shape: {hidden_states_last_token.shape}')
        # hidden_states_last_token = hidden_states_all_tokens[..., -1, :]
        # current_dir = hidden_states_last_token.view(-1) #flatten hidden layers into 1d vector
        if current_dir.shape != self.target_dir.shape:
            print(f'Target shape is {self.target_dir.shape} but current_dir shape is {current_dir.shape}')
            raise ValueError
        #3. Compute loss with cosine similarity between hidden states
        similarity = F.cosine_similarity(current_dir, self.target_dir, dim=0)
        loss = 1 - similarity # + self.length_reg * output_length ** 2
        # assert loss.numel() == 1, "Loss must be a scalar"
        # assert loss <= 2 and loss >= 0
        return_dict = {'loss': loss}
        if return_cosim:
            return_dict['cos_similarity'] = similarity
        if return_text:
            return_dict['output_ids'] = output_ids
        return return_dict
    
    

    def train(self, optimizer_state=None
            #   resume_checkpoint=False,
              ):

        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        if optimizer_state != None:
            optimizer.load_state_dict(optimizer_state)
        num_epochs = int(self.args.num_train_epochs)
        latent_module = self.model
        

        for epoch in range(num_epochs):
            # for step, batch in enumerate(self.get_train_dataloader()):
            optimizer.zero_grad()
            
            step_outputs = self.compute_loss(latent_module)
            loss = step_outputs['loss']
            loss.backward(
                retain_graph=True
                )
            
            latent = latent_module.get_parameter('latent')
            # if torch.all(latent.grad == 0):
            #     print('Latent grad is 0')

            # for name, param in self.model.named_parameters():
            #     print(f'Before step: {name}, {param.data}')

            optimizer.step()

            # After optimizer.step()
            # for name, param in self.model.named_parameters():
            #     print(f'After step: {name}, {param.data}')
            loss_scalar = loss.item()
            self.loss_values.append(loss_scalar)
            #TODO save logs somewhere
            if epoch % self.logging_steps == 0:
                print(f"Epoch {epoch}")
                # self.optimus.print_greedy(latent)
                text = self.wrapped_vae.ids_to_text(step_outputs['output_ids'], self.generation_length, self.tokenizer)
                print(text)
                print(f'Loss = 1 - cosine_similarity = {loss_scalar}')
                # print(text)
            # print(f"Epoch {epoch}, Loss: {loss.item()}")
        plt.plot(self.loss_values)
        plt.xlabel('Epoch')
        plt.ylabel('Loss = 1 - Cos Similarity')
        plt.show()
        training_state = {
            'epoch': num_epochs,
            'model_state_dict': latent_module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss_scalar
            }
        return training_state


In [None]:
# Initialize trainer and run loop

# dummy_data = {'dummy':'dummy',} # may be needed to appease Trainer
# dummy_data = Dataset.from_dict(dummy_data)
llamacfg = LlamaConfig()
NUM_LAYERS_FOR_LOSS = 1 # llamacfg.num_hidden_layers
BATCH_SIZE = 1
# SEQ_LEN = 1
HIDDEN_SIZE = 3200 # llamacfg.hidden_size

def train(
        wrapped_vae: WrappedVAE, target_dir=None, lr=1e-4, num_epochs=1e2, logging_steps=10, init_norm=1.0, init_latent=None, 
        # length_reg = 0.0 #1e-2,
        save_name='test',
        layer_index=-1,
        num_trainings=1,
        generation_length=10,
        checkpoint_path=None,
        # resume_checkpoint=False,
        ):
        
    for i in range(num_trainings):
        latent_size = wrapped_vae.ldlm_model.z_dim
        training_args = TrainingArguments(output_dir="test_trainer_checkpoints")
        training_args.num_train_epochs = num_epochs
        training_args.learning_rate = lr

        target_dir_shape = (NUM_LAYERS_FOR_LOSS, BATCH_SIZE, HIDDEN_SIZE)
        # decoder = vae.model_decoder_with_hidden.to(device)
        if target_dir is None:
            target_dir = torch.randn(target_dir_shape)
            target_dir = accelerator.prepare(target_dir)
        # else:
        #     self.target_dir_shape = target_dir.shape
            # assert target_dir.shape == target_dir_shape, f"target_dir must have shape (NUM_LAYERS={NUM_LAYERS}, BATCH_SIZE={BATCH_SIZE}, HIDDEN_SIZE={HIDDEN_SIZE}), but has shape {target_dir.shape}"
        
        latent_module = nn.Module()
        # if checkpoint_path == None:
        if init_latent == None:
            param = torch.randn(1, latent_size)
            param /= param.norm() * init_norm
        else:
            assert init_latent.shape == (1, latent_size)
            param = init_latent
        param = nn.Parameter(data=param, requires_grad=True)
        # param_init = param.clone().detach().to(device)
        latent_module.register_parameter("latent", param)
        optimizer_state = None
        if checkpoint_path != None:
            checkpoint = torch.load(checkpoint_path)
            latent_module.load_state_dict(checkpoint['model_state_dict'])
            optimizer_state = checkpoint['optimizer_state_dict']
        # latent_module = accelerator.prepare(latent_module)
        # original_latent = param.clone().detach().to(device)

        trainer = LatentTrainer(
            model_vae, training_args, latent_module, logging_steps=logging_steps,
            target_dir=target_dir,
            generation_length=generation_length,
            # optimizer_state=optimizer_state
            # length_reg=length_reg,
            # train_dataset=dummy_data,
            )
        training_state = trainer.train(optimizer_state=optimizer_state)
        #TODO figure out file structure
        # torch.save(latent_module.state_dict(), f'{save_name}-{i}.pth')
        
        torch.save(training_state, f'{save_name}.pth')
    # optimus.print_greedy(latent.get_parameter('latent'))
    # latent_diff = (param - param_init).view(-1).norm()
    # print(latent_diff)
    


In [None]:
print(accelerator.device)


cuda


In [None]:



# model_vae.to(device)
# A string from wikipedia about shaq
shaq_text = "Shaquille O'Neal is a 7-foot-1-inch (2.16 m) and 325-pound (147 kg) center who played for six teams over his 19-year career in the National Basketball Association (NBA) and is a four-time NBA champion. O'Neal is regarded as one of the greatest basketball players and centers of all time."
intro_text = "My name is Owen. Nice to meet you!"
# A string from wikipedia about benzene
benzene_text = "Benzene is a natural constituent of petroleum and is one of the elementary petrochemicals. Due to the cyclic continuous pi bonds between the carbon atoms, benzene is classed as an aromatic hydrocarbon. Benzene is a colorless and highly flammable liquid with a sweet smell, and is partially responsible for the aroma of gasoline."
intro_activations = get_hidden_states(intro_text, model_raw, tokenizer, device, layer_index=-1, last_token_only=True, )
intro_activations = accelerator.prepare(intro_activations)

intro_latent = text_to_latent(intro_text, model_vae)
intro_latent = accelerator.prepare(intro_latent)

In [None]:
intro_text_2 = "My name is Claire. Nice to meet you!"
intro_2_activations = get_hidden_states(intro_text_2, model_raw, tokenizer, device, layer_index=-1, last_token_only=True, )
intro_2_activations = accelerator.prepare(intro_2_activations)
print(intro_activations.shape)
print(F.cosine_similarity(intro_activations.reshape(-1),intro_2_activations.reshape(-1), dim=0))
accelerator.wait_for_everyone()

torch.Size([1, 1, 3200])
tensor(0.9805, device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward1>)


In [None]:
print(torch.norm(intro_latent.view(-1)))

tensor(44., device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)


In [None]:
# train(model_vae, target_dir=intro_activations, # truth_direction, 
#       lr=1e-10, num_epochs=50, logging_steps=2, 
#       # init_norm=1.0, 
#       # checkpoint_path = "intro_50_lre-10.pth",
#       init_latent= -intro_latent, 
#         save_name='random_to_intro_50_lr-8_reverse_latent',
#         num_trainings=1,
#         generation_length=10,
# )

In [None]:
print(device)

cuda:0
