In [1]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, weight, r, alpha):
        super(LoRALayer, self).__init__()
        self.weight = weight
        self.weight.requires_grad = False
        self.r = r
        self.alpha = alpha
        out_features = self.weight.shape[0]
        in_features = self.weight.shape[1]
        self.A = nn.Parameter(self.weight.new_zeros(self.r, in_features))
        self.B = nn.Parameter(self.weight.new_zeros(out_features, r))
    
    def forward(self, x):
        result = x @ self.weight.T
        result += x @ (self.A.T @ self.B.T)
        return result

In [2]:
class FFN(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels):
        super(FFN, self).__init__()
        self.linear1 = nn.Linear(in_channels, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, out_channels)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return self.sigmoid(x)

In [3]:
from torch.utils.data import DataLoader, TensorDataset

ffn = FFN(2, 16, 1)
x_xor = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y_xor = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

dataset_xor = TensorDataset(x_xor, y_xor)
dataloader_xor = DataLoader(dataset_xor, batch_size=1, shuffle=True)

def train_xor_model(model, dataloader):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(400):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

def validate_xor_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No gradients needed for predictions
        for inputs, labels in dataloader:
            outputs = model(inputs)
            print(f"Input: {inputs.numpy()}, Predicted: {outputs.numpy()}")

train_xor_model(ffn, dataloader_xor)
validate_xor_model(ffn, dataloader_xor)


  from .autonotebook import tqdm as notebook_tqdm


Epoch 10, Loss: 0.15251220762729645
Epoch 20, Loss: 0.07886414974927902
Epoch 30, Loss: 0.08194424957036972
Epoch 40, Loss: 0.08549709618091583
Epoch 50, Loss: 0.0413505844771862
Epoch 60, Loss: 0.03277064487338066
Epoch 70, Loss: 0.009126568213105202
Epoch 80, Loss: 0.006150983739644289
Epoch 90, Loss: 0.0108120022341609
Epoch 100, Loss: 0.008169732056558132
Epoch 110, Loss: 0.002343089086934924
Epoch 120, Loss: 0.005173786543309689
Epoch 130, Loss: 0.0034724685829132795
Epoch 140, Loss: 0.00361334509216249
Epoch 150, Loss: 0.001160691026598215
Epoch 160, Loss: 0.002678877441212535
Epoch 170, Loss: 0.0018410688498988748
Epoch 180, Loss: 0.0009976484579965472
Epoch 190, Loss: 0.0008719780016690493
Epoch 200, Loss: 0.001283199992030859
Epoch 210, Loss: 0.0011635959381237626
Epoch 220, Loss: 0.0005112708895467222
Epoch 230, Loss: 0.00046528075472451746
Epoch 240, Loss: 0.00106926285661757
Epoch 250, Loss: 0.00039106555050238967
Epoch 260, Loss: 0.0004478039045352489
Epoch 270, Loss: 0.00

In [4]:
ffn_weight = ffn.linear1.weight.detach().clone()
lora_layer = LoRALayer(ffn_weight, 1, 0.1)
setattr(ffn, 'linear1', lora_layer)

y_or = torch.tensor([[0], [1], [1], [1]], dtype=torch.float32)

dataset_xor = TensorDataset(x_xor, y_or)
dataloader_xor = DataLoader(dataset_xor, batch_size=1, shuffle=True)

train_xor_model(ffn, dataloader_xor)
validate_xor_model(ffn, dataloader_xor)

Epoch 10, Loss: 0.31432613730430603
Epoch 20, Loss: 0.004087598063051701
Epoch 30, Loss: 0.9512374401092529
Epoch 40, Loss: 0.26251086592674255
Epoch 50, Loss: 0.25592029094696045
Epoch 60, Loss: 0.23843255639076233
Epoch 70, Loss: 7.103406460373662e-06
Epoch 80, Loss: 0.01003202609717846
Epoch 90, Loss: 0.008165196515619755
Epoch 100, Loss: 7.21248215995729e-06
Epoch 110, Loss: 0.00016673911886755377
Epoch 120, Loss: 0.0001731382799334824
Epoch 130, Loss: 0.004492830950766802
Epoch 140, Loss: 0.07628859579563141
Epoch 150, Loss: 0.00018565756909083575
Epoch 160, Loss: 0.003221475752070546
Epoch 170, Loss: 0.053204506635665894
Epoch 180, Loss: 8.907381925382651e-06
Epoch 190, Loss: 0.00019771434017457068
Epoch 200, Loss: 9.146630873146933e-06
Epoch 210, Loss: 0.002004998968914151
Epoch 220, Loss: 0.00020425960246939212
Epoch 230, Loss: 0.02937709540128708
Epoch 240, Loss: 9.550826689519454e-06
Epoch 250, Loss: 0.0014442935353145003
Epoch 260, Loss: 9.695789231045637e-06
Epoch 270, Loss

In [5]:
ffn.modules

<bound method Module.modules of FFN(
  (linear1): LoRALayer()
  (linear2): Linear(in_features=16, out_features=1, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)>

In [6]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
print(model.modules)

<bound method Module.modules of GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)>


In [18]:
for name, module in model.named_modules():
    print(name, module.__class__.__name__)


transformer
transformer.wte
transformer.wpe
transformer.drop
transformer.h
transformer.h.0
transformer.h.0.ln_1
transformer.h.0.attn
transformer.h.0.attn.c_attn
transformer.h.0.attn.c_proj
transformer.h.0.attn.attn_dropout
transformer.h.0.attn.resid_dropout
transformer.h.0.ln_2
transformer.h.0.mlp
transformer.h.0.mlp.c_fc
transformer.h.0.mlp.c_proj
transformer.h.0.mlp.act
transformer.h.0.mlp.dropout
transformer.h.1
transformer.h.1.ln_1
transformer.h.1.attn
transformer.h.1.attn.c_attn
transformer.h.1.attn.c_proj
transformer.h.1.attn.attn_dropout
transformer.h.1.attn.resid_dropout
transformer.h.1.ln_2
transformer.h.1.mlp
transformer.h.1.mlp.c_fc
transformer.h.1.mlp.c_proj
transformer.h.1.mlp.act
transformer.h.1.mlp.dropout
transformer.h.2
transformer.h.2.ln_1
transformer.h.2.attn
transformer.h.2.attn.c_attn
transformer.h.2.attn.c_proj
transformer.h.2.attn.attn_dropout
transformer.h.2.attn.resid_dropout
transformer.h.2.ln_2
transformer.h.2.mlp
transformer.h.2.mlp.c_fc
transformer.h.2.mlp

In [8]:
from transformers.pytorch_utils import Conv1D

class LoRAConv1D(nn.Module):
    def __init__(self, weight, bias, r, alpha):
        super(LoRAConv1D, self).__init__()
        self.nf, self.nx = weight.shape 
        self.weight = weight
        self.weight.requires_grad = False
        self.bias = bias
        self.r = r
        self.alpha = alpha
        self.A = nn.Parameter(self.weight.new_zeros(self.r, self.nx))
        self.B = nn.Parameter(self.weight.new_zeros(self.nf, self.r))
    
    def forward(self, x):
        print(x.shape)
        size_out = x.size()[:-1] + (self.nf,)
        result = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        low_rank = self.B @ self.A
        result += x.view(-1, x.size(-1)) @ low_rank
        x = x.view(size_out)
        return x

In [19]:
for name, module in model.named_modules():
    if isinstance(module, Conv1D) and "c_attn" in str(name):
        print(name)
        print(module.nf)
        print(module.weight.shape)
        print(module.bias.shape)
    

print("------")
weight = torch.randn([768, 2304])
x = torch.randn([4,473,768])
size_out = x.size()[:-1] + (2304, )
A = torch.randn([8, 2304])
B = torch.randn([768, 8])
result_1 = x.view(-1, x.size(-1)) @ weight
result_2 = B @ A
print((result_1).shape)
print((result_2).shape)
result_2 = x.view(-1, x.size(-1)) @ result_2
print((result_2).shape)
print((result_1 + result_2).shape)


transformer.h.0.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.1.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.2.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.3.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.4.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.5.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.6.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.7.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.8.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.9.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.10.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
transformer.h.11.attn.c_attn
2304
torch.Size([768, 2304])
torch.Size([2304])
------
torch.Size([1892, 2304])
torch.Size([768, 2304])
torch.Size([1892, 2304])
torch

In [20]:
#replace all the attention layers in model with LoRA layers
r = 8
alpha = 0
for name, module in model.named_modules():
    if isinstance(module, Conv1D) and "c_attn" in str(name):
        lora_layer = LoRAConv1D(module.weight, module.bias, r, alpha)
        # Replace the module directly in the parent's _modules dictionary
        parent_name, child_name = name.rsplit('.', 1)
        parent_module = dict(model.named_modules())[parent_name]
        parent_module._modules[child_name] = lora_layer

In [41]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): LoRAConv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [49]:
for _, param in model.named_parameters():
    param.requires_grad = False

for name, module in model.named_modules():
    if isinstance(module, LoRAConv1D):
        for param in module.parameters():
            param.requires_grad = True

for name, param in model.named_parameters():
    if "attn.c_attn" in name: assert param.requires_grad == True
    else: assert param.requires_grad == False


In [50]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_scheduler
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch.cuda.amp import GradScaler, autocast
import tqdm

# Load dataset
dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')
texts = dataset['train']['text'][:500]  # Using a small slice for quick training

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Tokenize data
encodings = tokenizer(texts, truncation=True, padding=True, max_length=512, return_tensors="pt")

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Prepare data for training
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
dataset = TensorDataset(input_ids, attention_mask)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(dataloader)*3)

# Setup for mixed-precision training
scaler = GradScaler()

# Training loop
model.train()
progress_bar = tqdm.tqdm(range(len(dataloader) * 3), desc="Training")
for epoch in range(3):  # 3 epochs
    for batch in dataloader:
        optimizer.zero_grad()
        
        input_ids, attention_mask = batch[0].to(device), batch[1].to(device)

        with autocast():
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.item())

progress_bar.close()
print(f"Final loss: {loss.item()}")



: 

: 

: 

In [13]:
import numpy as np
import torch.nn.functional as F
dataset = load_dataset("cais/mmlu")
choices = ['A', 'B', 'C', 'D']

def format_subject(subject):
    return ' '.join(subject.split('_'))

def format_example(df, idx, choices, include_answer=True):
    prompt = df['question'][idx]
    for j, choice in enumerate(choices):
        prompt += f"\n{j+1}. {df[choice][idx]}"
    if include_answer:
        prompt += f"\nAnswer: {df['correct_answer'][idx]}\n\n"
    return prompt

def gen_prompt(df, subject, n_examples=-1):
    subject_formatted = format_subject(subject)
    prompt = f"The following are multiple choice questions (with answers) about {subject_formatted}.\n\n"
    max_examples = df.shape[0] if n_examples == -1 else n_examples
    for i in range(max_examples):
        prompt += format_example(df, i, ['choice1', 'choice2', 'choice3', 'choice4'], include_answer=True)
    return prompt

# Evaluation function
@torch.no_grad()
def evaluate(model, tokenizer, dev_df, test_df, subject, num_train_examples):
    cors = []
    all_probs = []

    for i in range(len(test_df)):
        train_prompt = gen_prompt(dev_df, subject, num_train_examples)
        prompt_end = format_example(test_df, i, ['choice1', 'choice2', 'choice3', 'choice4'], include_answer=False)
        prompt = train_prompt + prompt_end
        
        # Tokenize and ensure input length is within the model's limits
        input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids.to("cuda")
        
        # Generate logits
        outputs = model(input_ids=input_ids)
        logits = outputs.logits
        
        probs = (
            F.softmax(
                torch.tensor(
                    [
                        logits[tokenizer("A").input_ids[0]],
                        logits[tokenizer("B").input_ids[0]],
                        logits[tokenizer("C").input_ids[0]],
                        logits[tokenizer("D").input_ids[0]],
                    ]
                ),
                dim=0,
            )
            .detach()
            .cpu()
            .numpy()
        )
        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
        pred = np.argmax(probs)

        # Check if the prediction is correct
        correct = test_df['correct_answer'][i]
        cor = pred == correct
        cors.append(cor)
        all_probs.append(probs)

    accuracy = np.mean(cors)
    print(f"Average accuracy: {accuracy:.3f} - Subject: {subject}")

    return np.array(cors), accuracy, np.array(all_probs)

# Example usage
dev_df = dataset['validation']
test_df = dataset['test']
subject = 'subject_name_here'  # Replace with actual subject
num_train_examples = 5  # Number of training examples to include in each prompt

# Evaluate model
results = evaluate(model, tokenizer, dev_df, test_df, subject, num_train_examples)

Downloading data: 100%|██████████| 3.50M/3.50M [00:00<00:00, 13.8MB/s]
Downloading data: 100%|██████████| 408k/408k [00:00<00:00, 2.52MB/s]
Downloading data: 100%|██████████| 76.5k/76.5k [00:00<00:00, 517kB/s]
Downloading data: 100%|██████████| 47.5M/47.5M [00:00<00:00, 110MB/s] 
Generating test split: 100%|██████████| 14042/14042 [00:00<00:00, 768671.99 examples/s]
Generating validation split: 100%|██████████| 1531/1531 [00:00<00:00, 557759.00 examples/s]
Generating dev split: 100%|██████████| 285/285 [00:00<00:00, 191015.76 examples/s]
Generating auxiliary_train split: 100%|██████████| 99842/99842 [00:00<00:00, 386056.02 examples/s]

1531



