## Setup

### GPU Check

In [1]:
!nvidia-smi

Fri May 17 09:16:02 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:02:00.0 Off |                   On |
| N/A   30C    P0              28W / 165W |   1184MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

### Imports

In [2]:
import os
from sharded_model_loader import ShardedModelLoader
import constants  # Ensure constants module is available and contains CACHE_DIR_LOCAL
from hf_token import TOKEN
import torch

## Sharded Model Loading

### Model Constants

In [3]:
# Model Constants
model_name = "gemma-1.1-2b-it"
model_id = "google/" + model_name
model_dir = "models--google--" + model_name
shards = 2
shard_urls = [
    f"https://huggingface.co/google/{model_name}/resolve/main/model-0000{i+1}-of-0000{shards}.safetensors"
    for i in range(shards)
]
shard_local_files_download = [
    os.path.join(constants.CACHE_DIR_LOCAL, model_dir, "snapshots", "bf4924f313df5166dee1467161e886e55f2eb4d4", f"pytorch_model-0000{i+1}-of-0000{shards}.bin")
    for i in range(shards)
]

### Download and Load Shards

In [4]:
loader = ShardedModelLoader(model_name, model_id, constants.CACHE_DIR_LOCAL)
model, tokenizer = loader.download_and_load_shards(shard_urls, shard_local_files_download, TOKEN)

In [None]:
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(inputs['input_ids'], max_new_tokens=20, do_sample=False)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

Hello, my name is [Your Name] and I am reaching out to you today to inquire about your services.

I


### Load Downloaded Shards

In [4]:
loader = ShardedModelLoader(model_name, model_id, constants.CACHE_DIR_LOCAL)
model, tokenizer = loader.load_local_shards(shard_local_files_download)

Loading shard: /homes/pu22/.cache/huggingface/hub/models--google--gemma-1.1-2b-it/snapshots/bf4924f313df5166dee1467161e886e55f2eb4d4/pytorch_model-00001-of-00002.bin
Loading shard: /homes/pu22/.cache/huggingface/hub/models--google--gemma-1.1-2b-it/snapshots/bf4924f313df5166dee1467161e886e55f2eb4d4/pytorch_model-00002-of-00002.bin
