In [None]:
# import os
# from huggingface_hub import hf_hub_download
# import safetensors

# # Define model and safetensor file names (assuming you have these names)
# model_id = "your-model-id"
# filenames = ["file1.safetensors", "file2.safetensors", "file3.safetensors"]

# # Function to load safetensor tensors
# def load_safetensor_tensors(model_id, filenames):
#   tensors = []
#   for filename in filenames:
#     # Download safetensor file from the HF hub
#     local_filepath = hf_hub_download(repo_id=model_id, filename=filename)
    
#     # Use safe_open for memory efficient reading
#     with safetensors.safe_open(local_filepath, framework="pt") as f:
#       for key in f.keys():
#         tensors.append((key, f[key]))  # Storing as (key, tensor) tuples
#   return tensors

# # Load tensors lazily
# def iterate_tensors(tensors):
#   for key, tensor in tensors:
#     yield key, tensor

# # Example usage
# tensors = load_safetensor_tensors(model_id, filenames)
# for key, tensor in iterate_tensors(tensors):
#   print(f"Key: {key}, Tensor: {tensor}")

In [5]:
import json
import torch
import transformers
import safetensors
import os
from huggingface_hub import hf_hub_download
from typing import Generator

class ModelTree:
  def __init__(self, weight_map: dict):
    self.keys: list[str] = list(weight_map.keys())
    self.file_names = set(weight_map.values())
    self.flattened_tree = weight_map

    self.tree = {}

    for key in self.keys:
      split_key = key.split(".")
      tree = self.tree
      for i, part in enumerate(split_key):
        if i < len(split_key) - 1:
          tree = tree.setdefault(part, {})
        else:
          tree = tree.setdefault(split_key[i], self.flattened_tree[key])

import typing
from abc import ABC, abstractmethod

class TensorLoader(ABC):
  """Naive `torch.load` shard loading."""

  tensors: dict[str, torch.Tensor]
  shard_path: str

  @classmethod
  def get_loader(self, shard_path: str, device: typing.Optional[str] = None):
    if shard_path.endswith(".safetensors"):
      return SafetensorLoader(shard_path, device=device)
    elif shard_path.endswith(".bin"):
      return PytorchLoader(shard_path, device=device)

  @abstractmethod
  def get_tensor(self, key: str) -> torch.Tensor:
      ...

  @abstractmethod
  def keys(self) -> list[str]:
      ...

class PytorchLoader(TensorLoader):
  def __init__(self, path: str, device: typing.Optional[str] = None):
    self.tensors = torch.load(path, map_location=device, weights_only=True)
    self.shard_path = path

  def get_tensor(self, key: str) -> torch.Tensor:
    return self.tensors[key]

  def keys(self) -> list[str]:
    return list(self.tensors.keys())

class SafetensorLoader(TensorLoader):
  def __init__(self, path: str, device: typing.Optional[str] = None):
    self.tensors = {}
    with safetensors.safe_open(path, framework="pt") as f:
      for k in f.keys():
        self.tensors[k] = f.get_tensor(k)
    self.shard_path = path

  def get_tensor(self, key: str) -> torch.Tensor:
    return self.tensors[key]

  def keys(self) -> list[str]:
    return list(self.tensors.keys())


class ModelLoader:
  def __init__(self, model_name, is_cached=True, device="cpu"):
    if not is_cached: raise Exception("Model must be downloaded first. Auto-download not implemented yet.")

    self.device = device
    self.config_path = transformers.utils.hub.cached_file(model_name, "config.json")
    self.model_dir: list[str] = os.path.dirname(self.config_path)


    self.index_path = [f for f in os.listdir(self.model_dir) if f.endswith(".index.json")]
    assert(len(self.index_path) == 1)

    self.index_path = self.index_path[0]

    with open(os.path.join(self.model_dir, self.index_path)) as f:
      self.index_dict: dict[str, dict[str, str]] = json.load(f)
      self.weight_map: dict[str, str] = self.index_dict["weight_map"]

      self.model_tree = ModelTree(self.weight_map)

    self.cached_loader = (None,None)
  
  def get_tensor_loader(self, key) -> TensorLoader:
    parent_file = self.weight_map[key]
  
    if parent_file == self.cached_loader[0]: return self.cached_loader[1]

    del self.cached_loader
    
    tensor_loader = TensorLoader.get_loader(os.path.join(self.model_dir, parent_file))
    self.cached_loader = (parent_file, tensor_loader)
    return tensor_loader

  def iterate_tensors(self, ordered=True) -> Generator:
    tensor_keys = list(self.weight_map.keys())
    for key in tensor_keys:
      tensor_loader = self.get_tensor_loader(key)
      yield (key,tensor_loader.get_tensor(key))

  def flush(self):
    del self.cached_loader
    self.cached_loader = (None, None)

  def print_tree(self):
    print(self.model_tree.tree)

In [6]:
# model_loader_test = ModelLoader("argilla/CapybaraHermes-2.5-Mistral-7B")

model_loader_test = ModelLoader("Open-Orca/Mistral-7B-OpenOrca")


In [3]:
for tensor in model_loader_test.iterate_tensors():
  print(tensor)

('lm_head.weight', tensor([[-2.5787e-03,  8.9645e-04, -2.3499e-03,  ..., -2.6894e-04,
          4.0894e-03,  1.8616e-03],
        [-2.5482e-03,  1.4343e-03, -2.3956e-03,  ...,  5.4169e-04,
          5.9509e-03,  2.0294e-03],
        [ 2.7466e-03,  3.3722e-03,  2.4109e-03,  ..., -3.6163e-03,
         -2.9144e-03,  2.1820e-03],
        ...,
        [ 2.7313e-03, -4.3335e-03,  3.5858e-03,  ..., -6.5918e-03,
          1.9073e-03, -7.7724e-05],
        [ 1.4709e-02, -1.6327e-03, -1.5717e-03,  ..., -4.8828e-03,
         -2.1362e-02,  2.0996e-02],
        [ 2.5635e-03, -4.7302e-03, -5.6152e-03,  ...,  1.6235e-02,
         -3.3188e-04, -3.2715e-02]], dtype=torch.bfloat16))
('model.embed_tokens.weight', tensor([[-1.7643e-05, -2.5511e-05, -1.4126e-05,  ...,  2.8253e-05,
         -8.3447e-06,  1.8477e-05],
        [-4.5776e-03,  3.2616e-04, -5.1270e-03,  ..., -1.8978e-04,
         -1.1902e-03,  3.0708e-04],
        [-1.5640e-03,  9.3079e-04,  1.8692e-04,  ...,  1.1749e-03,
          3.3760e-04,  

In [None]:
model_loader_test.get_tensor("model.layers.0.self_attn.k_proj.weight")

In [None]:
import torch
file_path = 'pytorch_model.bin'
state_dict = torch.load(file_path, map_location='cpu')