## save to drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!mv /content/pruned_transformer30pct.pth /content/pruned_transformer30pct_all_layers.pth

In [None]:
!cp /content/pruned_transformer30pct_all_layers.pth /content/gdrive/MyDrive/Research/Data/

## imports

In [None]:
!pip install transformers datasets evaluate



In [None]:
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
from datasets import load_dataset
from evaluate import load
from transformers.data.metrics import squad_metrics

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from copy import deepcopy
import math
import random
from transformers.pytorch_utils import find_pruneable_heads_and_indices
import time

## pruning algorithm steps

### step 1: compute gradient

In [None]:
def compute_grads(pruned_layer, dense_layer, H, attn_mask, head_mask, pruned_indices, device='cpu'):
  query, key, value = H, H, H
  U_dense = dense_layer(query, key, value, attn_mask, head_mask=head_mask, output_attentions=False)[-1]

  with torch.no_grad():
    if len(pruned_indices) > 0:
      U_pruned = pruned_layer(query, key, value, attn_mask, head_mask=head_mask, output_attentions=False)[-1]

  if len(pruned_indices) > 0:
    loss = residual_objective(U_pruned - U_dense)
  else:
    loss = residual_objective(U_dense)

  loss.backward()

  importance = None
  with torch.no_grad():
    grad_dict = {}
    for name, parameter in dense_layer.named_parameters():
      grad_dict[name] = parameter.grad

      if 'weight' in name:
        if 'out' in name:
          grad_reshaped = parameter.grad.view(dense_layer.dim, dense_layer.n_heads, dense_layer.attention_head_size)
          imp_by_head = torch.norm(grad_reshaped, p=2, dim=(0,2))
        else:
          grad_reshaped = parameter.grad.view(dense_layer.n_heads, dense_layer.attention_head_size ,dense_layer.dim)
          imp_by_head = torch.norm(grad_reshaped, p=2, dim=(1,2))

        if importance is None:
          importance = torch.zeros_like(imp_by_head)
        importance += imp_by_head

    pruned_layer.zero_grad()
    dense_layer.zero_grad()

  return importance, grad_dict, loss.item()

### step 2: find best s columns of grad outside S and merge with S

In [None]:
@torch.no_grad()
def find_and_merge(importance, S, n_heads_to_keep, device='cpu'):
  for index in S:
    importance[index] = 0
  imp_top_idxs = torch.argsort(importance, descending=True)[:n_heads_to_keep]
  imp_top_idxs = set(imp_top_idxs.tolist())
  # print(f'top idxs of grad: {imp_top_idxs}, S: {S}')
  D = S.union(imp_top_idxs)

  return D

### step 3: update parameters by gradient descent focused on D

In [None]:
@torch.no_grad()
def update_step(args, Q, dense_layer, grad_dict, qkv_optimizer, out_optimizer, D, device=None):
  if not args.maintain_Q:
    Q = deepcopy(dense_layer)

  for name, Q_param in Q.named_parameters():
    if name in grad_dict:
      Q_param.grad = grad_dict[name]
    else:
      print(f'No gradient for {name}')

  indexer = get_mask_indexer(D, Q.n_heads, Q.attention_head_size)
  if args.maintain_Q:
    qkv_optimizer.step(indexer)
    out_optimizer.step(indexer)

  else:
    for name, Q_param in Q.named_parameters():
      grad = Q_param.grad
      if 'out' in name:
        if 'weight' in name:
          grad[:, indexer] = 0
      else:
        grad[indexer] = 0
      Q_param = Q_param - args.eta * grad

  return Q

### step 4: truncate Q to be s-sparse

In [None]:
@torch.no_grad()
def truncate(Q, pruned_layer, H, attn_mask, head_mask, n_heads_to_keep, device='cpu'): # truncate based on weights? or combine with WANDA? or truncate based on output? create args.trunc_strategy and try all of these.
  # attn = Q(H, H, H, attn_mask, head_mask=head_mask, output_attentions=False)[-1] # can also compute some norm over the different params
  # attn_sep = separate_heads(attn, attn.shape[0], Q.n_heads, Q.dim // Q.n_heads)
  # attn_imp = torch.norm(attn_sep, p=2, dim=(0, 2, 3))

  importance = None
  if args.trunc_strategy == 'magnitude':
    for name, param in Q.named_parameters():
      if 'weight' in name:
        if 'out' in name:
          param_reshaped = param.view(Q.dim, Q.n_heads, Q.attention_head_size)
          imp_by_head = torch.norm(param_reshaped, p=2, dim=(0,2))
        else:
          grad_reshaped = param.view(Q.n_heads, Q.attention_head_size, Q.dim)
          imp_by_head = torch.norm(grad_reshaped, p=2, dim=(1,2))
        if importance is None:
          importance = torch.zeros_like(imp_by_head)
        importance += imp_by_head
  imp_top_idxs = torch.argsort(importance, descending=True)[:n_heads_to_keep]
  S = set(imp_top_idxs.cpu().tolist())

  pruned_layer = deepcopy(Q)
  not_S = set(range(Q.n_heads)).difference(S)
  pruned_layer.prune_heads(not_S)

  return pruned_layer, S

### step 5: debias

In [None]:
def debias(args, pruned_layer, dense_layer, input_loader, iters, head_mask, eta=None, device='cpu'):
  # optimizer = optim.SGD(pruned_layer.parameters())

  optimizer = optim.Adam(pruned_layer.parameters(), lr=args.eta)
  for i in range(iters):
    attn_mask, H = next(iter(input_loader))

    query = key = value = H.to(device)
    attn_mask.to(device)

    pruned_layer.zero_grad()
    dense_layer.zero_grad()

    U_pruned = pruned_layer(query, key, value, attn_mask, head_mask=head_mask, output_attentions=False)[-1]
    with torch.no_grad():
      U_dense = dense_layer(query, key, value, attn_mask, head_mask=head_mask, output_attentions=False)[-1]

    loss = residual_objective(U_pruned - U_dense)
    loss.backward()

    with torch.no_grad():
      optimizer.step()

  return pruned_layer

## pruning algorithm utils

In [None]:
def validate(args, model, ids, inputs, answers, metric, device='cpu'):
  model = model.to(device)
  inputs = inputs.to(device)
  with torch.no_grad():
    start_sp = time.time()
    outputs_sp = model(**inputs)
    end_sp = time.time()
    if args.verbose:
      print(f'done predicting using sparse model. time elapsed = {end_sp - start_sp}s')

  model = model.cpu()
  inputs = inputs.to('cpu')

  num_examples = len(ids)
  preds = []
  refs = []
  for i in range(num_examples):
    answer_start_index_sp = torch.argmax(outputs_sp.start_logits[i])
    answer_end_index_sp = torch.argmax(outputs_sp.end_logits[i])
    predict_answer_tokens_sp = inputs.input_ids[i, answer_start_index_sp : answer_end_index_sp + 1]
    pred_sp = tokenizer.decode(predict_answer_tokens_sp)

    pred = {'id': ids[i], 'prediction_text': pred_sp}
    preds.append(pred)
    ref = {'answers': answers[i], 'id': ids[i]}
    refs.append(ref)

  results = metric.compute(predictions=preds, references=refs)

  #   possible_answers = answers[i]['text']
  #   for j in range(len(possible_answers)):
  #     possible_answers[j] = possible_answers[j].lower()

  #   correct_sp = int(answer_sp in possible_answers)

  #   if args.verbose:
  #     print('\nTEXT:', texts[i])
  #     print('QUESTION:', questions[i])
  #     print('MODEL ANSWER:', answer_sp)
  #     print('EXPECTED ANSWER:', possible_answers)
  #     print('CORRECT' if correct_sp else 'WRONG')

  #   total_correct_sp += correct_sp

  # avg_correct_sp = total_correct_sp / num_examples

  # if args.verbose:
  #   print('validation accuracy:', avg_correct_sp*100)

  return results

In [None]:
class SparseAdam(torch.optim.Optimizer):
  def __init__(self, params, lr=5e-5, betas=(0.9, 0.999), eps=1e-8, correct_bias=False, sparsity='row'):
    if lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
    if not 0.0 <= eps:
        raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
    if sparsity != 'col' and sparsity != 'row':
        raise ValueError(f"Invalid sparsity value: {sparsity} - must be 'row' or 'col'")

    defaults = dict(lr=lr, betas=betas, eps=eps, correct_bias=correct_bias, sparsity=sparsity)
    super(SparseAdam, self).__init__(params, defaults)

  def step(self, indexer):
    for group in self.param_groups:
      for p in group["params"]:
        if p.grad is None:
          continue
        grad = p.grad

        sparsity = group["sparsity"]
        if sparsity == 'row':
          grad[indexer] = 0 # zero out grad
        elif sparsity == 'col':
          grad[:, indexer] = 0

        state = self.state[p]

        # State initialization
        if len(state) == 0:
          state["step"] = 0
          # Exponential moving average of gradient values
          state["m"] = torch.zeros_like(p)
          # Exponential moving average of squared gradient values
          state["v"] = torch.zeros_like(p)

        m_tm1, v_tm1 = state["m"], state["v"]
        beta1, beta2 = group["betas"]

        state["step"] += 1

        m_t = beta1 * m_tm1 + (1 - beta1) * grad
        if sparsity == 'row':
          m_t[indexer] = 0 # zero out m_t
        elif sparsity == 'col':
          m_t[:, indexer] = 0

        v_t = beta2 * v_tm1 + (1 - beta2) * torch.pow(grad, 2)
        if sparsity == 'row':
          v_t[indexer] = 0 # zero out v_t
        elif sparsity == 'col':
          v_t[:, indexer] = 0

        step_size = group["lr"]
        if group["correct_bias"]:  # No bias correction for Bert
          bias_correction1 = 1.0 - beta1 ** state["step"]
          bias_correction2 = 1.0 - beta2 ** state["step"]
          step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

        p.data = p.data - step_size * m_t / (torch.sqrt(v_t) + group["eps"])

        state["m"] = m_t
        state["v"] = v_t

In [None]:
class InputDataset(Dataset):
  def __init__(self, attn_masks, hidden_states):
    self.attn_masks = attn_masks
    self.hidden_states = hidden_states

  def __len__(self):
    return self.hidden_states.shape[0]

  def __getitem__(self, idx):
    return self.attn_masks[idx], self.hidden_states[idx]

In [None]:
@torch.no_grad()
def zero_out(S, dense_layer, pruned_layer):
  print('zeroing out')
  heads_to_keep = S
  heads_to_prune = set(range(layer.n_heads)).difference(heads_to_keep)
  indexer_to_prune = get_mask_indexer(heads_to_keep, layer.n_heads, layer.attention_head_size)

  for name, parameter in layer.named_parameters():
    if 'bias' in name:
      if 'out' in name:
        pass
      else:
        parameter.data[indexer_to_prune] = 0
    else:
      if 'out' in name:
        parameter.data[:, indexer_to_prune] = 0
      else:
        parameter.data[indexer_to_prune, :] = 0

  return layer

In [None]:
def get_transformer_arguments(
    model,
    input_ids = None,
    attention_mask = None,
    head_mask = None,
    inputs_embeds = None,
    output_attentions = None,
    output_hidden_states = None,
    return_dict = None):

  output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
  output_hidden_states = (
      output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
  )
  return_dict = return_dict if return_dict is not None else model.config.use_return_dict

  if input_ids is not None and inputs_embeds is not None:
      raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  elif input_ids is not None:
      input_shape = input_ids.size()
  elif inputs_embeds is not None:
      input_shape = inputs_embeds.size()[:-1]
  else:
      raise ValueError("You have to specify either input_ids or inputs_embeds")

  device = input_ids.device if input_ids is not None else inputs_embeds.device

  if attention_mask is None:
      attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)

  # Prepare head mask if needed
  head_mask = model.get_head_mask(head_mask, model.config.num_hidden_layers)

  arguments = ({
                'output_attentions': output_attentions,
                'output_hidden_states': output_hidden_states,
                'return_dict': return_dict,
                'head_mask': head_mask,
                'attn_mask': attention_mask,
              })

  return arguments

In [None]:
def get_data(d_name):
  dataset = load_dataset(d_name)
  train_dataset = dataset['train']
  test_ds = dataset['validation']

  split_ds = train_dataset.train_test_split(test_size=0.1)
  train_ds = split_ds['train']
  val_ds = split_ds['test']
  return train_ds, val_ds, test_ds

In [None]:
@torch.no_grad()
def get_mask_indexer(heads_to_keep, n_heads, head_size):
  mask = torch.ones(n_heads, head_size)
  for head in heads_to_keep:
      mask[head] = 0
  mask = mask.view(-1).contiguous().eq(1)
  index = torch.arange(len(mask))[mask].long()
  return index

In [None]:
# separate heads
def separate_heads(x, bs, n_heads, dim_per_head):
  return x.view(bs, -1, n_heads, dim_per_head).transpose(1, 2)

In [None]:
class dotdict:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            if isinstance(value, dict):
                setattr(self, key, dotdict(value))
            else:
                setattr(self, key, value)

In [None]:
def residual_objective(mat):
    return 0.5 * torch.mean(mat**2) # using torch.sum requires small stepsize (order of 1e-5)

## fine-tuning

In [None]:
def run_finetuning(layer, input_loader):
  pass

## prune model

In [None]:
def prune_distilbert(args, model=None, tokenizer=None, train_ds=None, val_ds=None, test_ds=None, device='cpu'):
  if train_ds is None and test_ds is None and val_ds is None:
    train_ds, val_ds, test_ds = get_data('squad')
  if tokenizer is None:
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
  if model is None:
    model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')

  questions = train_ds['question']
  texts = train_ds['context']

  inputs = tokenizer(
        questions,
        texts,
        max_length=512,
        truncation="only_second",
        stride=128,
        padding="max_length",
        return_tensors='pt'
  )

  embedder = model.distilbert.embeddings
  transformer = model.distilbert.transformer

  transformer_arguments = get_transformer_arguments(model, **inputs)

  with torch.no_grad():
    embeddings = embedder(inputs.input_ids)

  val_arguments = {}
  if args.validate_iter or args.validate_layer:
    assert val_ds is not None
    start_idx = torch.randint(low=0, high=len(questions) - args.val_size, size=(1,)).item()

    val_ids = val_ds['id'][start_idx : start_idx + args.val_size]
    val_questions =  val_ds['question'][start_idx : start_idx + args.val_size]
    val_texts =  val_ds['context'][start_idx : start_idx + args.val_size]
    val_answers =  val_ds['answers'][start_idx : start_idx + args.val_size]

    val_inputs = tokenizer(
          val_questions,
          val_texts,
          max_length=512,
          truncation="only_second",
          stride=128,
          padding="max_length",
          return_tensors='pt'
    )

    val_arguments['ids'] = val_ids
    val_arguments['inputs'] = val_inputs
    val_arguments['answers'] = val_answers
    val_arguments['model'] = deepcopy(model)
    val_arguments['metric'] = load(args.dataset)

  pruned_transformer, val_mets = prune_transformer(args, transformer, embeddings, val_arguments, **transformer_arguments, device=device)

  if args.save_model:
    torch.save(pruned_transformer.state_dict(), '/content/pruned_transformer30pct.pth')
  return val_mets

## prune transformer

In [None]:
def prune_transformer(args, transformer, hidden_states, val_arguments, validate_iter=False, attn_mask=None, output_attentions=None, head_mask=None, output_hidden_states=None, return_dict=None, device='cpu'):
  pruned_transformer = deepcopy(transformer)
  layers = pruned_transformer.layer

  all_val_mets = {}
  if args.validate_layer or args.validate_iter:
    val_arguments['model'].distilbert.transformer = pruned_transformer
    mets = validate(args, **val_arguments, device=device)
    all_val_mets['pre-pruning'] = mets
    print(f'Pre-pruning validation metrics = {mets}')

  for i, layer in enumerate(layers[:]):
    if i < args.min_layer or i > args.max_layer:
      continue
    input_ds = InputDataset(attn_mask, hidden_states)
    input_loader = DataLoader(input_ds, num_workers=2, shuffle=True, batch_size=args.batch_size)

    print(f'\nLayer {i}')

    layer.attention = prune_attn_layer(args, layer.attention, input_loader, head_mask[i], val_arguments, layer_id=i, device=device)

    if args.finetune_layer:
      if args.validate_layer:
        val_arguments['model'].distilbert.transformer = pruned_transformer
        mets = validate(args, **val_arguments, device=device)
        all_val_mets[f'layer {i}'] = mets
        print(f'Validation metrics after pruning layer {i}, before fine-tuning = {mets}')

      layer = run_finetuning(layer, input_loader)

    layer = layer.to(device)
    hidden_states = hidden_states.to(device)
    attn_mask = attn_mask.to(device)

    with torch.no_grad():
      hidden_states = layer(
          x=hidden_states,
          attn_mask=attn_mask,
          output_attentions=output_attentions,
          head_mask=head_mask[i]
      )[-1]

    layer = layer.cpu()
    hidden_states = hidden_states.cpu()
    attn_mask = attn_mask.cpu()

    del input_ds
    del input_loader

    if args.validate_layer:
      val_arguments['model'].distilbert.transformer = pruned_transformer
      mets = validate(args, **val_arguments, device=device)
      all_val_mets[f'layer {i}'] = mets
      if args.finetune_layer:
        print(f'Validation metrics after pruning layer {i}, after fine-tuning = {mets}')
      else:
        print(f'Validation metrics after pruning layer {i} = {mets}')

    if args.validate_iter:
      val_arguments['model'].distilbert.transformer = pruned_transformer

    # print(f'after pruning : shape of W_q = {layer.attention.q_lin}, n_heads = {layer.attention.n_heads}, shape of hidden_states = {hidden_states.shape}')

  if args.verbose:
    print(pruned_transformer)
  return pruned_transformer, all_val_mets

## prune layer

In [None]:
def prune_attn_layer(args, layer, input_loader, head_mask, val_arguments, layer_id=None, device='cpu'):
  n_heads_to_keep = int(layer.n_heads * args.s)

  pruned_layer = deepcopy(layer)
  Q = deepcopy(layer)

  qkv_params = []
  out_params = []
  for name, param in Q.named_parameters():
    if 'out' in name:
      if 'weight' in name:
        out_params.append(param)
      else:
        continue
    else:
      qkv_params.append(param)
  qkv_optimizer = SparseAdam(qkv_params, sparsity='row')
  out_optimizer = SparseAdam(out_params, sparsity='col')

  S = set([])
  for t in range(args.T):
    if args.validate_iter and pruned_layer is not None:
      val_arguments['model'].distilbert.transformer.layer[layer_id].attention = pruned_layer
      acc = validate(args, **val_arguments, device=device)
      pruned_layer = pruned_layer.to(device)

    attn_mask, hidden_states = next(iter(input_loader))

    attn_mask = attn_mask.to(device)
    hidden_states = hidden_states.to(device)
    query, key, value = hidden_states, hidden_states, hidden_states
    layer = layer.to(device)

    importance, grad_dict, loss = compute_grads(pruned_layer, layer, hidden_states, attn_mask, head_mask, S, device=device)

    if args.validate_iter:
      print(f'Iteration {t}: S = {S} | Loss = {loss} | Exact Match = {acc["exact_match"]} | F1 = {acc["f1"]}')
    elif args.iter_verbose:
      print(f'Iteration {t}: S = {S} | Loss = {loss}')

    D = find_and_merge(importance, S, n_heads_to_keep, device=device)
    Q = update_step(args, Q, layer, grad_dict, qkv_optimizer, out_optimizer, D, device=device)
    pruned_layer, S = truncate(Q, pruned_layer, hidden_states, attn_mask, head_mask, n_heads_to_keep, device=device)
    pruned_layer = debias(args, pruned_layer, layer, input_loader, args.debias_iters, head_mask, eta=None, device=device)

    del grad_dict

  del Q
  del layer

  print(f'Post-pruning: S = {S}')
  all_heads = set(range(pruned_layer.n_heads))

  if layer_id is not None:
    print(f'Done pruning layer {layer_id}')

  return pruned_layer

## run

In [None]:
args = {
    'validate_iter': False,
    'validate_layer': True,
    'T': 10,
    's': 0.3,
    'debias_iters': 5,
    'min_layer': 0,
    'max_layer': 10,
    'batch_size': 16,
    'val_size': 100,
    'verbose': False,
    'iter_verbose': False,
    'dataset': 'squad',
    'save_model': False,
    'runs': 2,
    'maintain_Q': False,
    'eta': 1e-4,
    'trunc_strategy': 'magnitude',
    'finetune_layer': False,
    'train_size': 200
}
args = dotdict(args)

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
# train_ds, val_ds, test_ds = get_data(args.dataset)
# train_ds = train_ds[:args.train_size]

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(device)
start_time = time.time()

all_run_mets = []
for i in range(args.runs):
  print(f'\nRun {i}')
  train_ds, val_ds, test_ds = get_data(args.dataset)
  train_ds = train_ds[:args.train_size]
  run_mets = prune_distilbert(args, model, tokenizer, train_ds, test_ds, device=device)
  all_run_mets.append(run_mets)
end_time = time.time()

print(f'average run time = {(end_time - start_time) / args.runs}s')
torch.cuda.empty_cache()

cuda

Run 0
Pre-pruning validation metrics = {'exact_match': 72.0, 'f1': 84.0200495696003}

Layer 0


RuntimeError: ignored

In [None]:
all_run_mets

[{'pre-pruning': {'exact_match': 72.0, 'f1': 83.27004956960029},
  'layer 0': {'exact_match': 66.0, 'f1': 79.0994645705172},
  'layer 1': {'exact_match': 33.0, 'f1': 43.54160187197602},
  'layer 2': {'exact_match': 25.0, 'f1': 34.97856412921605},
  'layer 3': {'exact_match': 10.0, 'f1': 20.77237946338364},
  'layer 4': {'exact_match': 7.0, 'f1': 11.01484083604468},
  'layer 5': {'exact_match': 8.0, 'f1': 13.723320954571664}},
 {'pre-pruning': {'exact_match': 75.0, 'f1': 83.921978021978},
  'layer 0': {'exact_match': 67.0, 'f1': 77.29047619047618},
  'layer 1': {'exact_match': 46.0, 'f1': 58.2446301807476},
  'layer 2': {'exact_match': 39.0, 'f1': 46.2746657136901},
  'layer 3': {'exact_match': 11.0, 'f1': 16.147464060418937},
  'layer 4': {'exact_match': 8.0, 'f1': 9.77171589870085},
  'layer 5': {'exact_match': 10.0, 'f1': 15.003244172268563}},
 {'pre-pruning': {'exact_match': 72.0, 'f1': 83.27004956960029},
  'layer 0': {'exact_match': 66.0, 'f1': 79.0994645705172},
  'layer 1': {'ex

## plots

In [None]:
mets_by_score = []
for run in all_run_mets: # all_run_mets is an array of dicts, each run should be a dict
  metrics = run['pre-pruning'].keys()
  run_metrics = {met: [0 for _ in run.keys()] for met in metrics}
  for i, layer in enumerate(run.keys()):
    for met in metrics:
      run_metrics[met][i] = run[layer][met]
  mets_by_score.append(run_metrics)
mets_by_score

In [None]:
import matplotlib.pyplot as plt

metrics = list(mets_by_score[0].keys())
for met in metrics:
  for i, exp in enumerate(mets_by_score):
    data = exp[met]
    plt.plot(data, label=f"Run {i}")
  plt.legend()
  plt.ylabel('score')
  plt.xlabel('before pruning layer')
  plt.title(f'implementation (c), {met} by run')
  plt.show()
# for exp in avgs_by_exp: