In [1]:
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import LlamaModel, LlamaConfig
from datasets import load_dataset
from torch import nn, optim
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
import json

In [4]:
config = json.load(open('config.json', 'r'))
for key in ['_name_or_path', 'architectures']:
    del config[key]
config

{'bos_token_id': 1,
 'eos_token_id': 2,
 'hidden_act': 'silu',
 'hidden_size': 32,
 'initializer_range': 0.02,
 'intermediate_size': 256,
 'max_position_embeddings': 1024,
 'model_type': 'llama',
 'num_attention_heads': 4,
 'num_hidden_layers': 4,
 'pad_token_id': 0,
 'rms_norm_eps': 1e-06,
 'tie_word_embeddings': False,
 'torch_dtype': 'bfloat16',
 'transformers_version': '4.30.2',
 'use_cache': True,
 'vocab_size': 32000}

In [3]:
# configuration = LlamaConfig(bos_token_id = 1,
#                             eos_token_id = 2,
#                             hidden_act = "silu",
#                             hidden_size = 32,
#                             initializer_range = 0.02,
#                             intermediate_size = 256,
#                             max_position_embeddings = 1024,
#                             model_type = "llama",
#                             num_attention_heads = 4,
#                             num_hidden_layers = 4,
#                             pad_token_id = 0,
#                             rms_norm_eps = 1e-06,
#                             tie_word_embeddings = False,
#                             torch_dtype = "bfloat16",
#                             transformers_version = "4.30.2",
#                             use_cache = True,
#                             vocab_size = 32000,
# )



configuration = LlamaConfig(
    **config
)

llama2_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
llama2_model.config = configuration
llama2_model.model = LlamaModel(configuration)
llama2_model.vocab_size = configuration.vocab_size
llama2_model.lm_head = nn.Linear(configuration.hidden_size, configuration.vocab_size, bias=False)
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
raw_datasets = load_dataset("glue", "mrpc")

def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], padding="max_length", truncation=True, return_tensors='pt')

def collate_batch(batch):
    return {
        'input_ids': torch.tensor([item['input_ids'] for item in batch]),
        'attention_mask': torch.tensor([item['attention_mask'] for item in batch]),
        'labels': torch.tensor([item['input_ids'] for item in batch])
    }

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"]
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=collate_batch)
valid_dataset = tokenized_datasets["validation"]
valid_dataloader = DataLoader(valid_dataset, shuffle=True, batch_size=1, collate_fn=collate_batch)

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(llama2_model.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llama2_model.to(device)

num_epochs = 3
for epoch in range(num_epochs):
    llama2_model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['input_ids'].to(device)  # In causal LM, labels are usually the input_ids
        outputs = llama2_model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    print(f"Average training loss: {total_loss / len(train_dataloader)}")

    llama2_model.eval()
    total_eval_loss = 0
    with torch.no_grad():
        for batch in valid_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['input_ids'].to(device)
            outputs = llama2_model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_eval_loss += loss.item()
    print(f"Validation loss: {total_eval_loss / len(valid_dataloader)}")

100%|██████████| 3668/3668 [55:59<00:00,  1.09it/s]


Average training loss: 0.25915195561860127
Validation loss: 0.10711430107225098


  9%|▉         | 348/3668 [05:19<50:46,  1.09it/s]


KeyboardInterrupt: 