In [None]:
import roberta as my_roberta
from transformers.models.roberta import modeling_roberta as hf_roberta

import torch
import haliax as hax
import jax
import jax.random as jrandom
import jax.numpy as jnp
import numpy as np

# hello

ModuleNotFoundError: No module named 'levanter'

In [None]:
from transformers import AutoConfig
from time import time

hf_model_str = "FacebookAI/roberta-base"

hf_config = AutoConfig.from_pretrained(hf_model_str)
hf_config.hidden_dropout_prob = 0
hf_config.attention_probs_dropout_prob = 0
# hf_config.pad_token_id = -1
my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)

In [None]:
seed = time()
print(f"seed: {int(seed)}")
key = jrandom.PRNGKey(int(seed))

key_vars, key_funcs, key_run = jrandom.split(key, 3)

seed: 1725922495


In [None]:
EmbedAtt = my_config.EmbedAtt
Embed = my_config.Embed
Mlp = my_config.Mlp
Pos = my_config.Pos
KeyPos = my_config.KeyPos
Heads = my_config.Heads

cut_end_for_bounds = True 

Batch = hax.Axis("batch", 2)
Vocab = hax.Axis("vocab", my_config.vocab_size)

keys = jrandom.split(key_vars, 6)

input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)
if cut_end_for_bounds:
    input_ids = input_ids[{"position": slice(0,-2)}]
input_ids_torch = torch.from_numpy(np.array(input_ids.array))

input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))
if cut_end_for_bounds:
    input_embeds = input_embeds[{"position": slice(0,-2)}]
input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))

# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)
# mask = hax.ones((Batch, Pos))
mask = hax.zeros((Batch, Pos))

if cut_end_for_bounds:
    mask = mask[{"position": slice(0,-2)}]
mask_torch = torch.from_numpy(np.array(mask.array))

mask_materialized = (mask == 0) * jnp.finfo(jnp.bfloat16).min
mask_torch_materialized = hf_roberta.RobertaModel.get_extended_attention_mask(self=hf_roberta.RobertaModel(hf_config), attention_mask=mask_torch, input_shape=input_embeds_torch.shape)

features = input_embeds[{"position": 0}]
features_torch = torch.from_numpy(np.array(features.array))

x_embed_att = input_embeds.rename({"embed": "embed_att"})
if cut_end_for_bounds:
    x_embed_att = x_embed_att[{"position": slice(0,-2)}]
x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))

x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))
if cut_end_for_bounds:
    x_mlp = x_mlp[{"position": slice(0,-2)}]    
x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))

In [None]:
def check(my_output, hf_output, precision=1e-4):
    
    assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()
    # print(my_output.shape)
    # print(hf_output.shape)

    acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()

    # stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean()) 
    stats = (torch.linalg.norm(torch.tensor(np.array(my_output))), torch.linalg.norm(torch.tensor(np.array(hf_output))))
    
    difference = torch.tensor(np.array(my_output)) - torch.tensor(np.array(hf_output))

    diffs = difference.abs().mean()

    to_print = f"acc: {acc} \t norms: {stats} \t diffs: {diffs}"
    
    return acc, stats, diffs, to_print

In [None]:
def check_dicts(my_dict, hf_dict):
    print(my_dict.keys())
    print(hf_dict.keys())

    hf_keys_save = list(hf_dict.keys())

    flag = 0
    diff = 0

    for k in my_dict.keys():
        i = my_dict[k]
        if k not in hf_dict:
            print(f"ERROR \t {k}: key in my_dict but not hf_dict")
        j = hf_dict[k]
        diff += (np.array(i) - np.array(j)).sum()
        if check(i, j.detach())[0] < 1:
            print(f"ERROR \t {k}: {check(i, j)[0]}")
            flag += 1
        hf_keys_save.remove(k)

    if flag == 0:
        print("success1") 
    else:
        print("fail1") 

    if len(hf_keys_save) == 0:
        print("success2") 
    else:
        print("fail2")
        print(hf_keys_save)

    return diff

In [None]:
stop

NameError: name 'stop' is not defined

In [None]:
# Testing RobertaSelfOutput

def test_RobertaSelfOutput(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # # print(state.keys())

    hf_func = hf_roberta.RobertaSelfOutput(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(x_embed_att, input_embeds, key=k_2)
    hf_output = hf_func(x_embed_att_torch, input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaSelfAttention

def test_RobertaSelfAttention(key):
    k_1, k_2 = jrandom.split(key, 2)

    my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func  = hf_roberta.RobertaSelfAttention(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(input_embeds, mask_materialized, key=k_2)
    hf_output = hf_func(input_embeds_torch, mask_torch_materialized)

    return check(my_output.array, hf_output[0].detach())

# Testing RobertaAttention

def test_RobertaAttention(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaAttention(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)
    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

    return check(my_output.array, hf_output[0].detach())

# Testing RobertaIntermediate

def test_RobertaIntermediate(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaIntermediate(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(input_embeds, key=k_2)
    hf_output = hf_func(input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaOutput

def test_RobertaOutput(key):
    k_1, k_2 = jrandom.split(key, 2)

    my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaOutput(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(x_mlp, input_embeds, key=k_2)
    hf_output = hf_func(x_mlp_torch, input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaLayer

def test_RobertaLayer(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaLayer(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)
    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

    return check(my_output[0].array, hf_output[0].detach())

# Testing RobertaEncoder

def test_RobertaEncoder(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaEncoder(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)
    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

    return check(my_output[0].array, hf_output[0].detach())

# Testing RobertaEmbedding

def test_RobertaEmbedding(key, ids = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaEmbeddings(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    if ids:
        my_output = my_func.embed(input_ids=input_ids, key=k_2)
        hf_output = hf_func(input_ids=input_ids_torch)
    else:        
        my_output = my_func.embed(input_embeds=input_embeds, key=k_2)
        hf_output = hf_func(inputs_embeds=input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaPooler

def test_RobertaPooler(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaPooler(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(input_embeds, key=k_2)
    hf_output = hf_func(input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaModel

def test_RobertaModel(key, ids = True, pool = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)
    hf_func.load_state_dict(state, strict=True, assign=True)

    if ids:
        my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)
        hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
    else:
        my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)
        hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

    if pool:
        return check(my_output[1].array, hf_output[1].detach())
    else:
        return check(my_output[0].array, hf_output[0].detach())

# Testing RobertaLMHead

def test_RobertaLMHead(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    state["bias"] = torch.zeros(hf_config.vocab_size)

    # print(state.keys())

    hf_func = hf_roberta.RobertaLMHead(hf_config)
    hf_func.load_state_dict(state, strict=True, assign=True)

    my_output = my_func(features, key=k_2)
    hf_output = hf_func(features_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaForMaskedLM

def test_RobertaForMaskedLM(key, ids = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)
    state = my_pool.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    state["lm_head.bias"] = torch.zeros(hf_config.vocab_size)

    # print(state.keys())

    hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)
    hf_pool.load_state_dict(state, strict=True, assign=True)

    if ids:
        my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)
        hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
    else:
        my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)
        hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

    return check(my_output[0].array, hf_output[0].detach())

In [None]:
keys = jrandom.split(key_funcs, 15)

In [None]:
outs = []

outs.append(test_RobertaSelfOutput(keys[0]))
outs.append(test_RobertaSelfAttention(keys[1]))
outs.append(test_RobertaAttention(keys[2]))
outs.append(test_RobertaIntermediate(keys[3]))
outs.append(test_RobertaOutput(keys[4]))
outs.append(test_RobertaLayer(keys[4]))
outs.append(test_RobertaEncoder(keys[4]))
outs.append(test_RobertaEmbedding(keys[7], ids = True))
outs.append(test_RobertaEmbedding(keys[8], ids = False))
outs.append(test_RobertaModel(keys[9], ids = True, pool = True))
outs.append(test_RobertaModel(keys[10], ids = False, pool = False))
outs.append(test_RobertaModel(keys[9], ids = True, pool = True))
outs.append(test_RobertaModel(keys[10], ids = False, pool = False))
outs.append(test_RobertaPooler(keys[11]))
outs.append(test_RobertaLMHead(keys[12]))
outs.append(test_RobertaForMaskedLM(keys[13], ids = True))
outs.append(test_RobertaForMaskedLM(keys[14], ids = False))

In [None]:
types = [
    "test_RobertaSelfOutput",
    "test_RobertaSelfAttention",
    "test_RobertaAttention",
    "test_RobertaIntermediate",
    "test_RobertaOutput",
    "test_RobertaLayer",
    "test_RobertaEncoder",
    "test_RobertaEmbedding(ids = True)",
    "test_RobertaEmbedding(ids = False)",
    "test_RobertaModel(ids = True, pool = True)",
    "test_RobertaModel(ids = False, pool = False)",
    "test_RobertaModel(ids = True, pool = True)",
    "test_RobertaModel(ids = False, pool = False)",
    "test_RobertaPooler",
    "test_RobertaLMHead",
    "test_RobertaForMaskedLM(ids = True)",
    "test_RobertaForMaskedLM(ids = False)"
]

In [None]:
for i,o in enumerate(outs):
    if o[2] * 0 != 0:
        print(f"nan alert")
    if o[0] < 1:
        print(f"{types[i]}: {o[3]}")

In [None]:
for i,o in enumerate(outs):
    print(f"{types[i]}: {o[3]}")

In [None]:
key_model, key_lm = jrandom.split(key_funcs, 2)
key_model_run, key_lm_run = jrandom.split(key_run, 2)

In [None]:
# Initializing RobertaModel
my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, output_hidden_states=True, key=key_model)
state_model = my_model.to_state_dict()

state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}

hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)
hf_model.load_state_dict(state_model, strict=True)

print(check_dicts(my_model.to_state_dict(), hf_model.state_dict()))

dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias

In [None]:
# Initializing RobertaForMaskedLM
my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=key_funcs)
state_mlm = my_mlm.to_state_dict()

state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}
state_mlm["lm_head.bias"] = torch.zeros(hf_config.vocab_size)
# print(state_mlm[w_str])

hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)
hf_mlm.load_state_dict(state_mlm, strict=True, assign=True)
# print(hf_mlm.state_dict()[w_str])

print(check_dicts(state_mlm, hf_mlm.state_dict()))

dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attentio

In [None]:
def test_RobertaModel_Output(key_run, ids = False):
    if ids:
        my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key_run)
        hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)
    else:
        my_output = my_model(input_embeds = input_embeds, attention_mask=mask, key=key_run)
        hf_output = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)

    return my_output, hf_output

In [None]:
my_output_ids, hf_output_ids = test_RobertaModel_Output(key_model_run, ids=True)
my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key_model_run, ids=False)

In [None]:
# RobertaModel ids
my_out, hf_out = my_output_ids[0], hf_output_ids[0]

print(f"model_out: {check(my_out.array, hf_out.detach())[3]}")

my_pool, hf_pool = my_output_ids[1], hf_output_ids[1]

print(f"pool_out: {check(my_pool.array, hf_pool.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_output_ids[2], hf_output_ids[2][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach())[3])

model_out: acc: 1.0 	 norms: (tensor(888.5331), tensor(888.5331)) 	 diffs: 6.946668236196274e-07
pool_out: acc: 1.0 	 norms: (tensor(24.6133), tensor(24.6133)) 	 diffs: 4.2906631847472454e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5315), tensor(888.5314)) 	 diffs: 1.6926360046909394e-07
acc: 1.0 	 norms: (tensor(888.5301), tensor(888.5301)) 	 diffs: 2.57225963196106e-07
acc: 1.0 	 norms: (tensor(888.5310), tensor(888.5311)) 	 diffs: 3.1373855335914413e-07
acc: 1.0 	 norms: (tensor(888.5312), tensor(888.5312)) 	 diffs: 3.6955742643840495e-07
acc: 1.0 	 norms: (tensor(888.5304), tensor(888.5304)) 	 diffs: 4.1312114262836985e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 4.612424220340472e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 5.164657750356128e-07
acc: 1.0 	 norms: (tensor(888.5295), tensor(888.5295)) 	 diffs: 5.412561563389318e-07
acc: 1.0 	 norms: (tensor(888.5309), tensor(888.5310)) 	 diffs: 5.724053266931151e-07
acc: 1.0 	 norms

In [None]:
# RobertaModel embeds

my_out, hf_out = my_output_embeds[0], hf_output_embeds[0]

print(f"model_out: {check(my_out.array, hf_out.detach())[3]}")

my_pool, hf_pool = my_output_embeds[1], hf_output_embeds[1]

print(f"pool_out: {check(my_pool.array, hf_pool.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_output_embeds[2], hf_output_embeds[2][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach())[3])

model_out: acc: 1.0 	 norms: (tensor(888.5307), tensor(888.5306)) 	 diffs: 6.107513286224275e-07
pool_out: acc: 1.0 	 norms: (tensor(24.6094), tensor(24.6094)) 	 diffs: 3.8713346839358564e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5317), tensor(888.5317)) 	 diffs: 1.3876427829018212e-07
acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5305)) 	 diffs: 2.239140144411067e-07
acc: 1.0 	 norms: (tensor(888.5290), tensor(888.5290)) 	 diffs: 2.9642876597790746e-07
acc: 1.0 	 norms: (tensor(888.5301), tensor(888.5300)) 	 diffs: 3.554245893155894e-07
acc: 1.0 	 norms: (tensor(888.5311), tensor(888.5311)) 	 diffs: 4.070468264671945e-07
acc: 1.0 	 norms: (tensor(888.5311), tensor(888.5311)) 	 diffs: 4.3890696588277933e-07
acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5304)) 	 diffs: 4.87373824853421e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 5.209108735471091e-07
acc: 1.0 	 norms: (tensor(888.5298), tensor(888.5298)) 	 diffs: 5.463080583467672e-07
acc: 1.0 	 norms:

In [None]:
def test_RobertaForMaskedLM_Output(key_run, ids = False):
    if ids:
        my_output = my_mlm(input_ids = input_ids, attention_mask=mask, key=key_run)
        hf_output = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)
    else:
        my_output = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key_run)
        hf_output = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)

    return my_output, hf_output

In [None]:
my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key_run, ids=True)
my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key_run, ids=False)

In [None]:
#Masked MLM ids
my_out, hf_out = my_mlm_output_ids[0], hf_mlm_output_ids[0]

print(f"mlm_out: {check(my_out.array, hf_out.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_mlm_output_ids[1], hf_mlm_output_ids[1][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach(), precision = 0.01)[3])

mlm_out: acc: 1.0 	 norms: (tensor(7054.6812), tensor(7054.6816)) 	 diffs: 7.966510224832746e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5315), tensor(888.5314)) 	 diffs: 1.6926360046909394e-07
acc: 1.0 	 norms: (tensor(888.5301), tensor(888.5301)) 	 diffs: 2.57225963196106e-07
acc: 1.0 	 norms: (tensor(888.5310), tensor(888.5311)) 	 diffs: 3.1373855335914413e-07
acc: 1.0 	 norms: (tensor(888.5312), tensor(888.5312)) 	 diffs: 3.6955742643840495e-07
acc: 1.0 	 norms: (tensor(888.5304), tensor(888.5304)) 	 diffs: 4.1312114262836985e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 4.612424220340472e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 5.164657750356128e-07
acc: 1.0 	 norms: (tensor(888.5295), tensor(888.5295)) 	 diffs: 5.412561563389318e-07
acc: 1.0 	 norms: (tensor(888.5309), tensor(888.5310)) 	 diffs: 5.724053266931151e-07
acc: 1.0 	 norms: (tensor(888.5297), tensor(888.5297)) 	 diffs: 5.980200512567535e-07
acc: 1.0 	 norms: (tensor

In [None]:
#Masked MLM embeds
my_out, hf_out = my_mlm_output_embeds[0], hf_mlm_output_embeds[0]

print(f"mlm_out: {check(my_out.array, hf_out.detach())[3]}")

print("intermediates:")
my_ints, hf_ints = my_mlm_output_embeds[1], hf_mlm_output_embeds[1][1:]

for i,j in zip(my_ints, hf_ints):
    print(check(i.array,j.detach())[3])

mlm_out: acc: 1.0 	 norms: (tensor(7107.9902), tensor(7107.9902)) 	 diffs: 7.997662692105223e-07
intermediates:
acc: 1.0 	 norms: (tensor(888.5317), tensor(888.5317)) 	 diffs: 1.3876427829018212e-07
acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5305)) 	 diffs: 2.239140144411067e-07
acc: 1.0 	 norms: (tensor(888.5290), tensor(888.5290)) 	 diffs: 2.9642876597790746e-07
acc: 1.0 	 norms: (tensor(888.5301), tensor(888.5300)) 	 diffs: 3.554245893155894e-07
acc: 1.0 	 norms: (tensor(888.5311), tensor(888.5311)) 	 diffs: 4.070468264671945e-07
acc: 1.0 	 norms: (tensor(888.5311), tensor(888.5311)) 	 diffs: 4.3890696588277933e-07
acc: 1.0 	 norms: (tensor(888.5305), tensor(888.5304)) 	 diffs: 4.87373824853421e-07
acc: 1.0 	 norms: (tensor(888.5308), tensor(888.5308)) 	 diffs: 5.209108735471091e-07
acc: 1.0 	 norms: (tensor(888.5298), tensor(888.5298)) 	 diffs: 5.463080583467672e-07
acc: 1.0 	 norms: (tensor(888.5311), tensor(888.5311)) 	 diffs: 5.576325179390551e-07
acc: 1.0 	 norms: (tensor(

In [None]:
# # Testing RobertaModel
# my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)
# state_model = my_model.to_state_dict()

# state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}

# hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)
# hf_model.load_state_dict(state_model, strict=True)

In [None]:
# Testing RobertaLMHead
my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_lm)
state_head = my_head.to_state_dict()

state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}
state_head["bias"] = torch.zeros(hf_config.vocab_size)

hf_head = hf_roberta.RobertaLMHead(hf_config)
hf_head.load_state_dict(state_head, strict=True)

print(check_dicts(state_head, hf_head.state_dict()))

dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])
odict_keys(['bias', 'dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias'])
success1
success2
0.0


In [None]:
my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = False)

my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = False)
my_output = my_head(my_output_model[0], key=key_lm_run)
hf_output = hf_head(hf_output_model[0])

In [None]:
# k_rob, k_lm = jrandom.split(key, 2)

# # MLM
# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)
# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

# # Model + LM

# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)
# my_output = my_head(my_output_model[0], key=k_lm)

# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)
# hf_output = hf_head(hf_output_model[0])

In [None]:
# embeds
print(f"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}")
print(f"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}")
print(f"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}")

print(f"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}")
print(f"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}")

RobertaModel: acc: 1.0 	 norms: (tensor(888.5307), tensor(888.5306)) 	 diffs: 6.107513286224275e-07
Roberta Model + LM head: acc: 1.0 	 norms: (tensor(7107.9902), tensor(7107.9902)) 	 diffs: 7.997662692105223e-07
MLM: acc: 1.0 	 norms: (tensor(7107.9902), tensor(7107.9902)) 	 diffs: 7.997662692105223e-07
my RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7107.9902), tensor(7107.9902)) 	 diffs: 0.0
hf RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7107.9902), tensor(7107.9902)) 	 diffs: 0.0


In [None]:
my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = True)

my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = True)
my_output = my_head(my_output_model[0], key=key_lm_run)
hf_output = hf_head(hf_output_model[0])

In [None]:
# # MLM
# my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)
# hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)

# # Model + LM
# my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)
# my_output = my_head(my_output_model[0], key=k_lm)

# hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
# hf_output = hf_head(hf_output_model[0])

# # Notes
# # embeds works between hf and my, and within model->head and mlm 
# # ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''

In [None]:
# ids
print(f"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}")
print(f"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}")
print(f"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}")

print(f"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}")
print(f"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}")

RobertaModel: acc: 1.0 	 norms: (tensor(888.5331), tensor(888.5331)) 	 diffs: 6.946668236196274e-07
Roberta Model + LM head: acc: 1.0 	 norms: (tensor(7054.6812), tensor(7054.6816)) 	 diffs: 7.966510224832746e-07
MLM: acc: 1.0 	 norms: (tensor(7054.6812), tensor(7054.6816)) 	 diffs: 7.966510224832746e-07
my RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7054.6812), tensor(7054.6812)) 	 diffs: 0.0
hf RobertaModel + LM head vs MLM: acc: 1.0 	 norms: (tensor(7054.6816), tensor(7054.6816)) 	 diffs: 0.0


In [None]:
stop

NameError: name 'stop' is not defined

In [None]:
# Load pretrained weights from hf
hf_model = hf_roberta.RobertaModel.from_pretrained("roberta-base")
state_model = hf_model.state_dict()

state_model = {k: np.array(v) for k, v in state_model.items()}

hf_config = hf_model.config
hf_config.hidden_dropout_prob = 0
hf_config.attention_probs_dropout_prob = 0
my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)

my_model = my_roberta.RobertaModel.init(Vocab, my_config, output_hidden_states=True, key=key)
my_model = my_model.from_state_dict(state_model)

In [None]:
# Check weights loaded correctly
my_dict = my_model.to_state_dict()
hf_dict = hf_model.state_dict()

print(f"Total differences: {check_dicts(my_dict, hf_dict)}")

dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias

In [None]:
my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=key)
hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)



In [None]:
print(f"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}")

Model: acc: 1.0 	 norms: (tensor(443.3569), tensor(443.3564)) 	 diffs: 1.405485477334878e-06


In [None]:
my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=key)
hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)



In [None]:
print(f"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}")

Model: acc: 1.0 	 norms: (tensor(431.9545), tensor(431.9545)) 	 diffs: 1.3033360346526024e-06


In [None]:
hf_mlm = hf_roberta.RobertaForMaskedLM.from_pretrained(hf_model_str)
state_mlm = hf_mlm.state_dict()

state_mlm = {k: np.array(v) for k, v in state_mlm.items()}

hf_config = hf_mlm.config
hf_config.hidden_dropout_prob = 0
hf_config.attention_probs_dropout_prob = 0
my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)

my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)
my_mlm = my_mlm.from_state_dict(state_mlm)

In [None]:
# Check weights loaded correctly
my_dict = my_mlm.to_state_dict()
hf_dict = hf_mlm.state_dict()

print(f"Total differences: {check_dicts(my_dict, hf_dict)}")

dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attentio

In [None]:
hf_dict['lm_head.bias']

tensor([-0.0972, -0.0294,  0.4988,  ..., -0.0312, -0.0312, -1.0000])

In [None]:
hf_dict['lm_head.decoder.bias']

tensor([-0.0972, -0.0294,  0.4988,  ..., -0.0312, -0.0312, -1.0000])

In [None]:
my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)
hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)



In [None]:
print(f"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}")

MLM: acc: 1.0 	 norms: (tensor(33433.4062), tensor(33433.3945)) 	 diffs: 1.7561978893354535e-05


In [None]:
my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)
hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)



In [None]:
print(f"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}")

MLM: acc: 1.0 	 norms: (tensor(28814.2480), tensor(28814.2168)) 	 diffs: 1.45375252031954e-05
