In [1]:
import roberta as my_roberta
from transformers.models.roberta import modeling_roberta as hf_roberta

import torch
import haliax as hax
import jax.random as jrandom
import jax.numpy as jnp
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
2024-09-04 10:40:36,597	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
from transformers import AutoConfig
from time import time

hf_model_str = "FacebookAI/roberta-base"

hf_config = AutoConfig.from_pretrained(hf_model_str)
hf_config.hidden_dropout_prob = 0
hf_config.attention_probs_dropout_prob = 0
hf_config.pad_token_id = -1
my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)

In [3]:
seed = time()
print(f"seed: {int(seed)}")
key_vars = jrandom.PRNGKey(int(seed))

EmbedAtt = my_config.EmbedAtt
Embed = my_config.Embed
Mlp = my_config.Mlp
Pos = my_config.Pos
KeyPos = my_config.KeyPos
Heads = my_config.Heads

cut_end_for_bounds = False 

Batch = hax.Axis("batch", 2)
Vocab = hax.Axis("vocab", my_config.vocab_size)

keys = jrandom.split(key_vars, 6)

input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)

if cut_end_for_bounds:
    input_ids = input_ids[{"position": slice(0,-2)}]

input_ids_torch = torch.from_numpy(np.array(input_ids.array))
input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))

if cut_end_for_bounds:
    input_embeds = input_embeds[{"position": slice(0,-2)}]

input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))

# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)
mask = hax.ones((Batch, Pos))

if cut_end_for_bounds:
    mask = mask[{"position": slice(0,-2)}]

mask_torch = torch.from_numpy(np.array(mask.array))

mask_materialized = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf

if cut_end_for_bounds:
    mask_materialized = mask_materialized[{"position": slice(0,-2), "key_position": slice(0,-2)}]

mask_materialized_torch = torch.from_numpy(np.array(mask_materialized.array))

features = input_embeds[{"position": 0}]
features_torch = torch.from_numpy(np.array(features.array))

x_embed_att = input_embeds.rename({"embed": "embed_att"})

if cut_end_for_bounds:
    x_embed_att = x_embed_att[{"position": slice(0,-2)}]

x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))
x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))

if cut_end_for_bounds:
    x_mlp = x_mlp[{"position": slice(0,-2)}]
    
x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))

seed: 1725471642


In [4]:
def check(my_output, hf_output, precision=1e-4):
    
    assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()
    # print(my_output.shape)
    # print(hf_output.shape)

    acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()

    # stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean()) 
    stats = (torch.linalg.norm(torch.tensor(np.array(my_output))), torch.linalg.norm(torch.tensor(np.array(hf_output))))
    
    difference = torch.tensor(np.array(my_output)) - torch.tensor(np.array(hf_output))

    diffs = difference.abs().mean()

    to_print = f"acc: {acc} \t norms: {stats} \t diffs: {diffs}"
    
    return acc, stats, diffs, to_print

In [5]:
# # Testing RobertaSelfOutput

# def test_RobertaSelfOutput(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # # print(state.keys())

#     hf_func = hf_roberta.RobertaSelfOutput(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(x_embed_att, input_embeds, key=k_2)
#     hf_output = hf_func(x_embed_att_torch, input_embeds_torch)

#     return check(my_output.array, hf_output.detach())

# # Testing RobertaSelfAttention

# def test_RobertaSelfAttention(key):
#     k_1, k_2 = jrandom.split(key, 2)

#     my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func  = hf_roberta.RobertaSelfAttention(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(input_embeds, mask, key=k_2)
#     hf_output = hf_func(input_embeds_torch, mask_torch_materialized)

#     return check(my_output.array, hf_output[0].detach())

# # Testing RobertaAttention

# def test_RobertaAttention(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaAttention(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)
#     hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

#     return check(my_output.array, hf_output[0].detach())

# # Testing RobertaIntermediate

# def test_RobertaIntermediate(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaIntermediate(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(input_embeds, key=k_2)
#     hf_output = hf_func(input_embeds_torch)

#     return check(my_output.array, hf_output.detach())

# # Testing RobertaOutput

# def test_RobertaOutput(key):
#     k_1, k_2 = jrandom.split(key, 2)

#     my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaOutput(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(x_mlp, input_embeds, key=k_2)
#     hf_output = hf_func(x_mlp_torch, input_embeds_torch)

#     return check(my_output.array, hf_output.detach())

# # Testing RobertaLayer

# def test_RobertaLayer(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaLayer(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)
#     hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

#     return check(my_output.array, hf_output[0].detach())

# # Testing RobertaEncoder

# def test_RobertaEncoder(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaEncoder(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)
#     hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

#     return check(my_output.array, hf_output[0].detach())

# # Testing RobertaEmbedding

# def test_RobertaEmbedding(key, ids = True):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaEmbeddings(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     if ids:
#         my_output = my_func.embed(input_ids=input_ids, key=k_2)
#         hf_output = hf_func(input_ids=input_ids_torch)
#     else:        
#         my_output = my_func.embed(input_embeds=input_embeds, key=k_2)
#         hf_output = hf_func(inputs_embeds=input_embeds_torch)

#     return check(my_output.array, hf_output.detach())

# # Testing RobertaPooler

# def test_RobertaPooler(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaPooler(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(input_embeds, key=k_2)
#     hf_output = hf_func(input_embeds_torch)

#     return check(my_output.array, hf_output.detach())

# # Testing RobertaModel

# def test_RobertaModel(key, ids = True, pool = True):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)
#     hf_func.load_state_dict(state, strict=True)

#     if ids:
#         my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)
#         hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
#     else:
#         my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)
#         hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

#     if pool:
#         return check(my_output[1].array, hf_output[1].detach())
#     else:
#         return check(my_output[0].array, hf_output[0].detach())

# # Testing RobertaLMHead

# def test_RobertaLMHead(key):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)
#     state = my_func.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     state["bias"] = torch.zeros(hf_config.vocab_size)

#     # print(state.keys())

#     hf_func = hf_roberta.RobertaLMHead(hf_config)
#     hf_func.load_state_dict(state, strict=True)

#     my_output = my_func(features, key=k_2)
#     hf_output = hf_func(features_torch)

#     return check(my_output.array, hf_output.detach())

# # Testing RobertaForMaskedLM

# def test_RobertaForMaskedLM(key, ids = True):
#     k_1, k_2 = jrandom.split(key, 2)
#     my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)
#     state = my_pool.to_state_dict()

#     state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

#     state["lm_head.bias"] = torch.zeros(hf_config.vocab_size)

#     # print(state.keys())

#     hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)
#     hf_pool.load_state_dict(state, strict=True)

#     if ids:
#         my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)
#         hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
#     else:
#         my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)
#         hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

#     return check(my_output.array, hf_output[0].detach())

In [6]:
# seed = time()
# print(f"seed: {int(seed)}")
# key_vars = jrandom.PRNGKey(int(seed))
# keys = jrandom.split(key_vars, 15)

# print(f"test_RobertaSelfOutput: {out_func(test_RobertaSelfOutput(keys[0]))}")
# print(f"test_RobertaSelfAttention: {out_func(test_RobertaSelfAttention(keys[1]))}")
# print(f"test_RobertaAttention: {out_func(test_RobertaAttention(keys[2]))}")
# print(f"test_RobertaIntermediate: {out_func(test_RobertaIntermediate(keys[3]))}")
# print(f"test_RobertaOutput: {out_func(test_RobertaOutput(keys[4]))}")
# print(f"test_RobertaEmbedding(ids = True): {out_func(test_RobertaEmbedding(keys[7], ids = True))}")
# print(f"test_RobertaEmbedding(ids = False): {out_func(test_RobertaEmbedding(keys[8], ids = False))}")
# print(f"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}")
# print(f"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}")
# print(f"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}")
# print(f"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}")
# print(f"test_RobertaPooler: {out_func(test_RobertaPooler(keys[11]))}")
# print(f"test_RobertaLMHead: {out_func(test_RobertaLMHead(keys[12]))}")
# print(f"test_RobertaForMaskedLM(ids = True): {out_func(test_RobertaForMaskedLM(keys[13], ids = True))}")
# print(f"test_RobertaForMaskedLM(ids = False): {out_func(test_RobertaForMaskedLM(keys[14], ids = False))}")

In [7]:
seed = time()
print(f"seed: {int(seed)}")
key = jrandom.PRNGKey(int(seed))

def test_RobertaModel_Output(key, ids = False, pool = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, output_hidden_states=True, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)
    hf_func.load_state_dict(state, strict=True)

    if ids:
        my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)
        hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)
    else:
        my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)
        hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)

    return my_output, hf_output

my_output_ids, hf_output_ids = test_RobertaModel_Output(key, ids=True)
my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key, ids=False)

seed: 1725471643


In [8]:
my_out, hf_out = my_output_ids[0], hf_output_ids[0]

print(f"model_out: {check(my_out.array, hf_out.detach())[3]}")

my_pool, hf_pool = my_output_ids[1], hf_output_ids[1]

print(f"pool_out: {check(my_pool.array, hf_pool.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_output_ids[2], hf_output_ids[2][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach())[3])

model_out: acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5305)) 	 diffs: 3.735790983228071e-07
pool_out: acc: 1.0 	 norms: (tensor(24.1699), tensor(24.1699)) 	 diffs: 2.6738598535303026e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5306), tensor(888.5306)) 	 diffs: 1.842970362986307e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 2.625348543006112e-07
acc: 1.0 	 norms: (tensor(888.5301), tensor(888.5301)) 	 diffs: 3.2068956556940975e-07
acc: 1.0 	 norms: (tensor(888.5295), tensor(888.5295)) 	 diffs: 3.396997101390298e-07
acc: 1.0 	 norms: (tensor(888.5296), tensor(888.5297)) 	 diffs: 3.419580139052414e-07
acc: 1.0 	 norms: (tensor(888.5310), tensor(888.5310)) 	 diffs: 3.721844734627666e-07
acc: 1.0 	 norms: (tensor(888.5298), tensor(888.5299)) 	 diffs: 3.591211736875266e-07
acc: 1.0 	 norms: (tensor(888.5299), tensor(888.5299)) 	 diffs: 3.513960677992145e-07
acc: 1.0 	 norms: (tensor(888.5300), tensor(888.5300)) 	 diffs: 3.739319538453856e-07
acc: 1.0 	 norms: 

In [9]:
my_out, hf_out = my_output_embeds[0], hf_output_embeds[0]

print(f"model_out: {check(my_out.array, hf_out.detach())[3]}")

my_pool, hf_pool = my_output_embeds[1], hf_output_embeds[1]

print(f"pool_out: {check(my_pool.array, hf_pool.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_output_embeds[2], hf_output_embeds[2][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach())[3])

model_out: acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 3.864420250465628e-07
pool_out: acc: 1.0 	 norms: (tensor(24.0400), tensor(24.0400)) 	 diffs: 2.5442224682592496e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5295), tensor(888.5295)) 	 diffs: 1.46142554058315e-07
acc: 1.0 	 norms: (tensor(888.5310), tensor(888.5311)) 	 diffs: 2.35209114407553e-07
acc: 1.0 	 norms: (tensor(888.5301), tensor(888.5302)) 	 diffs: 3.0417106700042496e-07
acc: 1.0 	 norms: (tensor(888.5303), tensor(888.5303)) 	 diffs: 3.522779365994211e-07
acc: 1.0 	 norms: (tensor(888.5314), tensor(888.5314)) 	 diffs: 3.7978762179591286e-07
acc: 1.0 	 norms: (tensor(888.5302), tensor(888.5302)) 	 diffs: 3.9330373624579806e-07
acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5304)) 	 diffs: 3.8590263784499257e-07
acc: 1.0 	 norms: (tensor(888.5306), tensor(888.5306)) 	 diffs: 3.735180200692412e-07
acc: 1.0 	 norms: (tensor(888.5292), tensor(888.5291)) 	 diffs: 3.75227983795412e-07
acc: 1.0 	 norms: 

In [10]:
seed = time()
print(f"seed: {int(seed)}")
key = jrandom.PRNGKey(int(seed))

def test_RobertaForMaskedLM_Output(key, ids = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}
    
    state["lm_head.bias"] = torch.zeros(hf_config.vocab_size)

    hf_func = hf_roberta.RobertaForMaskedLM(hf_config)
    hf_func.load_state_dict(state, strict=True)

    if ids:
        my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)
        hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)
    else:
        my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)
        hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)

    return my_output, hf_output

my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key, ids=True)
my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key, ids=False)

seed: 1725471659


In [11]:
my_out, hf_out = my_mlm_output_ids[0], hf_mlm_output_ids[0]

print(f"mlm_out: {check(my_out.array, hf_out.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_mlm_output_ids[1], hf_mlm_output_ids[1][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach(), precision = 0.01)[3])

mlm_out: acc: 0.001719795589213743 	 norms: (tensor(7055.4185), tensor(7059.8276)) 	 diffs: 0.06642061471939087
intermediates:
acc: 0.019604713845654993 	 norms: (tensor(888.5299), tensor(888.5306)) 	 diffs: 0.5695138573646545
acc: 0.024172138456549936 	 norms: (tensor(888.5310), tensor(888.5300)) 	 diffs: 0.46615326404571533
acc: 0.030564759646562904 	 norms: (tensor(888.5312), tensor(888.5323)) 	 diffs: 0.37349948287010193
acc: 0.04000106395914397 	 norms: (tensor(888.5294), tensor(888.5300)) 	 diffs: 0.28456440567970276
acc: 0.05416818660830091 	 norms: (tensor(888.5305), tensor(888.5297)) 	 diffs: 0.21144814789295197
acc: 0.07160065053501946 	 norms: (tensor(888.5291), tensor(888.5300)) 	 diffs: 0.16202779114246368
acc: 0.08982095087548637 	 norms: (tensor(888.5302), tensor(888.5295)) 	 diffs: 0.12657980620861053
acc: 0.11758268482490272 	 norms: (tensor(888.5289), tensor(888.5308)) 	 diffs: 0.09682736545801163
acc: 0.1463805123216602 	 norms: (tensor(888.5303), tensor(888.5312)) 	

In [12]:
my_out, hf_out = my_mlm_output_embeds[0], hf_mlm_output_embeds[0]

print(f"mlm_out: {check(my_out.array, hf_out.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_mlm_output_embeds[1], hf_mlm_output_embeds[1][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach())[3])

mlm_out: acc: 1.0 	 norms: (tensor(7033.6245), tensor(7033.6240)) 	 diffs: 5.174669013285893e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5309), tensor(888.5310)) 	 diffs: 1.4689783256471856e-07
acc: 1.0 	 norms: (tensor(888.5306), tensor(888.5306)) 	 diffs: 2.3626945733212779e-07
acc: 1.0 	 norms: (tensor(888.5312), tensor(888.5312)) 	 diffs: 3.0318486210489937e-07
acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5305)) 	 diffs: 3.531020809077745e-07
acc: 1.0 	 norms: (tensor(888.5297), tensor(888.5297)) 	 diffs: 3.7493518334486e-07
acc: 1.0 	 norms: (tensor(888.5297), tensor(888.5297)) 	 diffs: 3.8230905374803115e-07
acc: 1.0 	 norms: (tensor(888.5295), tensor(888.5296)) 	 diffs: 3.8595226214965805e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 3.713914793479489e-07
acc: 1.0 	 norms: (tensor(888.5317), tensor(888.5318)) 	 diffs: 3.5173252399545163e-07
acc: 1.0 	 norms: (tensor(888.5316), tensor(888.5316)) 	 diffs: 3.4720503094831656e-07
acc: 1.0 	 norms: (tens

In [13]:
# Testing RobertaForMaskedLM

my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)
state_mlm = my_mlm.to_state_dict()

state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}

# if "lm_head.decoder.bias" in state:
#     print(state["lm_head.decoder.bias"])
# else:
#     print(f"RobertaForMaskedLM, {state.keys()}")

state_mlm["lm_head.bias"] = torch.zeros(hf_config.vocab_size)

print(state_mlm.keys())

hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)
hf_mlm.load_state_dict(state_mlm, strict=True)

# Testing RobertaModel

key_rob, key_head = jrandom.split(key, 2)

my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)
state_model = my_model.to_state_dict()

state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}

print(state_model.keys())

hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)
hf_model.load_state_dict(state_model, strict=True)

# Testing RobertaLMHead

my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)
state_head = my_head.to_state_dict()

state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}

state_head["bias"] = torch.zeros(hf_config.vocab_size)

print(state_head.keys())

hf_head = hf_roberta.RobertaLMHead(hf_config)
hf_head.load_state_dict(state_head, strict=True)

dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attentio

<All keys matched successfully>

In [14]:
k_rob, k_lm = jrandom.split(key, 2)

# MLM
my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)
hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

# Model + LM
my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)
my_output = my_head(my_output_model[0], key=k_lm)

hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)
hf_output = hf_head(hf_output_model[0])

In [15]:
print(f"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}")
print(f"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}")
print(f"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}")

print(f"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}")
print(f"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}")

RobertaModel: acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5305)) 	 diffs: 3.802756225468329e-07
Roberta Model + LM head: acc: 1.0 	 norms: (tensor(7087.6094), tensor(7087.6094)) 	 diffs: 5.642933729177457e-07
MLM: acc: 1.0 	 norms: (tensor(7087.6094), tensor(7087.6094)) 	 diffs: 5.642933729177457e-07
my RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7087.6094), tensor(7087.6094)) 	 diffs: 0.0
hf RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7087.6094), tensor(7087.6094)) 	 diffs: 0.0


In [16]:
# MLM

my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)
hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)

# Model + LM

my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)
my_output = my_head(my_output_model[0], key=k_lm)

hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
hf_output = hf_head(hf_output_model[0])


# Checks

# Notes
# embeds works between hf and my, and within model->head and mlm 
# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''

In [17]:
print(f"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}")
print(f"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}")
print(f"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}")

print(f"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}")
print(f"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}")

RobertaModel: acc: 1.0 	 norms: (tensor(888.5287), tensor(888.5287)) 	 diffs: 3.736330143055966e-07
Roberta Model + LM head: acc: 1.0 	 norms: (tensor(7065.4971), tensor(7065.4971)) 	 diffs: 5.507896503331722e-07
MLM: acc: 0.0014420458728273227 	 norms: (tensor(7065.4971), tensor(7062.7324)) 	 diffs: 0.07865540683269501
my RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7065.4971), tensor(7065.4971)) 	 diffs: 0.0
hf RobertaModel + LM head vs MLM: acc: 0.0014420071674599332 	 norms: (tensor(7065.4971), tensor(7062.7324)) 	 diffs: 0.07865539938211441
