In [None]:
import torch
from transformers import LlamaForCausalLM, LlamaConfig

class DummyLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
    
    def forward(self, *args, **kwargs):
        input_ids = kwargs.get('input_ids', None)
        if input_ids is None:
            raise ValueError("Input IDs trebuie furnizate pentru predicții.")

        batch_size, seq_len = input_ids.size()
        vocab_size = self.config.vocab_size
        
        random_logits = torch.rand((batch_size, seq_len, vocab_size), dtype=torch.float64)        

        return CausalLMOutputWithPast(logits = random_logits, loss = None)

config = LlamaConfig(
    architectures=["LlamaForCausalLM"],
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=1,
    eos_token_id=2,
    hidden_act="silu",
    hidden_size=512,
    initializer_range=0.02,
    intermediate_size=1024,
    max_position_embeddings=256,
    model_type="llama",
    num_attention_heads=8,
    num_hidden_layers=16,
    num_key_value_heads=8,
    pad_token_id=0,
    pretraining_tp=1,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=10000.0,
    tie_word_embeddings=False,
    torch_dtype="float32",
    transformers_version="4.40.1",
    use_cache=True,
    vocab_size=16000
)

dummy_model = DummyLlamaForCausalLM(config)
model = LlamaForCausalLM(config)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("babylm/babyllama-10m-2024")
sentence_good = "These fathers of Linda aren't escaped from by Laurie."
sentence_bad = "These fathers of Linda aren't conferred by Laurie."

In [None]:
import torch.nn.functional as F

with torch.no_grad():
    dummy_outputs_good = dummy_model(**inputs_good)
    dummy_outputs_bad = dummy_model(**inputs_bad)
    outputs_good = model(**inputs_good)
    outputs_bad = model(**inputs_bad)


dummy_logits_good = dummy_outputs_good.logits[:, :-1, :]
dummy_logits_bad = dummy_outputs_bad.logits[:, :-1, :]
logits_good = outputs_good.logits[:, :-1, :]
logits_bad = outputs_bad.logits[:, :-1, :]

target_good = inputs_good.input_ids[:, 1:]
target_bad = inputs_bad.input_ids[:, 1:]

dummy_log_probs_good = F.log_softmax(dummy_logits_good, dim=-1)
dummy_log_probs_bad = F.log_softmax(dummy_logits_bad, dim=-1)
log_probs_good = F.log_softmax(logits_good, dim=-1)
log_probs_bad = F.log_softmax(logits_bad, dim=-1)

dummy_sentence_log_prob_good = dummy_log_probs_good.gather(2, target_good.unsqueeze(-1)).squeeze(-1)
dummy_sentence_log_prob_bad = dummy_log_probs_bad.gather(2, target_bad.unsqueeze(-1)).squeeze(-1)
sentence_log_prob_good = log_probs_good.gather(2, target_good.unsqueeze(-1)).squeeze(-1)
sentence_log_prob_bad = log_probs_bad.gather(2, target_bad.unsqueeze(-1)).squeeze(-1)

dummy_total_log_prob_good = dummy_sentence_log_prob_good.sum(dim=-1).item()
dummy_total_log_prob_bad = dummy_sentence_log_prob_bad.sum(dim=-1).item()
total_log_prob_good = sentence_log_prob_good.sum(dim=-1).item()
total_log_prob_bad = sentence_log_prob_bad.sum(dim=-1).item()

print(f"Log-probability of sentence_good with dummy: {dummy_total_log_prob_good}")
print(f"Log-probability of sentence_bad with dummy: {dummy_total_log_prob_bad}")
print(f"Log-probability of sentence_good: {total_log_prob_good}")
print(f"Log-probability of sentence_bad: {total_log_prob_bad}")