In [None]:
context_length = 512
batch_size = 4
n_embed = 512  #Number of embedding dimensions
n_layers = 12
n_heads = 8
vocab_size = 50304
resume_training = False
grad_accum_steps = 20
device = "cuda" if torch.cuda.is_available() else "cpu"

model_parameters = dict( vocab_size = vocab_size, context_length = context_length, n_embed = n_embed, n_layers = n_layers,
    n_heads = n_heads)

class DataLoader:
    def __init__(self, file_name, batch_size, context_length):
      self.data = np.memmap(file_name, dtype=np.uint16, mode='r+')
      self.n_examples = len(self.data)
      self.current = 0
      self.batch_size = batch_size
      self.context_length = context_length
      self.full_context = self.batch_size * self.context_length
      self.n_valid_sequences = (self.n_examples - 1) // self.full_context
      self.indexes = np.arange(0, self.n_valid_sequences * self.full_context, self.full_context)

    def shuffle(self):
      np.random.shuffle(self.indexes)

    def get_batch(self):
      if self.current == self.n_valid_sequences:
        self.current = 0
        self.shuffle()

      idx = self.indexes[self.current]
      x = torch.from_numpy((self.data[idx: idx + self.full_context]).astype(np.int64)).view(-1,self.context_length)
      y = torch.from_numpy((self.data[idx+1 : idx + self.full_context+1]).astype(np.int64)).view(-1,self.context_length)
      self.current += 1

      x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

      return x,y

config = MiniGPTConfig(**model_parameters)
model = MiniGPT(config).to(device)
train_loader = DataLoader("train.bin", batch_size, context_length)
val_loader = DataLoader("val.bin", batch_size, context_length)

model = torch.compile(model)
checkpoints_dir = "checkpoints"
os.makedirs(checkpoints_dir, exist_ok=True)

if resume_training == True:
   checkpoint_dir = os.path.join(checkpoints_dir,"step_500.pt")
   state_dict = torch.load(checkpoint_dir, map_location= device)
   model.load_state_dict(state_dict["model"])
   for optimizer, state in zip(optimizers, state_dict["optimizer_states"]):
        optimizer.load_state_dict(state)
   loss = state_dict["val_loss"]
   step = state_dict["step"]

In [None]:
gradients = []

for small_grad_step in range(grad_accum_steps):
    xb, yb = train_loader.get_batch()  
    xb, yb = xb.to(device), yb.to(device)

    logits, loss = model(xb, yb) 
    loss.backward()

    step_gradients = [param.grad.clone().detach() for param in model.parameters()]
    gradients.append(step_gradients)
    model.zero_grad(set_to_none=True)

gradients = np.array([[g.cpu().numpy() for g in step] for step in gradients], dtype=object)
mean_gradients = [np.mean([grad[step] for grad in gradients], axis=0) for step in range(len(gradients[0]))]

deviations = []
for step in range(grad_accum_steps):
    step_deviation = [
        grad[step] - mean_gradients[step]
        for grad in gradients
    ]
    deviations.append(step_deviation)

covariances = []
for param_index in range(len(mean_gradients)):
    deviations_matrix = np.stack([deviation[param_index] for deviation in deviations])
    param_covariance = np.cov(deviations_matrix, rowvar=False)
    covariances.append(param_covariance)

noise_scales = []
for param_index in range(len(covariances)):
    param_covariance = covariances[param_index]
    mean_gradient = mean_gradients[param_index]

    trace = np.trace(param_covariance)
    mean_grad_norm_squared = np.linalg.norm(mean_gradient) ** 2

    if mean_grad_norm_squared > 0:
        noise_scale = trace / mean_grad_norm_squared
        noise_scales.append(noise_scale)

total_noise_scale = sum(noise_scales)

print(f"Gradient Noise Scale: {total_noise_scale}")