In [1]:
import roberta as my_roberta
from transformers.models.roberta import modeling_roberta as hf_roberta

import torch
import haliax as hax
import jax.random as jrandom
import jax.numpy as jnp
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
2024-08-12 14:00:57,074	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
from transformers import AutoConfig
from time import time

hf_model_str = "FacebookAI/roberta-base"

hf_config = AutoConfig.from_pretrained(hf_model_str)
hf_config.hidden_dropout_prob = 0
hf_config.attention_probs_dropout_prob = 0
hf_config.pad_token_id = -1
my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)

In [3]:
seed = time()
print(f"seed: {int(seed)}")
key_vars = jrandom.PRNGKey(int(seed))

EmbedAtt = my_config.EmbedAtt
Embed = my_config.Embed
Mlp = my_config.Mlp
Pos = my_config.Pos
KeyPos = my_config.KeyPos
Heads = my_config.Heads

Batch = hax.Axis("batch", 2)
Vocab = hax.Axis("vocab", my_config.vocab_size)

keys = jrandom.split(key_vars, 6)

input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)
input_ids_torch = torch.from_numpy(np.array(input_ids.array))
input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))
input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))

# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)
mask = hax.ones((Batch, Pos))
mask_torch = torch.from_numpy(np.array(mask.array))
mask_torch_materialized = torch.ones((2, hf_config.num_attention_heads, hf_config.max_position_embeddings, hf_config.max_position_embeddings))

features = input_embeds[{"position": 0}]
features_torch = torch.from_numpy(np.array(features.array))

x_embed_att = input_embeds.rename({"embed": "embed_att"})
x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))
x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))
x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))

seed: 1723496463


Notes:
- Random Mask causes RobertaModel to have different output

In [30]:
def check(my_output, hf_output, p=False, pp=False, ppp=False, pppp=True, precision=1e-4):
    
    assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()
    # print(my_output.shape)
    # print(hf_output.shape)

    acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()

    stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean())
    
    if p:   
        acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()
        print(f"Accuracy: {acc}")
        print(f"Jax:\n{torch.tensor(np.array(my_output))}\nTorch:\n{hf_output}")

    if pp:
        diff = torch.tensor(np.array(my_output)) - hf_output
        print(f"Mean: {diff.abs().mean()}")
        print(f"Stdev: {diff.std()}")
        print(f"Difference:\n{diff}")

    if ppp:
        acc_prev = None
        for i in range(15):
            prec = 10 ** (-1*i)
            acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=prec).mean()
            if acc_prev is None:
                print(f"Iteration {i}, Precision {prec}:\t{acc}")
            else:
                if np.abs(acc - acc_prev) > 1e-4:
                    print(f"Iteration {i}, Precision {prec}:\t{acc}")
            acc_prev = acc
        print(f"Iteration {i}, Precision {prec}:\t{acc}")
    
    return acc, stats

In [5]:
# Testing RobertaSelfOutput

def test_RobertaSelfOutput(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # # print(state.keys())

    hf_func = hf_roberta.RobertaSelfOutput(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(x_embed_att, input_embeds, key=k_2)
    hf_output = hf_func(x_embed_att_torch, input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaSelfAttention

def test_RobertaSelfAttention(key):
    k_1, k_2 = jrandom.split(key, 2)

    my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func  = hf_roberta.RobertaSelfAttention(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(input_embeds, mask, key=k_2)
    hf_output = hf_func(input_embeds_torch, mask_torch_materialized)

    return check(my_output.array, hf_output[0].detach())

# Testing RobertaAttention

def test_RobertaAttention(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaAttention(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)
    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

    return check(my_output.array, hf_output[0].detach())

# Testing RobertaIntermediate

def test_RobertaIntermediate(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaIntermediate(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(input_embeds, key=k_2)
    hf_output = hf_func(input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaOutput

def test_RobertaOutput(key):
    k_1, k_2 = jrandom.split(key, 2)

    my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaOutput(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(x_mlp, input_embeds, key=k_2)
    hf_output = hf_func(x_mlp_torch, input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaLayer

def test_RobertaLayer(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaLayer(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)
    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

    return check(my_output.array, hf_output[0].detach())

# Testing RobertaEncoder

def test_RobertaEncoder(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaEncoder(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)
    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)

    return check(my_output.array, hf_output[0].detach())

# Testing RobertaEmbedding

def test_RobertaEmbedding(key, ids = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaEmbeddings(hf_config)
    hf_func.load_state_dict(state, strict=True)

    if ids:
        my_output = my_func.embed(input_ids=input_ids, key=k_2)
        hf_output = hf_func(input_ids=input_ids_torch)
    else:        
        my_output = my_func.embed(input_embeds=input_embeds, key=k_2)
        hf_output = hf_func(inputs_embeds=input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaPooler

def test_RobertaPooler(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaPooler(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(input_embeds, key=k_2)
    hf_output = hf_func(input_embeds_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaModel

def test_RobertaModel(key, ids = True, pool = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)
    hf_func.load_state_dict(state, strict=True)

    if ids:
        my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)
        hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
    else:
        my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)
        hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

    if pool:
        return check(my_output[1].array, hf_output[1].detach())
    else:
        return check(my_output[0].array, hf_output[0].detach())

# Testing RobertaLMHead

def test_RobertaLMHead(key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    state["bias"] = torch.zeros(hf_config.vocab_size)

    # print(state.keys())

    hf_func = hf_roberta.RobertaLMHead(hf_config)
    hf_func.load_state_dict(state, strict=True)

    my_output = my_func(features, key=k_2)
    hf_output = hf_func(features_torch)

    return check(my_output.array, hf_output.detach())

# Testing RobertaForMaskedLM

def test_RobertaForMaskedLM(key, ids = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)
    state = my_pool.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    state["lm_head.bias"] = torch.zeros(hf_config.vocab_size)

    # print(state.keys())

    hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)
    hf_pool.load_state_dict(state, strict=True)

    if ids:
        my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)
        hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
    else:
        my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)
        hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

    return check(my_output.array, hf_output[0].detach())

In [6]:
def out_func(input):
    acc, stats = input
    if acc < 1:
        return str(acc) + "\t<---- here"
    else:
        return str(acc)

In [7]:
# seed = time() + 20
# print(f"seed: {int(seed)}")
# key_vars = jrandom.PRNGKey(int(seed))
# keys = jrandom.split(key_vars, 15)

# print(f"test_RobertaSelfOutput: {out_func(test_RobertaSelfOutput(keys[0]))}")
# print(f"test_RobertaSelfAttention: {out_func(test_RobertaSelfAttention(keys[1]))}")
# print(f"test_RobertaAttention: {out_func(test_RobertaAttention(keys[2]))}")
# print(f"test_RobertaIntermediate: {out_func(test_RobertaIntermediate(keys[3]))}")
# print(f"test_RobertaOutput: {out_func(test_RobertaOutput(keys[4]))}")
# print(f"test_RobertaEmbedding(ids = True): {out_func(test_RobertaEmbedding(keys[7], ids = True))}")
# print(f"test_RobertaEmbedding(ids = False): {out_func(test_RobertaEmbedding(keys[8], ids = False))}")
# print(f"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}")
# print(f"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}")
# print(f"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}")
# print(f"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}")
# print(f"test_RobertaPooler: {out_func(test_RobertaPooler(keys[11]))}")
# print(f"test_RobertaLMHead: {out_func(test_RobertaLMHead(keys[12]))}")
# print(f"test_RobertaForMaskedLM(ids = True): {out_func(test_RobertaForMaskedLM(keys[13], ids = True))}")
# print(f"test_RobertaForMaskedLM(ids = False): {out_func(test_RobertaForMaskedLM(keys[14], ids = False))}")

In [31]:
def get_output_RobertaEmbedding(input, key, ids = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaEmbeddings(hf_config)
    hf_func.load_state_dict(state, strict=True)

    input_torch = torch.from_numpy(np.array(input.array))

    if ids:
        my_output = my_func.embed(input_ids=input, key=k_2)
        hf_output = hf_func(input_ids=input_torch)
    else:        
        my_output = my_func.embed(input_embeds=input, key=k_2)
        hf_output = hf_func(inputs_embeds=input_torch)

    return check(my_output.array, hf_output.detach()), (my_output, hf_output)

def get_output_RobertaEncoder(input, key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaEncoder(hf_config)
    hf_func.load_state_dict(state, strict=True)

    input_torch = torch.from_numpy(np.array(input.array))

    attention_mask = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf
    attention_mask_torch = torch.from_numpy(np.array(attention_mask.array))

    my_output = my_func(hidden_states=input, attention_mask=attention_mask, key=k_2)
    hf_output = hf_func(hidden_states=input_torch, attention_mask=attention_mask_torch)

    return check(my_output.array, hf_output[0].detach()), (my_output, hf_output)

def get_output_RobertaPooler(input, key):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaPooler(hf_config)
    hf_func.load_state_dict(state, strict=True)

    input_torch = torch.from_numpy(np.array(input.array))

    my_output = my_func(input, key=k_2)
    hf_output = hf_func(input_torch)

    return check(my_output.array, hf_output.detach()), (my_output, hf_output)

# Testing RobertaModel

def get_output_RobertaModel(input, key, ids = True, pool = True):
    k_1, k_2 = jrandom.split(key, 2)
    my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)
    state = my_func.to_state_dict()

    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

    # print(state.keys())

    hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)
    hf_func.load_state_dict(state, strict=True)
    
    input_torch = torch.from_numpy(np.array(input.array))
    
    # attention_mask = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf
    # attention_mask_torch = torch.from_numpy(np.array(attention_mask.array))

    if ids:
        my_output = my_func(input_ids = input, attention_mask=mask, key=k_2)
        hf_output = hf_func(input_ids = input_torch, attention_mask=mask_torch, return_dict=False)
    else:
        my_output = my_func(input_embeds = input, attention_mask=mask, key=k_2)
        hf_output = hf_func(inputs_embeds = input_torch, attention_mask=mask_torch, return_dict=False)

    if pool:
        return check(my_output[1].array, hf_output[1].detach()), (my_output, hf_output)
    else:
        return check(my_output[0].array, hf_output[0].detach()), (my_output, hf_output)


In [33]:
seed = time() + 30
print(f"seed: {int(seed)}")
key = jrandom.PRNGKey(int(seed))

k_t, k_emb, k_p = jrandom.split(key, 3)

input = input_embeds

(acc_embeds, stats_embed), (my_out_embeds, hf_out_embeds) = get_output_RobertaEmbedding(input, k_t, ids = False)
print(stats_embed)
(acc_enc, stats_enc), (my_out_enc, hf_out_enc) = get_output_RobertaEncoder(my_out_embeds, k_emb)
print(stats_enc)
(acc_pool, stats_pool), (my_out_pool, hf_out_pool) = get_output_RobertaPooler(my_out_enc, k_p)
print(stats_pool)

(acc_model, stats_model), (my_out_model, hf_out_model) = get_output_RobertaModel(input, key, ids = False, pool = True)
print(stats_model)

print(f"acc_embeds: {acc_embeds}")
print(f"acc_enc: {acc_enc}")
print(f"acc_pool: {acc_pool}")
print(f"acc_model: {acc_model}")
print(f"my comparison pool: {check(my_out_pool.array, my_out_model[1].array)}")
print(f"my comparison no pool: {check(my_out_enc.array, my_out_model[0].array)}")
print(f"hf comparison: {check(hf_out_pool.detach(), hf_out_model[1].detach())}")

seed: 1723505034
{'batch': 2, 'position': 514, 'embed': 768}
(tensor(0.7984), tensor(0.7984))
(tensor(nan), tensor(nan))
(tensor(nan), tensor(nan))
(tensor(0.5408), tensor(0.5408))
acc_embeds: 1.0
acc_enc: 0.0
acc_pool: 0.0
acc_model: 1.0
my comparison pool: (0.0, (tensor(nan), tensor(0.5408)))
my comparison no pool: (0.0, (tensor(nan), tensor(0.7873)))
hf comparison: (0.0, (tensor(nan), tensor(0.5408)))


In [17]:
print(my_out_pool)
print(my_out_enc)
print(my_out_embeds)
print(my_out_model)

NamedArray(float32{'batch': 2, 'embed': 768},
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]])
NamedArray(float32{'batch': 2, 'position': 514, 'embed': 768},
[[[nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  ...
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]]

 [[nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  ...
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]]])
NamedArray(float32{'batch': 2, 'position': 514, 'embed': 768},
[[[ 0.47864354  1.2938721   1.0534003  ... -1.4044254   0.829634
    0.00428176]
  [-0.98862374 -0.943986   -1.0448135  ... -0.64666593  0.12967904
   -1.0975188 ]
  [ 0.43071395 -0.60738516 -1.7641208  ... -1.1334671   0.9041689
    0.9875958 ]
  ...
  [ 1.287216    0.507795    0.23451686 ... -0.9582702  -0.3576718
    0.6565546 ]
  [ 0.33264828 -0.68922603  0

In [10]:
'''# Testing RobertaModel

my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)
state = my_model.to_state_dict()

state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

# print(state.keys())

hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)
hf_model.load_state_dict(state, strict=True)

my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)
hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)

check(my_output[0].array, hf_output[0].detach(), ppp=True)

# Testing RobertaLMHead

my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)
state = my_head.to_state_dict()

state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

state["bias"] = torch.zeros(hf_config.vocab_size)

# print(state.keys())

hf_head = hf_roberta.RobertaLMHead(hf_config)
hf_head.load_state_dict(state, strict=True)

my_output = my_head(my_output[0], key=key)
hf_output = hf_head(hf_output[0])

check(my_output.array, hf_output.detach(), ppp=True)'''

'# Testing RobertaModel\n\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\nstate = my_model.to_state_dict()\n\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n\n# print(state.keys())\n\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\nhf_model.load_state_dict(state, strict=True)\n\nmy_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\nhf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)\n\n# Testing RobertaLMHead\n\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\nstate = my_head.to_state_dict()\n\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n\nstate["bias"] = torch.zeros(hf_config.vocab_size)\n\n# print(state.keys())\n\nhf_head = hf_roberta.RobertaLMHead(hf_config)\nhf_head.load_state_dict(state, strict=True)\n\nmy_output = my_he

In [11]:
'''# Testing RobertaForMaskedLM

my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)
state_mlm = my_mlm.to_state_dict()

state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}

# if "lm_head.decoder.bias" in state:
#     print(state["lm_head.decoder.bias"])
# else:
#     print(f"RobertaForMaskedLM, {state.keys()}")

state_mlm["lm_head.bias"] = torch.zeros(hf_config.vocab_size)

print(state_mlm.keys())

hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)
hf_mlm.load_state_dict(state_mlm, strict=True)

# Testing RobertaModel

key_rob, key_head = jrandom.split(key, 2)

my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)
state_model = my_model.to_state_dict()

state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}

print(state_model.keys())

hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)
hf_model.load_state_dict(state_model, strict=True)

# Testing RobertaLMHead

my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)
state_head = my_head.to_state_dict()

state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}

state_head["bias"] = torch.zeros(hf_config.vocab_size)

print(state_head.keys())

hf_head = hf_roberta.RobertaLMHead(hf_config)
hf_head.load_state_dict(state_head, strict=True)'''

'# Testing RobertaForMaskedLM\n\nmy_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\nstate_mlm = my_mlm.to_state_dict()\n\nstate_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\n\n# if "lm_head.decoder.bias" in state:\n#     print(state["lm_head.decoder.bias"])\n# else:\n#     print(f"RobertaForMaskedLM, {state.keys()}")\n\nstate_mlm["lm_head.bias"] = torch.zeros(hf_config.vocab_size)\n\nprint(state_mlm.keys())\n\nhf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\nhf_mlm.load_state_dict(state_mlm, strict=True)\n\n# Testing RobertaModel\n\nkey_rob, key_head = jrandom.split(key, 2)\n\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\nstate_model = my_model.to_state_dict()\n\nstate_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n\nprint(state_model.keys())\n\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\nhf_model.load_state_dict(state_model, stric

In [12]:
'''k_rob, k_lm = jrandom.split(key, 2)

# MLM

my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)
hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)

# Model + LM

my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)
my_output = my_head(my_output_model[0], key=k_lm)

hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)
hf_output = hf_head(hf_output_model[0])

# # MLM
# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)
# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)

# # Model + LM
# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)
# my_output = my_head(my_output_model[0], key=k_lm)

# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)
# hf_output = hf_head(hf_output_model[0])

# Checks

print("\nChecking RobertaModel")
check(my_output_model[0].array, hf_output_model[0].detach(), pppp=True)
print("\nChecking Roberta Model + LM head")
check(my_output.array, hf_output.detach(), pppp=True)
print("\nChecking MLM")
check(my_output_mlm.array, hf_output_mlm[0].detach(), pppp=True)

print("\nChecking my RobertaModel + LM head and MLM")
check(my_output.array, my_output_mlm.array, pppp=True)
print("\nChecking hf RobertaModel + LM head and MLM")
check(hf_output.detach(), hf_output_mlm[0].detach(), pppp=True)


# Notes
# embeds works between hf and my, and within model->head and mlm 
# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''

'k_rob, k_lm = jrandom.split(key, 2)\n\n# MLM\n\nmy_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\nhf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n\n# Model + LM\n\nmy_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\nmy_output = my_head(my_output_model[0], key=k_lm)\n\nhf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\nhf_output = hf_head(hf_output_model[0])\n\n# # MLM\n# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n\n# # Model + LM\n# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n# my_output = my_head(my_output_model[0], key=k_lm)\n\n# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False