In [1]:
!pip install datasets



In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch as t

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [3]:
# Check for the availability of multiple GPUs.
if t.cuda.device_count() > 1:
    print("Let's use", t.cuda.device_count(), "GPUs!")
else:
    print('No multiple GPUs available.')

Let's use 4 GPUs!


In [4]:
from datasets import load_dataset
ds = load_dataset('stas/openwebtext-10k')
dataset = ds['train']['text']

Repo card metadata block was not found. Setting CardData to empty.


In [5]:
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
# pretrained_model.to(device)

In [6]:
def init_layer(layer: t.nn.Module):
    if isinstance(layer, t.nn.Embedding) or isinstance(layer, t.nn.Linear):
        layer.weight.data.normal_(0, 0.02)

In [7]:
class GPTBlock(t.nn.Module):
    def __init__(self, hidden_size = 768, context_length = 1024, dim_size = 3072, p_dropout = 0.1, n_heads = 12):
        super().__init__()
        self.ln_init = t.nn.LayerNorm(hidden_size)
        self.attn = t.nn.MultiheadAttention(hidden_size, n_heads, p_dropout, batch_first = True)
        mask = (t.triu(t.ones(context_length, context_length)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        self.attn_mask = t.nn.Parameter(mask, requires_grad = False)
        self.ln_intermediate = t.nn.LayerNorm(hidden_size)
        self.nn1 = t.nn.Linear(hidden_size, dim_size)
        self.nn2 = t.nn.Linear(dim_size, hidden_size)
        self.gelu = t.nn.GELU()
        self.dropout = t.nn.Dropout(p_dropout)
    
    def forward(self, x):
        resid_0 = x
        x = self.ln_init(x)
        x, _ = self.attn(x, x, x, attn_mask = self.attn_mask, need_weights = False)
        x = self.ln_intermediate(x + resid_0)
        resid_1 = x
        x = self.nn1(x)
        x = self.nn2(x)
        x = self.gelu(x)
        return self.dropout(x + resid_1)

In [8]:
class SimpleGPT2(t.nn.Module):
    def __init__(self, n_blocks = 1, vocab_size = 50257, context_length = 1024, hidden_size = 768, p_dropout = 0.1):
        super().__init__()
        self.wte = t.nn.Embedding(vocab_size, hidden_size)
        self.wpe = t.nn.Embedding(context_length, hidden_size)
        self.pe_matrix = t.nn.Parameter(t.arange(0, context_length).unsqueeze(0), requires_grad = False)
        self.dropout = t.nn.Dropout(p_dropout)
        self.gpt_blocks = t.nn.ModuleList([GPTBlock() for _ in range(n_blocks)])
        self.layernorm = t.nn.LayerNorm(hidden_size)
        self.final = t.nn.Linear(hidden_size, vocab_size)

        for layer in [self.wte, self.wpe, self.final]:
            init_layer(layer)
    
    def forward(self, input_ids: t.Tensor, attention_mask = t.Tensor):
        x = input_ids
        n, seq_len = x.shape
        hidden = self.wte(x) + self.wpe(self.pe_matrix.expand(n, -1))
        hidden = self.dropout(hidden)
        for gpt_block in self.gpt_blocks:
            hidden = gpt_block(hidden)
        hidden = self.layernorm(hidden)
        return self.final(hidden)

In [9]:
simpleGPT2 = SimpleGPT2(n_blocks = 6)
if t.cuda.device_count() > 1:
    simpleGPT2 = t.nn.DataParallel(simpleGPT2)
simpleGPT2.to(device)


# Run model on a few truncated samples ... works!

encoded_input = tokenizer(dataset[0:1], return_tensors='pt', padding='max_length', truncation=True).to(device)
print(encoded_input['attention_mask'].shape, encoded_input['attention_mask'].sum())
logits = simpleGPT2(**encoded_input)
print(logits.shape)

torch.Size([1, 1024]) tensor(1024, device='cuda:0')
torch.Size([1, 1024, 50257])


In [10]:
def loss_fn(logits, encoded_input):
    # logits: n x seq x d
    # true_tokens: n x seq
    # attention_mask = n x seq
    true_tokens = encoded_input['input_ids']
    attention_mask = encoded_input['attention_mask']
    valid_samples_mask = attention_mask[:, 1:].reshape(-1).bool()
    n, seq, d  = logits.shape
    return t.nn.functional.cross_entropy(logits[:, :-1, :].reshape(-1, d)[valid_samples_mask, :], true_tokens[:, 1:].flatten()[valid_samples_mask]), valid_samples_mask.sum()

def compute_dataset_loss(dataset, model, tokenizer, batch_size = 2):
    loss = 0
    samples = 0
    with t.no_grad():
      n = len(dataset)
      batches = n // batch_size
      for i in range(batches):
          # print(i, batch_size, loss, samples)
          batch = dataset[i:i+batch_size]
          encoded_input = tokenizer(batch, return_tensors='pt', padding='max_length', truncation=True).to(device)
          logits = model(**encoded_input)
          # Find true labels and compute loss
          ce_loss, valid_samples = loss_fn(logits, encoded_input)
          loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
          samples = samples + valid_samples
    return loss, samples

In [11]:
def compute_gpu_utilization():
    n_gpus = t.cuda.device_count()
    memory_utilization = 0;
    for i in range(n_gpus):
        memory_utilization += t.cuda.memory_allocated(i) / 1e6

    return n_gpus, memory

In [12]:
# Fine-tune the model on a subset of training set, and then evaluate on val set
#TODO Separate this code into two parts. Calculate batch time as well. Save epoch run files
import random
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torch._inductor import config

import time
import matplotlib.pyplot as plt

def train_model(dataset, optimizer, epochs, model, tokenizer, batch_size = 4):
    loss = 0
    samples = 0
    n = len(dataset)
    batches = n // batch_size
    print_interval = batches // 20
    losses = []  # Store loss for each epoch
    val_losses = [] # Store validation loss for each epoch

    scheduler = OneCycleLR(optimizer, max_lr = 2.5e-4, total_steps = epochs * batches, pct_start = 0.2)

    for epoch in range(epochs):
        start_time = time.time()
        random.shuffle(dataset)
        print("Starting epoch: ", epoch)
        for i in range(batches):
            if i % print_interval == 0:
                print(i, i/float(batches), batch_size, loss, samples)
                n_gpus, memory_utilization = compute_gpu_utilization()
                print(f"Current GPU memory usage: {memory_utilization} MB across {n_gpus} GPUs. Average utilization: {memory_utilization/n_gpus} MB")

            optimizer.zero_grad()

            batch = dataset[i:i+batch_size]
            encoded_input = tokenizer(batch, return_tensors='pt', padding='max_length', truncation=True).to(device)
            logits = model(**encoded_input)

            # Find true labels and compute loss
            ce_loss, valid_samples = loss_fn(logits, encoded_input)
            loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
            samples = samples + valid_samples

            # Backprop
            ce_loss.backward()
            optimizer.step()
            scheduler.step()

        end_time = time.time()
        epoch_time = end_time - start_time

        losses.append(loss.item())

        val_loss, _ = compute_dataset_loss(dataset, model, tokenizer, batch_size=batch_size)
        val_losses.append(val_loss.item())

        if (epoch + 1) % 5 == 0:
            checkpoint_filename = f'simpleGPT_{epoch + 1}epochs_t2_batch2.pt'
            t.save(model.state_dict(), checkpoint_filename)
            print(f"Saved model checkpoint to {checkpoint_filename}")
        
        # print(f"Epoch {epoch} finished in {epoch_time} seconds with loss {loss.item()}")
        print(f"Epoch {epoch} finished in {epoch_time} seconds with loss {loss.item()} and val_loss {val_loss}")

    # Plot loss over epochs
    plt.plot(range(epochs), losses, label='Training loss')
    plt.plot(range(epochs), val_losses, label='Validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    return loss, samples


epochs = 2

lrs = [5e-5, 5e-4, 1e-5, 2e-5]

# config.compile_threads = 1
simpleGPT2 = t.compile(simpleGPT2, mode="max-autotune")

optimizer = Adam(simpleGPT2.parameters(), lr = lrs[-1])

print(train_model(dataset[:2000*4], optimizer, epochs, simpleGPT2, tokenizer, batch_size = 12))

Starting epoch:  0
0 0.0 16 0 0


NameError: name 'memory' is not defined