# Setup

In [14]:
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, Trainer
import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
# import accelerate

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load model

In [25]:
# model_name_or_path = "ehartford/Wizard-Vicuna-30B-Uncensored"
# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
# model_name_or_path = "EleutherAI/pythia-410m"
model_name_or_path = "openlm-research/open_llama_3b_v2"
# model_name_or_path = "EleutherAI/pythia-1.4b"
# !huggingface-cli login
# model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto")
# model.to(torch.float16)
# use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
use_fast_tokenizer = False
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
# tokenizer.pad_token_id = 0

Test model

In [26]:
prompt = 'Q: What is the largest country?\nA:'
# tokenizer, model = accelerator.prepare(tokenizer, model)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

generation_output = model.generate(
    input_ids=input_ids, max_new_tokens=32
)
print(tokenizer.decode(generation_output[0]))

<s> Q: What is the largest country?
A: The largest country is Russia.
Q: What is the smallest country?
A: The smallest country is Vatican City.
Q: What is


Load and check dataset (mini pile)

In [6]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
print(dataset)
print(dataset[0]['text'][:200])

Dataset({
    features: ['text'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


In [None]:
#Load weights for optimus


Initialize adapter

In [31]:
# Create the model class for adapter, should just be a linear from dim of VAE latent to 
# class Adapter(nn.Linear):
#     def __init__(self, d_in, d_out):
#         super().__init__()
        

#     def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
#         return 
d_vae_latent = 32 # Latent size, could also be 768
d_lm_embedding = model.config.hidden_size
adapter = nn.Linear(d_vae_latent, d_lm_embedding)

In [None]:
class AdapterTrainer(Trainer):
    def __init__(self, vae, lm):
        # Trainer for an adapter from vae to lm
        super().__init__()
        self.vae = vae
        self.lm = lm
    def compute_loss(self, model, inputs, return_outputs=False):
#         text = inputs.pop("text")
        
        # forward pass
        #1. Get VAE-encoded input
        vae_encoded_inputs = #TODO
        
        #2. Pass VAE latent vector into adapter (get "output" from "model")
        outputs = model(**vae_encoded_inputs)
        
        #3. Get hidden state of LM using adapter output as input
        hidden_state_from_adapter = self.lm(outputs).hidden_state #TODO

        #4. Pass the same text into LLM and get hidden state
        hidden_state_from_text = self.lm(inputs).hidden_state #TODO
        
        #5. Compute loss with cosine similarity between hidden states 
        loss = nn.functional.cosine_similarity(hidden_state_from_adapter, hidden_state_from_text) 
        # I think need to make sure hidden states are both 1d vectors
        
        #Sample code below:
#         outputs = model(**inputs)
#         logits = outputs.get("logits")
#         # compute custom loss (suppose one has 3 labels with different weights)
#         loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
#         loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss