In [16]:
import sys
sys.path.insert(0, "../src")
from model_bert import BertForSequenceClassification
from config_bert import BertConfig
from run_glue import load_trained_model

In [8]:
import pathlib
import torch

In [5]:
global_dir = pathlib.Path("../masks/global/")

In [11]:
masks = {}
for task_dir in global_dir.iterdir():
    task = task_dir.stem
    masks[task] = {}
    for seed_dir in task_dir.iterdir():
        seed = seed_dir.stem
        masks[task][seed] = {}
        masks[task][seed]["magnitude"] = torch.load(str(seed_dir / "magnitude_mask.p"))
        masks[task][seed]["bad"] = torch.load(str(seed_dir/"bad_mask.pt"))

In [12]:
model_dir = pathlib.Path("../models/finetuned/")

In [40]:
scaling_factors = {}

for task in masks:
    scaling_factors[task] = {}
    for seed in masks[task]:
        good_masks = masks[task][seed]["magnitude"]
        bad_masks = masks[task][seed]["bad"]
        model = BertForSequenceClassification.from_pretrained(str(model_dir / task / seed))
        state_dict = model.state_dict()
        good_weight_sum = 0
        good_weight_total = 0
        for mask_name, mask in good_masks.items():
            component = state_dict[mask_name[:-5]]
            good_weight_sum += (component * mask).abs().sum()
            good_weight_total += mask.numel()
        good_weight_mean = good_weight_sum / good_weight_total
        bad_weight_sum = 0
        bad_weight_total = 0
        for mask_name, mask in bad_masks.items():
            component = state_dict[mask_name[:-5]]
            bad_weight_sum += (component * mask).abs().sum()
            bad_weight_total += mask.numel()
        bad_weight_mean = bad_weight_sum / bad_weight_total
        scaling_factor = good_weight_mean/bad_weight_mean
        print(task, seed)
        print(scaling_factor)
        scaling_factors[task][seed] = scaling_factor.item()

WNLI seed_71
tensor(3.9156)
WNLI seed_1337
tensor(12.6881)
WNLI seed_86
tensor(1.1837)
WNLI seed_42
tensor(4.5969)
WNLI seed_166
tensor(12.6880)
MRPC seed_71
tensor(3.3142)
MRPC seed_1337
tensor(4.2456)
MRPC seed_86
tensor(3.0404)
MRPC seed_42
tensor(3.0405)
MRPC seed_166
tensor(3.3142)
MNLI seed_71
tensor(3.9148)
MNLI seed_1337
tensor(3.9147)
MNLI seed_86
tensor(3.9147)
MNLI seed_42
tensor(3.9147)
MNLI seed_166
tensor(3.9145)
QQP seed_71
tensor(4.2449)
QQP seed_1337
tensor(4.5960)
QQP seed_86
tensor(4.2444)
QQP seed_42
tensor(4.5960)
QQP seed_166
tensor(4.2446)
RTE seed_71
tensor(4.2456)
RTE seed_1337
tensor(3.9155)
RTE seed_86
tensor(3.6055)
RTE seed_42
tensor(3.3142)
RTE seed_166
tensor(3.6055)
SST-2 seed_71
tensor(4.9707)
SST-2 seed_1337
tensor(4.9707)
SST-2 seed_86
tensor(5.3685)
SST-2 seed_42
tensor(4.9707)
SST-2 seed_166
tensor(4.9707)
STS-B seed_71
tensor(3.6055)
STS-B seed_1337
tensor(4.2456)
STS-B seed_86
tensor(3.9155)
STS-B seed_42
tensor(4.2455)
STS-B seed_166
tensor(4.245

In [41]:
scaling_factors

{'WNLI': {'seed_71': 3.915574312210083,
  'seed_1337': 12.68806266784668,
  'seed_86': 1.183749794960022,
  'seed_42': 4.596926212310791,
  'seed_166': 12.687975883483887},
 'MRPC': {'seed_71': 3.3142013549804688,
  'seed_1337': 4.245562553405762,
  'seed_86': 3.0404460430145264,
  'seed_42': 3.0404586791992188,
  'seed_166': 3.3141989707946777},
 'MNLI': {'seed_71': 3.9147980213165283,
  'seed_1337': 3.914726734161377,
  'seed_86': 3.9146506786346436,
  'seed_42': 3.9146640300750732,
  'seed_166': 3.914477586746216},
 'QQP': {'seed_71': 4.244863033294678,
  'seed_1337': 4.5959672927856445,
  'seed_86': 4.244410991668701,
  'seed_42': 4.5959577560424805,
  'seed_166': 4.24460506439209},
 'RTE': {'seed_71': 4.245597839355469,
  'seed_1337': 3.915539026260376,
  'seed_86': 3.605501890182495,
  'seed_42': 3.3142037391662598,
  'seed_166': 3.6055095195770264},
 'SST-2': {'seed_71': 4.970656871795654,
  'seed_1337': 4.970679759979248,
  'seed_86': 5.368537425994873,
  'seed_42': 4.970706462