In [11]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')
print(model.modules)

<bound method Module.modules of GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)>


In [7]:
from transformers.pytorch_utils import Conv1D
import torch.nn as nn

class LoRAConv1D(nn.Module):
    def __init__(self, weight, bias, r, alpha):
        super(LoRAConv1D, self).__init__()
        self.nx, self.nf = weight.shape 
        self.weight = weight
        self.weight.requires_grad = False
        self.bias = bias
        self.r = r
        self.alpha = alpha
        self.A = nn.Parameter(self.weight.new_zeros(self.r, self.nx))
        self.B = nn.Parameter(self.weight.new_zeros(self.nf, self.r))
    
    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        result = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        low_rank = self.B @ self.A
        result += x.view(-1, x.size(-1)) @ low_rank.T
        result = result.view(size_out)
        return result

In [27]:
#replace all the attention layers in model with LoRA layers
r = 64
alpha = 0
for name, module in model.named_modules():
    if isinstance(module, Conv1D) and "c_attn" in str(name):
        lora_layer = LoRAConv1D(module.weight, module.bias, r, alpha)
        # Replace the module directly in the parent's _modules dictionary
        parent_name, child_name = name.rsplit('.', 1)
        parent_module = dict(model.named_modules())[parent_name]
        parent_module._modules[child_name] = lora_layer

In [28]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): LoRAConv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [8]:
for _, param in model.named_parameters():
    param.requires_grad = False

for name, module in model.named_modules():
    if isinstance(module, LoRAConv1D):
        for param in module.parameters():
            param.requires_grad = True

for name, param in model.named_parameters():
    if "attn.c_attn" in name: assert param.requires_grad == True
    else: assert param.requires_grad == False


AssertionError: 

In [12]:
r=0
for _, param in model.named_parameters():
    param.requires_grad = False

for name, module in model.named_modules():
    if "attn.c_" in name:
        for param in module.parameters():
            param.requires_grad = True

for name, param in model.named_parameters():
    if "attn.c_" in name: assert param.requires_grad == True
    else: assert param.requires_grad == False


In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
import torch
from torch.cuda.amp import GradScaler, autocast
import tqdm

# Load dataset
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
texts = dataset['train']['text']  # Using a small slice for quick training

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Tokenize data
encodings = tokenizer(texts, truncation=True, padding=True, max_length=512, return_tensors="pt")

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'load model with total params: {pytorch_total_params} for r= {r}')
model.to(device)

# Prepare data for training
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
dataset = TensorDataset(input_ids, attention_mask)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(dataloader)*5)

# Setup for mixed-precision training
scaler = GradScaler()

# Training loop
model.train()
progress_bar = tqdm.tqdm(range(len(dataloader) * 5), desc="Training")
for epoch in range(5):  # 5 epochs
    for batch in dataloader:
        optimizer.zero_grad()

        input_ids, attention_mask = batch[0].to(device), batch[1].to(device)

        with autocast():
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.item())

progress_bar.close()
print(f"Final loss: {loss.item()}")

load model with total params: 28348416 for r= 0


Training:  16%|█▋        | 1891/11475 [06:06<30:58,  5.16it/s, loss=0.681] 

In [56]:
torch.save(model, f"./gpt2_r{r}_16b_512.pt")

In [5]:
import torch
from datasets import load_dataset
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm.auto import tqdm  # Use tqdm.auto for a progress bar that automatically adjusts to the environment
from transformers import GPT2Tokenizer, GPT2LMHeadModel


def compute_perplexity(model, tokenizer, dataset, batch_size=16):
    model.eval()  # Put the model in evaluation mode
    total_loss = 0.0
    total_length = 0
    data_loader = DataLoader(dataset, batch_size=batch_size)
    progress_bar = tqdm(data_loader, desc="Computing Perplexity")

    for batch in progress_bar:
        inputs = tokenizer(batch["text"], return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
        input_ids = inputs.input_ids.to(model.device)
        attention_mask = inputs.attention_mask.to(model.device)
        labels = input_ids.clone()

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item() * input_ids.size(1)  # Multiply by the number of tokens
            total_length += input_ids.size(1)
        current_perplexity = torch.exp(torch.tensor(total_loss / total_length)).item()
        progress_bar.set_postfix({'current perplexity': current_perplexity})

    perplexity = torch.exp(torch.tensor(total_loss / total_length))
    return perplexity.item()
    
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Load the WikiText validation dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")

models_r = [64, 4, 2]
for r in models_r:
    model_name = f'gpt2_r{r}_16b_512.pt'
    model = torch.load(f'./{model_name}')
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'load {model_name} with total params: {pytorch_total_params} for r={r}')    
    model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    perplexity = compute_perplexity(model, tokenizer, dataset)
    print(f"Perplexity: {perplexity}")

load gpt2_r64_16b_512.pt with total params: 23620608 for r= 64


Computing Perplexity:   0%|          | 0/235 [00:00<?, ?it/s]

Perplexity: 1.2441805601119995
load gpt2_r4_16b_512.pt with total params: 21408768 for r= 4


Computing Perplexity:   0%|          | 0/235 [00:00<?, ?it/s]

Perplexity: 1.2483491897583008
load gpt2_r2_16b_512.pt with total params: 21335040 for r= 2


Computing Perplexity:   0%|          | 0/235 [00:00<?, ?it/s]

Perplexity: 1.2438615560531616


In [12]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
perplexity = compute_perplexity(model, tokenizer, dataset)
print(f"Perplexity: {perplexity}")

Computing Perplexity:   0%|          | 0/235 [00:00<?, ?it/s]

Perplexity: 19345.966796875


In [32]:
dataset = load_dataset("cais/mmlu", 'all')
choices = ['A', 'B', 'C', 'D']
subjects = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 
            'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 
            'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 
            'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 
            'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 
            'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 
            'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 
            'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 
            'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 
            'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 
            'virology', 'world_religions']

In [55]:
import torch
from datasets import load_dataset
import numpy as np
import torch.nn.functional as F

def format_subject(subject):
    return ' '.join(subject.split('_'))

def format_example(df, idx, include_answer=True):
    entry = df[idx]
    prompt = entry['question']
    for j, choice in enumerate(entry['choices']):
        prompt += f"\n{choices[j]}. {choice}"
    prompt += f"\nAnswer:"
    if include_answer:
        prompt += f"\n{choices[entry['answer']]}\n\n"
    return prompt

def gen_prompt(df, subject, n_examples=-1):
    subject_formatted = format_subject(subject)
    df = df.filter(lambda x: x['subject'] == subject)
    prompt = f"The following are multiple choice questions (with answers) about {subject_formatted}.\n\n"
    max_examples = df.shape[0] if n_examples == -1 else n_examples
    for i in range(max_examples):
        prompt += format_example(df, i, include_answer=True)
    return prompt

# Evaluation function
@torch.no_grad()
def evaluate(model, tokenizer, dev_df, val_df, subject, num_train_examples):
    cors = []
    all_probs = []

    for i in range(len(val_df)):
        subject = val_df['subject'][i]
        train_prompt = gen_prompt(dev_df, subject, num_train_examples)
        prompt_end = format_example(val_df, i, include_answer=False)
        prompt = train_prompt + prompt_end
        # Tokenize and ensure input length is within the model's limits
        input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids.to("cuda")
        
        # Generate logits
        outputs = model(input_ids=input_ids)
        logits = outputs.logits
        print(logits[:, :, tokenizer('A').input_ids[0]])
        print(logits.shape)
        probs = (
            F.softmax(
                torch.stack(
                    [
                        logits[:, :, tokenizer("A").input_ids[0]],
                        logits[:, :, tokenizer("B").input_ids[0]],
                        logits[:, :, tokenizer("C").input_ids[0]],
                        logits[:, :, tokenizer("D").input_ids[0]],
                    ]
                ),
                dim=0,
            )
            .detach()
            .cpu()
            .numpy()
        )
        print(probs)
        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
        pred = np.argmax(probs)

        # Check if the prediction is correct
        correct = val_df[i]['answer']
        cor = pred == correct
        cors.append(cor)
        all_probs.append(probs)

    accuracy = np.mean(cors)
    print(f"Average accuracy: {accuracy:.3f} - Subject: {subject}")

    return np.array(cors), accuracy, np.array(all_probs)

# Example usage
dev_df = dataset['dev']
val_df = dataset['validation']
subject = 'all_facts'  # Replace with actual subject
num_train_examples = 5  # Number of training examples to include in each prompt

# Evaluate model
results = evaluate(model, tokenizer, dev_df, val_df, subject, num_train_examples)

tensor([[-3.9195e+01, -8.8060e+01, -6.6721e+01, -7.6704e+01, -7.8159e+01,
         -9.2187e+01, -8.8048e+01, -7.5561e+01, -1.1095e+02, -1.1478e+02,
         -8.3513e+01, -7.6295e+01, -7.1059e+01, -1.0340e+02, -1.5245e+02,
          2.8419e+01, -4.8414e+01, -7.0397e+01, -6.6105e+01, -7.1978e+01,
         -7.9528e+01, -7.0824e+01, -8.9900e+01, -2.6417e+01, -7.7344e+01,
         -1.8183e+00, -5.4304e+01, -7.5323e+01, -9.9255e+01, -8.1886e+01,
         -8.3881e+01, -4.8993e+01, -7.6000e+01, -7.2275e+01, -7.8384e+01,
         -6.9850e+01, -7.9477e+01, -7.1126e+01, -9.4997e+01, -1.0155e+02,
         -9.5762e+01, -9.7707e+01, -4.7966e+01, -7.4537e+01, -8.7895e+01,
         -6.5548e+01, -4.2067e+01, -1.0200e+02, -8.2504e+01,  5.1081e+01,
         -1.9441e+01, -2.0478e+01, -7.1574e+01,  1.5196e+01, -5.6091e+01,
          8.5105e+01, -9.8095e+01,  1.7449e+01, -7.8288e+01, -9.2535e+01,
         -9.8119e+01, -6.3177e+01, -3.0487e+01, -6.4403e+01, -8.4969e+01,
         -8.6147e+01, -6.1624e+01, -7.

KeyError: 228

In [44]:
tokenizer('D').input_ids[0]

35