In [2]:
import torch
import copy
from torch.utils.data import DataLoader
import json
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from transformers.optimization import AdamW

In [16]:
def mem_stats():
    t = torch.cuda.get_device_properties(0).total_memory / 1024**3
    r = torch.cuda.memory_reserved(0) / 1024**3
    a = torch.cuda.memory_allocated(0) / 1024**3
    print(f"Total Memory: {t:.2f} GB\n"
          f"Reserved Memory: {r:.2f} GB ({(100*(r/t)):.2f}%)\n"
          f"Remaining Memory: {t-r:.2f} GB ({(100*(t-r)/t):.2f}%)\n"
          f"---------------------------------\n"
          f"Allocated Memory: {a:.2f} GB ({(100*(a/t)):.2f}%)\n"
          f"Percent of Reserved Allocated: {(100*(a+1e-9)/(r+1e-9)):.2f}%\n")

In [17]:
## Collate functions for loading dataset
def collate_fn(batch):
    tokens = [tokenizer.encode(example["text"], return_tensors="pt", truncation=True) for example in batch]
    max_length = max([t.size(1) for t in tokens])
    tokens_padded = [torch.cat([t, t.new_zeros(t.size(0), max_length - t.size(1))], dim=1) for t in tokens]
    tokens_padded = torch.cat(tokens_padded, dim=0)
    return tokens_padded

def collate_already_encoded(batch):
    tokens = batch
    max_length = max([len(t['tokens']) for t in tokens])
    tokens_padded = torch.zeros((len(tokens),max_length),dtype=torch.int)
    for i in range(len(tokens)):
        tokens_padded[i,:] = torch.Tensor(tokens[i]['tokens'])
    return tokens_padded

validation_dataset = load_dataset("the_pile_val.py", split="validation") 
mem_stats()

tensor([1., 2., 3.])

In [24]:
## Model and tokenizer
model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-1.4B-deduped",
  revision="step143000",
  cache_dir="./pythia-1.4B-deduped/step143000",
)

tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-1.4B-deduped",
  revision="step143000",
  cache_dir="./pythia-1.4B-deduped/step143000",
)

model_name = "EleutherAI/pythia-1.4B-deduped"
model_revision = "step143000"
model_cache_dir = "./pythia-1.4B-deduped/step143000"

device = "cuda" if torch.cuda.is_available() else "cpu"

model.half()
mem_stats()

tensor([2., 3.])

In [None]:
def compute_batch_perplexity(model, batch_data, bs, samplelength, device):
    with torch.no_grad():               
        ## Get predictions on validation data                 
        batch_x = batch_data[:,:samplelength].to(device).detach()
        mask  = (batch_x>0).detach()                                     

        model.train(False)
        logits_validation = model(input_ids=batch_x, attention_mask = mask)

        ## Find sum of log likelihood of each sequence
        ans = torch.zeros(bs)
        for batch in range(bs):
            seq_logits = torch.zeros((len(batch_x[0])))

            ## Find logits of each word in the sequence
            for idx, w in enumerate(batch_x[batch,]):
                if batch_x[batch, idx] == 0:
                    break
                seq_logits[idx] = logits_validation.logits[batch,idx,w]
            if batch_x[batch,idx] != 0:
                idx += 1
            ans[batch] = torch.sum(seq_logits[:idx]-torch.log(1+torch.exp(seq_logits[:idx])))/idx

            ## Clean up 
            del seq_logits
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        ## Cleaning up
        del batch_x, mask, logits_validation
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    return ans

In [None]:
def compare_models(new_model, n_new_models, noise_variance, 
                   dataloader, nbatches, bs, samplelength, device):
  
  new_model.eval()
  new_model.to(device)
  ans = torch.zeros((n_new_models+1, nbatches, bs))  
  
  for ind_model in range(n_new_models):
    if ind_model == 0:
      for batchno, batch_data in enumerate(dataloader):
        if batchno < nbatches:
          ans[ind_model,batchno,:] = compute_batch_perplexity(model, batch_data, bs, samplelength, device)
        else:
          break
    else:
        prevseed = torch.seed()
        with torch.no_grad():
          for param in new_model.parameters():
              param.add_((torch.randn(param.size()) * noise_variance).to(device))
        for batchno, batch_data in enumerate(dataloader):
            if batchno < nbatches:
              ans[ind_model,batchno,:] = compute_batch_perplexity(model, batch_data, bs, samplelength, device)
            else:
              break

        torch.manual_seed(prevseed)
        with torch.no_grad():
          for param in new_model.parameters():
              param.add_(-(torch.randn(param.size()) * noise_variance).to(device))

    print(ind_model)
    mem_stats()
  return ans 

In [None]:
n_new_models = 50
noise_variance =  0.001
bs = 8
samplelength = 10
nbatches = 500

training_dataset = load_dataset("EleutherAI/pile-deduped-pythia-random-sampled", split="train")
validation_dataset = load_dataset("the_pile_val.py", split="validation") 

training_dataloader = DataLoader(training_dataset, batch_size = bs, collate_fn=collate_already_encoded)
validation_dataloader = DataLoader(validation_dataset, batch_size = bs, collate_fn=collate_fn)

training = compare_models(model, n_new_models, noise_variance, validation_dataloader, nbatches, bs, samplelength, device)
validation = compare_models(model, n_new_models, noise_variance, validation_dataloader, nbatches, bs, samplelength, device)