In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

In [2]:
# Load your model and tokenizer
model_name = 'meta-llama/Meta-Llama-3-8B-Instruct' #"microsoft/Phi-3-mini-4k-instruct"#'meta-llama/Meta-Llama-3-8B-Instruct' #"microsoft/Phi-3-mini-4k-instruct"# #"microsoft/Phi-3-mini-4k-instruct" #meta-llama/Meta-Llama-3-8B-Instruct" 

model = AutoModelForCausalLM.from_pretrained(model_name, 
                                            device_map="auto", 
                                            torch_dtype=torch.bfloat16,
                                            trust_remote_code=True,
                                            attn_implementation="flash_attention_2")# So we can do gradient checkpointing

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:





# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Load dataset
dataset = load_dataset("your-dataset")

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Create a custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.tokenized_dataset = tokenized_dataset

    def __len__(self):
        return len(self.tokenized_dataset)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]).to(torch.bfloat16) for key, val in self.tokenized_dataset.items()}
        return item

# Create DataLoader with a smaller batch size
train_dataset = CustomDataset(tokenized_datasets['train'])
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)  # Adjust batch size

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Training loop with memory monitoring
model.train()
num_epochs = 3
for epoch in range(num_epochs):
    for batch in train_dataloader:
        inputs = {key: val.to('cuda') for key, val in batch.items()}
        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # Monitor memory usage
        print(f"Memory Allocated: {torch.cuda.memory_allocated() / (1024 ** 3):.2f} GB")
        print(f"Memory Reserved: {torch.cuda.memory_reserved() / (1024 ** 3):.2f} GB")

    print(f"Epoch {epoch + 1}/{num_epochs} completed")

# Save the model
model.save_pretrained("./results")
tokenizer.save_pretrained("./results")