In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from huggingface_hub import login
from peft import PeftModel
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, matthews_corrcoef
from tqdm import tqdm

In [2]:
print(torch.cuda.is_available())  # Should return True if CUDA is available
print(torch.cuda.device_count())  # Number of GPUs detected
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce GTX 1650


In [3]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

## Configurations

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    compute_dtype = torch.bfloat16 # Or torch.float16 depending on your GPU
else:
    compute_dtype = torch.float32

In [5]:
model_id = "mistralai/Mistral-7B-v0.1"

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=compute_dtype,
    device_map={"":"cuda"}, # Automatically distributes across GPUs if available/needed
    # offload_folder='offload/'
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# add adapter, if not then base model
model = base_model

model.eval()

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
     

In [7]:
dataset = load_dataset("Rowan/hellaswag")
test_dataset = dataset["validation"]  # or "test" if it includes labels

In [14]:
test_dataset

Dataset({
    features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
    num_rows: 10042
})

In [16]:
correct = 0
total = 0

for i in tqdm(range(len(test_dataset))):
    example = dataset['validation'][i]  # ← Access the i-th example correctly
    ctx = example["ctx"]
    endings = example["endings"]
    label = example["label"]
    losses = []

    for ending in endings:
        prompt = ctx.strip() + " " + ending.strip()
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_ids = inputs["input_ids"]

        # GPT-style scoring: compute loss of predicting each token (shifted)
        with torch.no_grad():
            outputs = model(**inputs, labels=input_ids)
            loss = outputs.loss.item()
            losses.append(loss)
    
    pred = losses.index(min(losses))
    print(str(pred) + " == " + str(label) + "? " + str(int(pred) == int(label)))
    
    if int(pred) == int(label):
        correct += 1
    total += 1
    if i == 10:
        break

accuracy = correct / total
print(f"HellaSwag Accuracy (causal LLM): {accuracy:.4f}")


  0%|          | 1/10042 [00:23<65:30:13, 23.49s/it]

2 == 3? False


  0%|          | 2/10042 [00:48<67:17:56, 24.13s/it]

3 == 3? True


  0%|          | 3/10042 [01:20<77:18:13, 27.72s/it]

2 == 2? True


  0%|          | 4/10042 [01:45<74:47:28, 26.82s/it]

2 == 2? True


  0%|          | 5/10042 [02:14<76:51:45, 27.57s/it]

2 == 1? False


  0%|          | 6/10042 [02:40<75:20:41, 27.03s/it]

2 == 1? False


  0%|          | 7/10042 [03:07<75:25:49, 27.06s/it]

1 == 2? False


  0%|          | 8/10042 [03:34<75:23:10, 27.05s/it]

3 == 0? False


  0%|          | 9/10042 [04:04<77:50:18, 27.93s/it]

2 == 1? False


  0%|          | 10/10042 [04:40<84:48:37, 30.43s/it]

1 == 1? True


  0%|          | 10/10042 [05:12<87:01:09, 31.23s/it]

3 == 3? True
HellaSwag Accuracy (causal LLM): 0.4545



