Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Security Update and Enhancement for run.py #264

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 71 additions & 45 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
Expand All @@ -13,60 +13,86 @@
# limitations under the License.

import logging
import hashlib

from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model


CKPT_PATH = "./checkpoints/"
CKPT_HASH = "expected_checkpoint_hash"


def validate_checkpoint(path, expected_hash):
calculated_hash = hashlib.sha256(open(path, 'rb').read()).hexdigest()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a context manager with for opening and reading the given path. It might also be in our best interest to utilize type hints in the function signature.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from typing import Text

import hashlib

CKPT_HASH = "expected_checkpoint_hash"

def validate_checkpoint(path: Text, expected_hash: Text):
  with open(path, 'rb') as f:
    contents = f.read()

  calculated_hash = hashlib.sha256(contents).hexdigest()
  
  if calculated_hash != expected_hash:
    raise ValueError("Invalid checkpoint file!")

The key changes:

  • Added type hints for the path (Text) and expected_hash (Text) parameters.

  • Opened the file using a with statement, which automatically closes it when done.

  • Stored the file contents in a variable called 'contents' to avoid re-reading the file.

  • Passed the contents variable to hashlib.sha256 rather than the file object.

It seems we would need to import text from typing is this necessary?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the Text import is superfluous and could just as easily be replaced with str without importing any extra type hints.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still left to fix. Calling open outside of a context manager is bad practice and not recommended for production.

if calculated_hash != expected_hash:
raise ValueError("Invalid checkpoint file!")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this error message be improved? It might also be nice to utilize logging

Copy link
Author

@MiChaelinzo MiChaelinzo Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import logging

# Set up logging
logger = logging.getLogger(__name__)

def validate_checkpoint(path: Text, expected_hash: Text):

  with open(path, 'rb') as f:
    contents = f.read()

  calculated_hash = hashlib.sha256(contents).hexdigest()

  if calculated_hash != expected_hash:
    logger.error(f"Invalid checkpoint file. Expected hash: {expected_hash}, " 
                 f"Actual hash: {calculated_hash}")
    raise ValueError("Checkpoint validation failed")

The key changes:

  • Imported the logging module and created a logger object

  • Logged an error with the expected and actual hash values for more detail

  • Updated the exception message to be more specific

This makes it clear in the logs when a validation failure happens and provides the expected and actual hashes for diagnostics.

Other enhancements could include:

  • Adding the checkpoint path to the log message
  • Logging at INFO level when validation succeeds
  • Configuring logging to output to a file for production debugging

It would make the code longer etc. is this necessary?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please utilize Flake8 as well as some standardized code formatter. I'm noticing many inconsistencies in code you submit. There's no problem that I notice with your usage of logging. You just have not submitted an actual commit with fixes for file context management.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% convinced this user is just repeating garbage from an LLM.



def main():
grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024,
pad_token=0,
eos_token=2,
sequence_len=8192,
embedding_init_scale=1.0,
output_multiplier_scale=0.5773502691896257,
embedding_multiplier_scale=78.38367176906169,
model=TransformerConfig(
emb_size=48 * 128,
widening_factor=8,
key_size=128,
num_q_heads=48,
num_kv_heads=8,
num_layers=64,
attn_output_multiplier=0.08838834764831845,
shard_activations=True,
# MoE.
num_experts=8,
num_selected_experts=2,
# Activation sharding.
data_axis="data",
model_axis="model",
),
)
inference_runner = InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
local_mesh_config=(1, 8),
between_hosts_config=(1, 1),
)
inference_runner.initialize()
gen = inference_runner.run()
# Validate checkpoint integrity
validate_checkpoint(CKPT_PATH, CKPT_HASH)
MiChaelinzo marked this conversation as resolved.
Show resolved Hide resolved

grok_1_model = LanguageModelConfig(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only change here is a dedent from the PEP8 standard 4 space indent.

vocab_size=128 * 1024,
pad_token=0,
eos_token=2,
sequence_len=8192,
embedding_init_scale=1.0,
output_multiplier_scale=0.5773502691896257,
embedding_multiplier_scale=78.38367176906169,
model=TransformerConfig(
emb_size=48 * 128,
widening_factor=8,
key_size=128,
num_q_heads=48,
num_kv_heads=8,
num_layers=64,
attn_output_multiplier=0.08838834764831845,
shard_activations=True,
# MoE.
num_experts=8,
num_selected_experts=2,
# Activation sharding.
data_axis="data",
model_axis="model",
),
)

inference_runner = InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
# Limit inference rate
inference_runner.rate_limit = 100
Copy link

@Aareon Aareon Apr 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to reference inference_runner.rate_limit before it is defined.

),

name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
Copy link

@Aareon Aareon Apr 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you were to improve anything, I'd suggest improving how file paths are defined by utilizing pathlib

local_mesh_config=(1, 8),
between_hosts_config=(1, 1),
)

inference_runner.initialize()

gen = inference_runner.run()

inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))

# Add authentication
@app.route("/inference")
@auth.login_required
MiChaelinzo marked this conversation as resolved.
Show resolved Hide resolved
def inference():
...

gen = inference_runner.run()

# Rest of inference code

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()
logging.basicConfig(level=logging.INFO)
main()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 space indent is not standard. Please view PEP8

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using 1 space, and you should comment that to the original repo, you're making our lives very complicated enough with your reviews that doesn't make any sense at all!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using 1 space, and you should comment that to the original repo, you're making our lives very complicated enough with your reviews that doesn't make any sense at all!

This change is not accepted and requires fixing. 1 space indent is not standard and worsens readability and consistently in all code affected.

Not accepted.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, a complete waste of a PR. Nothing of value was added.