Skip to content

Commit

Permalink
Add default values to Params members
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever authored and pytorchmergebot committed May 24, 2024
1 parent 2f6954c commit 5607fc6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ struct PhiloxCudaState {
int64_t* ptr;
};

Payload seed_;
Payload offset_;
Payload seed_{};
Payload offset_{};
uint32_t offset_intragraph_ = 0;
bool captured_ = false;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,46 +153,46 @@ struct AttentionKernel {
int32_t window_size = 0;

// Scale
accum_t scale;
accum_t scale = 1.0f;

// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
int32_t num_keys_absolute;
int32_t head_dim= 0;
int32_t head_dim_value= 0;
int32_t num_queries = 0;
int32_t num_keys = 0;
int32_t num_keys_absolute = 0;

uint8_t custom_mask_type = NoCustomMask;

int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t q_strideM = 0;
int32_t k_strideM = 0;
int32_t v_strideM = 0;
int32_t bias_strideM = 0;

int32_t o_strideM = 0;

// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t q_strideH = 0;
int32_t k_strideH = 0;
int32_t v_strideH = 0;
int64_t bias_strideH = 0;

int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t q_strideB = 0;
int64_t k_strideB = 0;
int64_t v_strideB = 0;
int64_t bias_strideB = 0;

int32_t num_batches;
int32_t num_heads;
int32_t num_batches = 0;
int32_t num_heads = 0;

// dropout
bool use_dropout;
unsigned long long dropout_batch_head_rng_offset;
float dropout_prob;
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
float dropout_prob = 0;
at::PhiloxCudaState rng_engine_inputs;
int64_t* extragraph_offset;
int64_t* seed;
int64_t* extragraph_offset = nullptr;
int64_t* seed = nullptr;

// Moves pointers to what we should process
// Returns "false" if there is no work to do
Expand Down

0 comments on commit 5607fc6

Please sign in to comment.