diff --git a/images/Merge.png b/images/Merge.png new file mode 100644 index 00000000..a2df0487 Binary files /dev/null and b/images/Merge.png differ diff --git a/images/ollama.png b/images/ollama.png new file mode 100644 index 00000000..fa83bb42 Binary files /dev/null and b/images/ollama.png differ diff --git a/pyproject.toml b/pyproject.toml index c2cf4ae9..581b86a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ exclude = ["images*"] [project.optional-dependencies] huggingface = [ "tyro", - "transformers>=4.38.2", + "transformers>=4.42.3", "datasets>=2.16.0", "sentencepiece>=0.2.0", "tqdm", @@ -185,9 +185,9 @@ colab-ampere-torch220 = [ ] colab-new = [ "tyro", - "transformers>=4.38.2", + "transformers>=4.42.3", "datasets>=2.16.0", - "sentencepiece", + "sentencepiece>=0.2.0", "tqdm", "psutil", "wheel>=0.42.0", diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 26057791..dc1ad269 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,14 +19,17 @@ from transformers.models.llama.modeling_llama import logger +@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) @triton.jit def _cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -58,13 +61,19 @@ def _cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) + if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP) + + logits = logits.to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) if label_idx != -100: - x = tl.load(logits_ptr + label_idx).to(tl.float32) - loss = logsumexp - x + x = tl.load(logits_ptr + label_idx) + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) + if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP) + loss = logsumexp - x.to(tl.float32) else: loss = 0.0 tl.store(logsumexp_ptr, logsumexp) @@ -72,15 +81,18 @@ def _cross_entropy_forward( pass +@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) @triton.jit def _chunked_cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) + if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP) + + logits = logits.to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) @@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward( # Do the -x separately if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) - loss = -1.0 * x + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) + if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP) + loss = -1.0 * x.to(tl.float32) else: loss = 0.0 tl.store(loss_ptr, loss) @@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward( pass +@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],}) @triton.jit def _cross_entropy_backward( logits_ptr, logits_row_stride, dloss_ptr, dloss_row_stride, logsumexp_ptr, labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) @@ -173,15 +194,27 @@ def _cross_entropy_backward( else: dloss = 0.0 - x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + # Do logit softcapping for Gemma 2: t * tanh(1/t * x) + if DO_SOFTCAPPING: + # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) + partial = tl.math.tanh(x / SOFTCAP) + x = SOFTCAP * partial + pass + logsumexp = tl.load(logsumexp_ptr + row_idx) - y = tl.exp(x - logsumexp) + y = tl.exp(x.to(tl.float32) - logsumexp) y = tl.where( col_offsets == label_idx, y - 1.0, # exp(x - logsumexp) - 1 y, # exp(x - logsumexp) ) + if DO_SOFTCAPPING: + # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) + y = y * (1.0 - partial*partial) + pass + # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0. tl.store(logits_ptr + col_offsets, dloss * y, mask = mask) pass @@ -191,40 +224,46 @@ def _cross_entropy_backward( class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod - def forward(ctx, logits, labels): + def forward(ctx, logits, labels, logit_softcapping = 0): n_rows, vocab_size = logits.shape div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) - losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + + DO_SOFTCAPPING = (logit_softcapping != 0) if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) - logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), losses, logsumexp, labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + num_warps = num_warps, ) else: # For large vocabs > 65336 like Gemma 256K - logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda") + logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0") _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( logits, logits.stride(0), losses, logsumexp, labels, - VOCAB_SIZE = vocab_size, - N_CHUNKS = n_chunks, - BLOCK_SIZE = MAX_FUSED_SIZE, - num_warps = 32, + VOCAB_SIZE = vocab_size, + N_CHUNKS = n_chunks, + BLOCK_SIZE = MAX_FUSED_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + num_warps = 32, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately @@ -234,6 +273,8 @@ def forward(ctx, logits, labels): pass ctx.save_for_backward(logits, logsumexp, labels) + ctx.DO_SOFTCAPPING = DO_SOFTCAPPING + ctx.logit_softcapping = logit_softcapping return losses pass @@ -251,16 +292,18 @@ def backward(ctx, dlosses): dlosses, dlosses.stride(0), logsumexp, labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = 8, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, + SOFTCAP = ctx.logit_softcapping, + num_warps = 8, ) return logits, None, None, pass pass -def fast_cross_entropy_loss(logits, labels): +def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0): """ Arguments: logits: (batch, seq_len, vocab_size) @@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels): loss = Fast_CrossEntropyLoss.apply( logits.view(batch*seq_len, d), labels.view(-1), + logit_softcapping, ) n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py index df80fcb7..006e8c0f 100644 --- a/unsloth/kernels/geglu.py +++ b/unsloth/kernels/geglu.py @@ -41,7 +41,7 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def geglu_exact_forward_kernel(gate, up): batch, seq_len, hd = gate.shape n_elements = gate.numel() - out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda") + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out @@ -133,7 +133,7 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def geglu_approx_forward_kernel(gate, up): batch, seq_len, hd = gate.shape n_elements = gate.numel() - out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda") + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 4db89b78..f26e5965 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward( W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) row_var = tl.sum(X_row * X_row, axis = 0) / n_cols - inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl + inv_var = tl.math.rsqrt(row_var + eps) tl.store(r, inv_var) normed = X_row * inv_var output = normed * (W_row + 1.0) @@ -137,8 +137,8 @@ def forward(ctx, X, W, eps, gemma = False): n_rows, n_cols = X.shape BLOCK_SIZE, num_warps = calculate_settings(n_cols) - Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda") - r = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") + r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward fx[(n_rows,)]( diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py index ff6b1626..f81b7aae 100644 --- a/unsloth/kernels/swiglu.py +++ b/unsloth/kernels/swiglu.py @@ -41,7 +41,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def swiglu_fg_kernel(e, g): batch, seq_len, hd = e.shape n_elements = e.numel() - h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda") + h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) return h diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index ddee198b..935f1d43 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -105,14 +105,14 @@ def fast_dequantize(W, quant_state = None, out = None): # Create weight matrix if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda") + out = torch.empty(shape, dtype = dtype, device = "cuda:0") else: assert(out.shape == shape) assert(out.dtype == dtype) # NF4 dequantization of statistics n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda") + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") # Do dequantization ptr_out_absmax = get_ptr(out_absmax) @@ -161,7 +161,7 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda") + out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0") # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -179,7 +179,7 @@ def fast_gemv(X, W, quant_state, out = None): ldb = ctypes.c_int32(ldb) ldc = ctypes.c_int32(ldc) - df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda") + df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7a6954c9..73aa0c6c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -21,6 +21,12 @@ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub") +warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing") + +# Stop "Special tokens have been added in the vocabulary, ..." +import logging +logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) + import bitsandbytes as bnb from transformers.models.llama.modeling_llama import logger from transformers import AutoTokenizer @@ -31,7 +37,7 @@ import os import psutil -__version__ = "2024.6" +__version__ = "2024.7" # Get Flash Attention v2 if Ampere (RTX 30xx, A100) major_version, minor_version = torch.cuda.get_device_capability() @@ -80,8 +86,49 @@ "offload_output_embeddings", "is_bfloat16_supported", "unsloth_offloaded_gradient_checkpoint", + "torch_compile_options", ] +# Just remove max_autotune_gemm warning +import functools +@functools.lru_cache(None) +def is_big_gpu(index): + sms = torch.cuda.get_device_properties(index).multi_processor_count + if sms < 80: # V100 + # log.warning("not enough SMs to use max_autotune_gemm mode") + return False + return True +import torch._inductor.utils +torch._inductor.utils.is_big_gpu = is_big_gpu + + +# Torch compile arguments +torch_compile_arguments = [ + "config.dce = True", + "config.memory_planning = True", + "config.memory_pool = 'combined'", + "config.coordinate_descent_tuning = True", + "config.max_autotune_gemm = False", # GEMM is unnecessary + "config.autotune_multi_device = False", + "config.max_autotune_gemm_backends = 'ATEN'", # Not much faster + "config.aggressive_fusion = False", # Careful changes results! + "config.cuda.enable_cuda_lto = True", + "config.cuda.use_fast_math = True", + "config.cuda.compile_opt_level = '-O2'", +] +import torch._inductor.config as config +for _try_compile_argument in torch_compile_arguments: + try: exec(_try_compile_argument) + except: pass +pass +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, # Output Triton kernel outputs! + "triton.cudagraphs" : False, +} + def prepare_model_for_kbit_training( model : Any, diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 99374891..4c4515b7 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -247,6 +247,8 @@ def pre_patch(): GemmaModel .forward = LlamaModel_fast_forward GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference) PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(GemmaForCausalLM) + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py new file mode 100644 index 00000000..0669e422 --- /dev/null +++ b/unsloth/models/gemma2.py @@ -0,0 +1,538 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .llama import * +from ._utils import __version__ +from .gemma import ( + GemmaFixedRotaryEmbedding, + fast_geglu_inference, +) +from transformers.models.gemma2.modeling_gemma2 import ( + Gemma2Attention, + Gemma2DecoderLayer, + Gemma2Model, + Gemma2ForCausalLM, + Gemma2RotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.models.gemma2.modeling_gemma2 import * +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) +# For Pytorch 2.1.1 +try: + from transformers.models.gemma2.modeling_gemma2 import ( + Gemma2SdpaAttention, + Gemma2FlashAttention2, + ) +except: + Gemma2SdpaAttention = Gemma2Attention + Gemma2FlashAttention2 = Gemma2Attention +pass + + +# [TODO] We must randomnly use torch.compile? +# I checked the gradients and formulas and I'm sure it's correct. +# I'm stumped :( +@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) +def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True): + old_dtype = X.dtype + X = X.float() + X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \ + (1.0 + layernorm.weight.float()) + return X.to(old_dtype) +pass + + +# Logit softcapping +@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) +def gemma2_attention(Q, K, V, causal_mask, self, bsz, q_len): + n_heads = self.num_heads + head_dim = self.head_dim + n_kv_heads = self.num_key_value_heads + n_groups = self.num_key_value_groups + + # Grouped query attention + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + K = K.reshape(bsz, n_heads, q_len, head_dim) + V = V.reshape(bsz, n_heads, q_len, head_dim) + + s = self.config.hidden_size // self.config.num_attention_heads + t = self.config.attn_logit_softcapping + + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch.matmul(Q, K.transpose(2, 3)) + A = t * torch.tanh(A / t) # Logit softcapping + A += causal_mask[:q_len, :q_len] + A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) + A = torch.matmul(A, V) + A = A.transpose(1, 2).contiguous() + A = A.reshape(bsz, q_len, n_heads*head_dim) + return A +pass + + +# Logit softcapping +def Gemma2Attention_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + *args, **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + # Clear inference + if hasattr(self, "paged_attention"): + del self.paged_attention_K + del self.paged_attention_V + del self.paged_attention + del self.temp_QA + del self.temp_KV + del self.RH_Q + del self.attention + pass + + bsz, q_len, _ = hidden_states.size() + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + assert(n_kv_heads * n_groups == n_heads) + + Q, K, V = self.apply_qkv(self, hidden_states) + Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) + K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + + kv_seq_len = K.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if position_ids is None: + cos = self.rotary_emb.cos_cached + sin = self.rotary_emb.sin_cached + Q, K = fast_rope_embedding(Q, K, cos, sin) + else: + cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + pass + + if past_key_value is not None: + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) + pass + past_key_value = (K, V) if use_cache else None + + A = gemma2_attention(Q, K, V, causal_mask, self, bsz, kv_seq_len) + A = self.apply_o(self, A) + return A, None, past_key_value +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590 +def Gemma2DecoderLayer_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + *args, **kwargs, +): + if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: + out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0") + + # Self Attention + residual = hidden_states + hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) + hidden_states += residual + + # Fully Connected + residual = hidden_states + hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_geglu_inference(self.mlp, hidden_states) + hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight) + hidden_states += residual + else: + residual = hidden_states + hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = self.mlp(hidden_states) + hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True) + hidden_states = residual + hidden_states + pass + + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) + if use_cache: outputs += (present_key_value,) + return outputs +pass + + +from math import sqrt as math_sqrt +KV_CACHE_INCREMENT = 256 # KV Cache update size +torch_nn_functional_softmax = torch.nn.functional.softmax + +def Gemma2Attention_fast_forward_inference( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]], + position_ids, + do_prefill = False, + attention_mask = None, + use_sliding_window = False, +): + Xn = hidden_states + bsz, _, hd = hidden_states.size() + K1, V1 = past_key_value + dtype = Xn.dtype + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + attention_size = n_heads*head_dim + # assert(n_kv_heads * n_groups == n_heads) + seq_len = K1.shape[-2] + kv_seq_len = seq_len + 1 + + # Prefill phase + # if not hasattr(self, "paged_attention"): + if do_prefill: + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) + self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) + self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") + self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads) + self.half_head_dim = head_dim // 2 + self. t = self.config.attn_logit_softcapping + self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping + elif kv_seq_len >= self.paged_attention.shape[0]: + self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) + pass + + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) + Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) + Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) + # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) + cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) + sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) + h = self.half_head_dim + + RH_Q = self.RH_Q + RH_Q[:,:,:,:h] = Qn[:,:,:,h:] + RH_Q[:,:,:,h:] = Qn[:,:,:,:h] + torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) + Qn *= cos + Qn.addcmul_(RH_Q, sin) + + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + RH_K[:,:,:,:h] = Kn[:,:,:,h:] + RH_K[:,:,:,h:] = Kn[:,:,:,:h] + torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) + Kn *= cos + Kn.addcmul_(RH_K, sin) + + # New KV cache + # Kn = torch.cat([K1, Kn], dim = 2) + # Vn = torch.cat([V1, Vn], dim = 2) + self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) + self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) + Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) + Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) + + # Handle sliding windows + sliding_window = self.config.sliding_window + if use_sliding_window and kv_seq_len > sliding_window: + # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 + slicing_tokens = 1 - sliding_window + Knn = Kn[:, :, slicing_tokens:, :]#.contiguous() + Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous() + else: + Knn, Vnn = Kn, Vn + pass + + # Grouped query attention + _, _, cached_len, _ = Knn.shape + if n_groups != 1: + Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) + Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) + pass + # else: + # Knn, Vnn = Knn, Vnn + # pass + + # Attention + # if bsz == 1: + Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 + # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows + A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched + + A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping + + A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) + A = torch.matmul(A, Vnn, out = Qn) + # else: + # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + # pass + A = A.transpose(1, 2) + A = A.reshape(bsz, 1, attention_size) + A = fast_linear_forward(self.o_proj, A, out = self.temp_QA[1][:,:,:self.hidden_size]) + return A, (Kn, Vn) +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 +# @torch.inference_mode +def Gemma2Model_fast_forward_inference( + self, + input_ids, + past_key_values, + position_ids, + attention_mask = None, +): + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + input_ids = input_ids[:,:self.max_seq_length] + hidden_states = self.model.embed_tokens(input_ids) + hidden_states = hidden_states.to(self.config.torch_dtype) + # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32 + # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32 + hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype) + + bsz, q_len, hd = hidden_states.shape + seq_len = past_key_values[0][0].shape[-2] + if bsz != 1: + SWA = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + sliding_window = self.config.sliding_window, + ) + GA = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + ) + else: + SWA = attention_mask + GA = attention_mask + pass + + next_decoder_cache = [] + for idx, decoder_layer in enumerate(self.model.layers): + + use_sliding_window = idx % 2 == 0 + + residual = hidden_states + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( + decoder_layer.self_attn, + hidden_states = hidden_states, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = SWA if use_sliding_window else GA, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), + use_sliding_window = use_sliding_window, + ) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) + hidden_states += residual + + residual = hidden_states + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight) + hidden_states += residual + + next_decoder_cache.append(present_key_value) + pass + hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + + return BaseModelOutputWithPast( + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], + ) +pass + + +class FastGemma2Model(FastLlamaModel): + + @staticmethod + def pre_patch(): + Gemma2Attention .forward = Gemma2Attention_fast_forward + Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward + Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward + Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward + Gemma2Model .forward = LlamaModel_fast_forward + Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) + PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(Gemma2ForCausalLM) + + # Solves https://github.com/unslothai/unsloth/issues/168 + # Static KV Cache was introduced in 4.38.0, causing training to be much slower. + # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. + # https://github.com/huggingface/transformers/pull/27931 + # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py + import transformers.models.gemma2.modeling_gemma2 + transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding + return + pass + + + @staticmethod + def post_patch(model): + # Patch model for Gemma + layers = model.model.layers + + # Torch.compile fails on embedding matrix?? + # Workaround randomnly fixes it for torch versions < 2.2 + model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) + model.config.update({"unsloth_version" : __version__}) + + # We also do this for the lm_head + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = model.lm_head.weight + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + model.lm_head = lm_head + + # Gemma has tied weights! This means lm_head == embed_tokens + if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = model.model.embed_tokens.weight + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + model.lm_head = lm_head + pass + + # Also patch all dtypes - BnB seems to not allocate the correct type? + # BnB default dtype seems to be float16! + correct_dtype = lm_head.weight.dtype + + for name, module in model.named_modules(): + if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): + weight = module.weight + quant_state = weight.quant_state + + if type(quant_state) is list: + # BnB seems to have float16 as default! + module.weight.quant_state[2] = correct_dtype # Cast to correct dtype + else: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + quant_state.dtype = correct_dtype + pass + pass + # Downcast RoPE embedding to correct data type + # RoPE must be done in float32 for Gemma + # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \ + # and (module.cos_cached.dtype != correct_dtype): + + # module.cos_cached = module.cos_cached.to(correct_dtype) + # module.sin_cached = module.sin_cached.to(correct_dtype) + # pass + # pass + pass + + # Add 1 to weight + # return output * (1 + self.weight) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89 + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm + + # Freeze all parameters except LoRA + # We do this first since += 1 seems to not be liked by requires_grad = True + for name, param in model.named_parameters(): + if ".lora_A." in name or ".lora_B." in name: + param.requires_grad_(True) + else: + param.requires_grad_(False) + pass + + # Patch RMS Layernorm + for name, module in model.named_modules(): + if isinstance(module, Gemma2RMSNorm): + # Must be in float32 + # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36 + # module = module.to(torch.float32) + # Leave + 1 to Triton kernel itself + # module.weight += 1.0 # return output * (1 + self.weight) + if not hasattr(module, "variance_epsilon"): + module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon + pass + + # Clear deleted GPU items + import gc + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + return model + pass +pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 2368a376..e19b8572 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -15,6 +15,8 @@ import torch import gc from typing import Optional, Tuple, List, Union +from ._utils import * +from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers.models.llama.modeling_llama import ( logger, @@ -25,8 +27,6 @@ _prepare_4d_causal_attention_mask_for_sdpa, ) from ..kernels import * -from ._utils import * -from ._utils import __version__ from ..tokenizer_utils import * if HAS_FLASH_ATTENTION: from flash_attn import flash_attn_func @@ -78,6 +78,24 @@ def original_apply_o(self, X): KV_CACHE_INCREMENT = 256 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax +# Fix new HF's inference code +def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): + if "past_key_values" in kwargs: + input_ids = input_ids[:,[-1]] + kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] + kwargs["position_ids"] = kwargs["cache_position"] + return { "input_ids" : input_ids, **kwargs, } +pass + + +def fix_prepare_inputs_for_generation(module): + # Fix prepare_inputs_for_generation + if hasattr(module, "prepare_inputs_for_generation"): + module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation + pass +pass + + def LlamaAttention_fast_forward_inference( self, hidden_states: torch.Tensor, @@ -542,7 +560,8 @@ def LlamaModel_fast_forward( inputs_embeds = inputs_embeds.to(self.config.torch_dtype) # Normalized from Gemma - IS_GEMMA = self.config.model_type == "gemma" + IS_GEMMA = self.config.model_type.startswith("gemma") + IS_GEMMA2 = self.config.model_type.startswith("gemma2") train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -642,17 +661,38 @@ def LlamaModel_fast_forward( offloaded_gradient_checkpointing = True pass + # Gemma2 has alternating SWA and global attn + if IS_GEMMA2 and not hasattr(self, "SWA_mask"): + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + n = self.config.max_position_embeddings + self.SWA_mask = AttentionMaskConverter( + is_causal = True, + sliding_window = self.config.sliding_window, + )\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + .squeeze(0).squeeze(0) + + self.GA_mask = AttentionMaskConverter( + is_causal = True, + )\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + .squeeze(0).squeeze(0) + pass + # Go through every layer! for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None + mask = causal_mask + if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask + if offloaded_gradient_checkpointing: hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply( decoder_layer, hidden_states, - causal_mask, + mask, attention_mask, position_ids, past_key_values, @@ -670,7 +710,7 @@ def custom_forward(*inputs): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - causal_mask, + mask, attention_mask, position_ids, use_reentrant = True, @@ -681,7 +721,7 @@ def custom_forward(*inputs): else: layer_outputs = decoder_layer( hidden_states, - causal_mask=causal_mask, + causal_mask=mask, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -838,6 +878,7 @@ def _CausalLM_fast_forward( logits = logits.to(self.config.torch_dtype) loss = None + logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): @@ -849,7 +890,12 @@ def _CausalLM_fast_forward( loss = fast_cross_entropy_loss( logits = shift_logits, labels = shift_labels, + logit_softcapping = logit_softcapping, ) + elif logit_softcapping != 0: + logits *= (1.0 / logit_softcapping) + torch.tanh(logits, out = logits) + logits *= logit_softcapping pass if not return_dict: @@ -983,11 +1029,22 @@ def _fast_generate(*args, **kwargs): pass internal_model._flag_for_generation = True + # For newer HF + kwargs["cache_implementation"] = "dynamic" + + # Set pad token + old_pad_token_id = getattr(model.config, "pad_token_id", None) + old_eos_token_id = getattr(model.config, "eos_token_id", None) + model.config.pad_token_id = old_eos_token_id + # Autocasted with torch.autocast(device_type = device_type, dtype = dtype): output = generate(*args, **kwargs) pass + # Revert + model.config.pad_token_id = old_pad_token_id + # Unset a flag for generation! internal_model = model while hasattr(internal_model, "model"): @@ -1013,6 +1070,7 @@ def pre_patch(): LlamaModel .forward = LlamaModel_fast_forward LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference) PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(LlamaForCausalLM) # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. @@ -1056,7 +1114,7 @@ def from_pretrained( f"==((====))== Unsloth: Fast {model_patcher.__name__[4:-5]} patching release {__version__}\n"\ f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\ f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\ - f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\ + f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' print(statistics) model_patcher.pre_patch() @@ -1200,11 +1258,11 @@ def from_pretrained( 'nvidia-smi --query-gpu=memory.used --format=csv', shell = True) output = re.findall(rb'([\\d]{1,})[\\s]{1,}M', output) output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output) - if output > 1: raise RuntimeError( - 'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\ + if output > 1: print( + '********************\\nUnsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\ 'enabling it will require much more work, so we have to prioritize. Please understand!\\n'\\ - 'We do have a separate beta version, which you can contact us about!\\n'\\ - 'Thank you for your understanding and we appreciate it immensely!') + '********************\\nWe do have a separate beta version, which you can contact us about!\\n'\\ + '********************\\nThank you for your understanding and we appreciate it immensely!') for _ in range(3): gc.collect() torch.cuda.empty_cache()""" @@ -1760,6 +1818,7 @@ def patch_peft_model( elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx + elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx else: raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!") pass diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index d87af0a1..9134d4a2 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -26,8 +26,11 @@ major, minor = int(major), int(minor) SUPPORTS_FOURBIT = (major > 4) or (major == 4 and minor >= 37) SUPPORTS_GEMMA = (major > 4) or (major == 4 and minor >= 38) +SUPPORTS_GEMMA2 = (major > 4) or (major == 4 and minor >= 42) if SUPPORTS_GEMMA: - from .gemma import FastGemmaModel + from .gemma import FastGemmaModel +if SUPPORTS_GEMMA2: + from .gemma2 import FastGemma2Model del major, minor @@ -138,6 +141,15 @@ def from_pretrained( f"to obtain the latest transformers build, then restart this session."\ ) dispatch_model = FastGemmaModel + elif model_type == "gemma2": + if not SUPPORTS_GEMMA2: + raise RuntimeError( + f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\ + f"The minimum required version is 4.43.\n"\ + f'Try `pip install --upgrade "transformers>=4.43"`\n'\ + f"to obtain the latest transformers build, then restart this session."\ + ) + dispatch_model = FastGemma2Model elif model_type == "qwen2": dispatch_model = FastQwen2Model else: diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 4b400650..cec7332e 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -191,6 +191,14 @@ "mistralai/Codestral-22B-v0.1" : ( "mistral-community/Codestral-22B-v0.1", ), + "unsloth/gemma-2-9b-bnb-4bit" : ( + "unsloth/gemma-2-9b", + "google/gemma-2-9b", + ), + "unsloth/gemma-2-27b-bnb-4bit" : ( + "unsloth/gemma-2-27b", + "google/gemma-2-27b", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index d8bd85d4..e0b51a16 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -275,7 +275,8 @@ def pre_patch(): MistralModel .forward = LlamaModel_fast_forward MistralForCausalLM .forward = MistralForCausalLM_fast_forward PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward - + fix_prepare_inputs_for_generation(MistralForCausalLM) + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/qwen2.py b/unsloth/models/qwen2.py index 984bf7ca..5b9fff5d 100644 --- a/unsloth/models/qwen2.py +++ b/unsloth/models/qwen2.py @@ -43,6 +43,7 @@ def pre_patch(): Qwen2Model .forward = LlamaModel_fast_forward Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference) PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(Qwen2ForCausalLM) # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 50b09275..8727ca03 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -963,11 +963,11 @@ def patch_sft_trainer_tokenizer(): " 'nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\ "output = re.findall(rb'([\\d]{1,})[\\s]{1,}M', output)\n"\ "output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output)\n"\ - "if output > 1: raise RuntimeError(\n"\ - " 'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\\n"\ + "if output > 1: print(\n"\ + " '********************\\nUnsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\\n"\ " 'enabling it will require much more work, so we have to prioritize. Please understand!\\n'\\\n"\ - " 'We do have a separate beta version, which you can contact us about!\\n'\\\n"\ - " 'Thank you for your understanding and we appreciate it immensely!')\n"\ + " '********************\\nWe do have a separate beta version, which you can contact us about!\\n'\\\n"\ + " '********************\\nThank you for your understanding and we appreciate it immensely!')\n"\ "for _ in range(3):\n"\ " gc.collect()\n"\ " torch.cuda.empty_cache()\n"\