diff --git a/.gitignore b/.gitignore
index 2dae32884e6..2ec3a0661f3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,7 @@
 *.o
 *.a
+*.mlmodel
+*.mlmodelc
 .cache/
 .vs/
 .vscode/
@@ -32,3 +34,5 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
 examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
 
 extra/bench-gg.txt
+
+*.mlmodel*
diff --git a/CMakeLists.txt b/CMakeLists.txt
index fbbb5209d2d..89ec64fad2d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -54,6 +54,8 @@ if (APPLE)
     option(WHISPER_NO_AVX              "whisper: disable AVX" OFF)
     option(WHISPER_NO_AVX2             "whisper: disable AVX2" OFF)
     option(WHISPER_NO_FMA              "whisper: disable FMA" OFF)
+
+    option(WHISPER_COREML              "whisper: enable Core ML framework" OFF)
 else()
     option(WHISPER_SUPPORT_OPENBLAS    "whisper: support for OpenBLAS" OFF)
 endif()
@@ -86,16 +88,33 @@ endif()
 
 find_package(Threads REQUIRED)
 
-# on APPLE - include Accelerate framework
-if (APPLE AND NOT WHISPER_NO_ACCELERATE)
-    find_library(ACCELERATE_FRAMEWORK Accelerate)
-    if (ACCELERATE_FRAMEWORK)
-        message(STATUS "Accelerate framework found")
+# on APPLE
+if (APPLE)
+    # include Accelerate framework
+    if (NOT WHISPER_NO_ACCELERATE)
+        find_library(ACCELERATE_FRAMEWORK Accelerate)
+
+        if (ACCELERATE_FRAMEWORK)
+            message(STATUS "Accelerate framework found")
 
-        set(WHISPER_EXTRA_LIBS  ${WHISPER_EXTRA_LIBS}  ${ACCELERATE_FRAMEWORK})
-        set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
-    else()
-        message(WARNING "Accelerate framework not found")
+            set(WHISPER_EXTRA_LIBS  ${WHISPER_EXTRA_LIBS}  ${ACCELERATE_FRAMEWORK})
+            set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
+        else()
+            message(WARNING "Accelerate framework not found")
+        endif()
+    endif()
+
+    if (WHISPER_COREML)
+        find_library(FOUNDATION_FRAMEWORK Foundation)
+        find_library(COREML_FRAMEWORK CoreML)
+
+        if (COREML_FRAMEWORK)
+            message(STATUS "CoreML framework found")
+
+            set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
+        else()
+            message(WARNING "CoreML framework not found")
+        endif()
     endif()
 endif()
 
@@ -183,6 +202,33 @@ if (WHISPER_PERF)
     set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
 endif()
 
+#
+# whisper.coreml - Core ML support
+#
+
+if (WHISPER_COREML)
+    set(TARGET whisper.coreml)
+
+    add_library(${TARGET}
+        coreml/whisper-encoder.h
+        coreml/whisper-encoder.mm
+        coreml/whisper-encoder-impl.h
+        coreml/whisper-encoder-impl.m
+        )
+
+    include(DefaultTargetOptions)
+
+    target_include_directories(${TARGET} PUBLIC
+        .
+        )
+
+    target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
+
+    set_target_properties(${TARGET} PROPERTIES
+        COMPILE_FLAGS "-fobjc-arc"
+        )
+endif()
+
 #
 # whisper - this is the main library of the project
 #
@@ -202,6 +248,10 @@ target_include_directories(${TARGET} PUBLIC
     .
     )
 
+if (WHISPER_COREML)
+    target_link_libraries(${TARGET} PRIVATE whisper.coreml)
+endif()
+
 if (MSVC)
     target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
 
diff --git a/Makefile b/Makefile
index 929d45b9e82..7bf80e87e60 100644
--- a/Makefile
+++ b/Makefile
@@ -138,6 +138,10 @@ ifndef WHISPER_NO_ACCELERATE
 		LDFLAGS += -framework Accelerate
 	endif
 endif
+ifdef WHISPER_COREML
+	CXXFLAGS += -DWHISPER_USE_COREML
+	LDFLAGS  += -framework Foundation -framework CoreML
+endif
 ifdef WHISPER_OPENBLAS
 	CFLAGS  += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
 	LDFLAGS += -lopenblas
@@ -190,11 +194,23 @@ ggml.o: ggml.c ggml.h
 whisper.o: whisper.cpp whisper.h
 	$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
 
-libwhisper.a: ggml.o whisper.o
-	$(AR) rcs libwhisper.a ggml.o whisper.o
+ifndef WHISPER_COREML
+WHISPER_OBJ = whisper.o
+else
+whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
+	$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
+
+whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
+	$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
+
+WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
+endif
+
+libwhisper.a: ggml.o $(WHISPER_OBJ)
+	$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
 
-libwhisper.so: ggml.o whisper.o
-	$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
+libwhisper.so: ggml.o $(WHISPER_OBJ)
+	$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
 
 clean:
 	rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
@@ -208,21 +224,21 @@ CC_SDL=`sdl2-config --cflags --libs`
 SRC_COMMON = examples/common.cpp
 SRC_COMMON_SDL = examples/common-sdl.cpp
 
-main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
-	$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
+main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
+	$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
 	./main -h
 
-stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
-	$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
+stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
+	$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
 
-command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
-	$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
+command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
+	$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
 
-talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
-	$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
+talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
+	$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
 
-bench: examples/bench/bench.cpp ggml.o whisper.o
-	$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
+bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
+	$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
 
 #
 # Audio samples
diff --git a/coreml/whisper-encoder-impl.h b/coreml/whisper-encoder-impl.h
new file mode 100644
index 00000000000..9395acb250f
--- /dev/null
+++ b/coreml/whisper-encoder-impl.h
@@ -0,0 +1,142 @@
+//
+// CoremlEncoder.h
+//
+// This file was automatically generated and should not be edited.
+//
+
+#import <Foundation/Foundation.h>
+#import <CoreML/CoreML.h>
+#include <stdint.h>
+#include <os/log.h>
+
+NS_ASSUME_NONNULL_BEGIN
+
+
+/// Model Prediction Input Type
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
+@interface CoremlEncoderInput : NSObject<MLFeatureProvider>
+
+/// melSegment as 1 × 80 × 3000 3-dimensional array of floats
+@property (readwrite, nonatomic, strong) MLMultiArray * melSegment;
+- (instancetype)init NS_UNAVAILABLE;
+- (instancetype)initWithMelSegment:(MLMultiArray *)melSegment NS_DESIGNATED_INITIALIZER;
+
+@end
+
+
+/// Model Prediction Output Type
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
+@interface CoremlEncoderOutput : NSObject<MLFeatureProvider>
+
+/// output as multidimensional array of floats
+@property (readwrite, nonatomic, strong) MLMultiArray * output;
+- (instancetype)init NS_UNAVAILABLE;
+- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
+
+@end
+
+
+/// Class for model loading and prediction
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
+@interface CoremlEncoder : NSObject
+@property (readonly, nonatomic, nullable) MLModel * model;
+
+/**
+    URL of the underlying .mlmodelc directory.
+*/
++ (nullable NSURL *)URLOfModelInThisBundle;
+
+/**
+    Initialize CoremlEncoder instance from an existing MLModel object.
+
+    Usually the application does not use this initializer unless it makes a subclass of CoremlEncoder.
+    Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
+*/
+- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
+
+/**
+    Initialize CoremlEncoder instance with the model in this bundle.
+*/
+- (nullable instancetype)init;
+
+/**
+    Initialize CoremlEncoder instance with the model in this bundle.
+
+    @param configuration The model configuration object
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+*/
+- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+
+/**
+    Initialize CoremlEncoder instance from the model URL.
+
+    @param modelURL URL to the .mlmodelc directory for CoremlEncoder.
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+*/
+- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+
+/**
+    Initialize CoremlEncoder instance from the model URL.
+
+    @param modelURL URL to the .mlmodelc directory for CoremlEncoder.
+    @param configuration The model configuration object
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+*/
+- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+
+/**
+    Construct CoremlEncoder instance asynchronously with configuration.
+    Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
+
+    @param configuration The model configuration
+    @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
+*/
++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
+
+/**
+    Construct CoremlEncoder instance asynchronously with URL of .mlmodelc directory and optional configuration.
+
+    Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
+
+    @param modelURL The model URL.
+    @param configuration The model configuration
+    @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
+*/
++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
+
+/**
+    Make a prediction using the standard interface
+    @param input an instance of CoremlEncoderInput to predict from
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+    @return the prediction as CoremlEncoderOutput
+*/
+- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+
+/**
+    Make a prediction using the standard interface
+    @param input an instance of CoremlEncoderInput to predict from
+    @param options prediction options
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+    @return the prediction as CoremlEncoderOutput
+*/
+- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+
+/**
+    Make a prediction using the convenience interface
+    @param melSegment as 1 × 80 × 3000 3-dimensional array of floats:
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+    @return the prediction as CoremlEncoderOutput
+*/
+- (nullable CoremlEncoderOutput *)predictionFromMelSegment:(MLMultiArray *)melSegment error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+
+/**
+    Batch prediction
+    @param inputArray array of CoremlEncoderInput instances to obtain predictions from
+    @param options prediction options
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+    @return the predictions as NSArray<CoremlEncoderOutput *>
+*/
+- (nullable NSArray<CoremlEncoderOutput *> *)predictionsFromInputs:(NSArray<CoremlEncoderInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+@end
+
+NS_ASSUME_NONNULL_END
diff --git a/coreml/whisper-encoder-impl.m b/coreml/whisper-encoder-impl.m
new file mode 100644
index 00000000000..9d3a08b8d0b
--- /dev/null
+++ b/coreml/whisper-encoder-impl.m
@@ -0,0 +1,197 @@
+//
+// CoremlEncoder.m
+//
+// This file was automatically generated and should not be edited.
+//
+
+#if !__has_feature(objc_arc)
+#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
+#endif
+
+#import "whisper-encoder-impl.h"
+
+@implementation CoremlEncoderInput
+
+- (instancetype)initWithMelSegment:(MLMultiArray *)melSegment {
+    self = [super init];
+    if (self) {
+        _melSegment = melSegment;
+    }
+    return self;
+}
+
+- (NSSet<NSString *> *)featureNames {
+    return [NSSet setWithArray:@[@"melSegment"]];
+}
+
+- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
+    if ([featureName isEqualToString:@"melSegment"]) {
+        return [MLFeatureValue featureValueWithMultiArray:self.melSegment];
+    }
+    return nil;
+}
+
+@end
+
+@implementation CoremlEncoderOutput
+
+- (instancetype)initWithOutput:(MLMultiArray *)output {
+    self = [super init];
+    if (self) {
+        _output = output;
+    }
+    return self;
+}
+
+- (NSSet<NSString *> *)featureNames {
+    return [NSSet setWithArray:@[@"output"]];
+}
+
+- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
+    if ([featureName isEqualToString:@"output"]) {
+        return [MLFeatureValue featureValueWithMultiArray:self.output];
+    }
+    return nil;
+}
+
+@end
+
+@implementation CoremlEncoder
+
+
+/**
+    URL of the underlying .mlmodelc directory.
+*/
++ (nullable NSURL *)URLOfModelInThisBundle {
+    NSString *assetPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"CoremlEncoder" ofType:@"mlmodelc"];
+    if (nil == assetPath) { os_log_error(OS_LOG_DEFAULT, "Could not load CoremlEncoder.mlmodelc in the bundle resource"); return nil; }
+    return [NSURL fileURLWithPath:assetPath];
+}
+
+
+/**
+    Initialize CoremlEncoder instance from an existing MLModel object.
+
+    Usually the application does not use this initializer unless it makes a subclass of CoremlEncoder.
+    Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
+*/
+- (instancetype)initWithMLModel:(MLModel *)model {
+    self = [super init];
+    if (!self) { return nil; }
+    _model = model;
+    if (_model == nil) { return nil; }
+    return self;
+}
+
+
+/**
+    Initialize CoremlEncoder instance with the model in this bundle.
+*/
+- (nullable instancetype)init {
+    return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle error:nil];
+}
+
+
+/**
+    Initialize CoremlEncoder instance with the model in this bundle.
+
+    @param configuration The model configuration object
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+*/
+- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    return [self initWithContentsOfURL:(NSURL * _Nonnull)self.class.URLOfModelInThisBundle configuration:configuration error:error];
+}
+
+
+/**
+    Initialize CoremlEncoder instance from the model URL.
+
+    @param modelURL URL to the .mlmodelc directory for CoremlEncoder.
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+*/
+- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    MLModel *model = [MLModel modelWithContentsOfURL:modelURL error:error];
+    if (model == nil) { return nil; }
+    return [self initWithMLModel:model];
+}
+
+
+/**
+    Initialize CoremlEncoder instance from the model URL.
+
+    @param modelURL URL to the .mlmodelc directory for CoremlEncoder.
+    @param configuration The model configuration object
+    @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
+*/
+- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    MLModel *model = [MLModel modelWithContentsOfURL:modelURL configuration:configuration error:error];
+    if (model == nil) { return nil; }
+    return [self initWithMLModel:model];
+}
+
+
+/**
+    Construct CoremlEncoder instance asynchronously with configuration.
+    Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
+
+    @param configuration The model configuration
+    @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
+*/
++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler {
+    [self loadContentsOfURL:(NSURL * _Nonnull)[self URLOfModelInThisBundle]
+              configuration:configuration
+          completionHandler:handler];
+}
+
+
+/**
+    Construct CoremlEncoder instance asynchronously with URL of .mlmodelc directory and optional configuration.
+
+    Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
+
+    @param modelURL The model URL.
+    @param configuration The model configuration
+    @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
+*/
++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler {
+    [MLModel loadContentsOfURL:modelURL
+                 configuration:configuration
+             completionHandler:^(MLModel *model, NSError *error) {
+        if (model != nil) {
+            CoremlEncoder *typedModel = [[CoremlEncoder alloc] initWithMLModel:model];
+            handler(typedModel, nil);
+        } else {
+            handler(nil, error);
+        }
+    }];
+}
+
+- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    return [self predictionFromFeatures:input options:[[MLPredictionOptions alloc] init] error:error];
+}
+
+- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
+    if (!outFeatures) { return nil; }
+    return [[CoremlEncoderOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
+}
+
+- (nullable CoremlEncoderOutput *)predictionFromMelSegment:(MLMultiArray *)melSegment error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    CoremlEncoderInput *input_ = [[CoremlEncoderInput alloc] initWithMelSegment:melSegment];
+    return [self predictionFromFeatures:input_ error:error];
+}
+
+- (nullable NSArray<CoremlEncoderOutput *> *)predictionsFromInputs:(NSArray<CoremlEncoderInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
+    id<MLBatchProvider> inBatch = [[MLArrayBatchProvider alloc] initWithFeatureProviderArray:inputArray];
+    id<MLBatchProvider> outBatch = [self.model predictionsFromBatch:inBatch options:options error:error];
+    if (!outBatch) { return nil; }
+    NSMutableArray<CoremlEncoderOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
+    for (NSInteger i = 0; i < outBatch.count; i++) {
+        id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
+        CoremlEncoderOutput * result = [[CoremlEncoderOutput alloc] initWithOutput:(MLMultiArray *)[resultProvider featureValueForName:@"output"].multiArrayValue];
+        [results addObject:result];
+    }
+    return results;
+}
+
+@end
diff --git a/coreml/whisper-encoder.h b/coreml/whisper-encoder.h
new file mode 100644
index 00000000000..84bbe416505
--- /dev/null
+++ b/coreml/whisper-encoder.h
@@ -0,0 +1,22 @@
+// Wrapper of the Core ML Whisper Encoder model
+//
+// Code is derived from the work of Github user @wangchou
+// ref: https://github.com/wangchou/callCoreMLFromCpp
+
+#if __cplusplus
+extern "C" {
+#endif
+
+struct whisper_coreml_context;
+
+struct whisper_coreml_context * whisper_coreml_init(const char * path_model);
+void whisper_coreml_free(struct whisper_coreml_context * ctx);
+
+void whisper_coreml_encode(
+        const whisper_coreml_context * ctx,
+                               float * mel,
+                               float * out);
+
+#if __cplusplus
+}
+#endif
diff --git a/coreml/whisper-encoder.mm b/coreml/whisper-encoder.mm
new file mode 100644
index 00000000000..09091c2003c
--- /dev/null
+++ b/coreml/whisper-encoder.mm
@@ -0,0 +1,61 @@
+#import "coreml/whisper-encoder.h"
+#import "coreml/whisper-encoder-impl.h"
+
+#import <CoreML/CoreML.h>
+
+#include <stdlib.h>
+
+#if __cplusplus
+extern "C" {
+#endif
+
+struct whisper_coreml_context {
+    const void * data;
+};
+
+struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
+    NSString * path_model_str = [[NSString alloc] initWithUTF8String:path_model];
+
+    NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
+
+    const void * data = CFBridgingRetain([[CoremlEncoder alloc] initWithContentsOfURL:url_model error:nil]);
+
+    if (data == NULL) {
+        return NULL;
+    }
+
+    whisper_coreml_context * ctx = new whisper_coreml_context;
+
+    ctx->data = data;
+
+    return ctx;
+}
+
+void whisper_coreml_free(struct whisper_coreml_context * ctx) {
+    CFRelease(ctx->data);
+    delete ctx;
+}
+
+void whisper_coreml_encode(
+        const whisper_coreml_context * ctx,
+                               float * mel,
+                               float * out) {
+    MLMultiArray * inMultiArray = [
+        [MLMultiArray alloc] initWithDataPointer: mel
+                                           shape: @[@1, @80, @3000]
+                                        dataType: MLMultiArrayDataTypeFloat32
+                                         strides: @[@(240000), @(3000), @1]
+                                     deallocator: nil
+                                           error: nil
+    ];
+
+    CoremlEncoderOutput * outCoreML = [(__bridge id) ctx->data predictionFromMelSegment:inMultiArray error:nil];
+
+    MLMultiArray * outMA = outCoreML.output;
+
+    memcpy(out, outMA.dataPointer, outMA.count * sizeof(float));
+}
+
+#if __cplusplus
+}
+#endif
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 9f804a7aeac..a4f37359c1e 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -63,4 +63,5 @@ else()
     add_subdirectory(command)
     add_subdirectory(bench)
     add_subdirectory(talk)
+    add_subdirectory(talk.llama)
 endif()
diff --git a/examples/talk.llama/.gitignore b/examples/talk.llama/.gitignore
new file mode 100644
index 00000000000..6b780a24045
--- /dev/null
+++ b/examples/talk.llama/.gitignore
@@ -0,0 +1,2 @@
+eleven-labs.py
+audio.mp3
diff --git a/examples/talk.llama/CMakeLists.txt b/examples/talk.llama/CMakeLists.txt
new file mode 100644
index 00000000000..c278deb8dda
--- /dev/null
+++ b/examples/talk.llama/CMakeLists.txt
@@ -0,0 +1,12 @@
+if (WHISPER_SUPPORT_SDL2)
+    # talk.llama
+    set(TARGET talk-llama)
+
+    # TODO: this is temporary
+    #       need to export ggml symbols for MSVC, but too lazy ..
+    add_executable(${TARGET} talk-llama.cpp llama.cpp)
+
+    include(DefaultTargetOptions)
+
+    target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
+endif ()
diff --git a/examples/talk.llama/README.md b/examples/talk.llama/README.md
new file mode 100644
index 00000000000..821d40dca60
--- /dev/null
+++ b/examples/talk.llama/README.md
@@ -0,0 +1,2 @@
+# talk.llama
+
diff --git a/examples/talk.llama/llama.cpp b/examples/talk.llama/llama.cpp
new file mode 100644
index 00000000000..2bd520353ef
--- /dev/null
+++ b/examples/talk.llama/llama.cpp
@@ -0,0 +1,1865 @@
+#include "llama.h"
+
+#include "ggml.h"
+
+#include <cinttypes>
+#include <fstream>
+#include <random>
+#include <map>
+#include <unordered_map>
+#include <queue>
+#include <regex>
+#include <cassert>
+#include <cstring>
+
+#define LLAMA_USE_SCRATCH
+#define LLAMA_MAX_SCRATCH_BUFFERS 16
+
+#define LLAMA_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            abort(); \
+        } \
+    } while (0)
+
+
+// determine number of model parts based on the dimension
+static const std::unordered_map<int, int> LLAMA_N_PARTS = {
+    { 4096, 1 },
+    { 5120, 2 },
+    { 6656, 4 },
+    { 8192, 8 },
+};
+
+// available llama models
+enum e_model {
+    MODEL_UNKNOWN,
+    MODEL_7B,
+    MODEL_13B,
+    MODEL_30B,
+    MODEL_65B,
+};
+
+static const size_t MB = 1024*1024;
+
+// computed for n_ctx == 2048
+// TODO: dynamically determine these sizes
+//       needs modifications in ggml
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
+    { MODEL_7B,    512ull*MB },
+    { MODEL_13B,   512ull*MB },
+    { MODEL_30B,   512ull*MB },
+    { MODEL_65B,   512ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
+    { MODEL_7B,    512ull*MB },
+    { MODEL_13B,   512ull*MB },
+    { MODEL_30B,   512ull*MB },
+    { MODEL_65B,   512ull*MB },
+};
+
+// 2*n_embd*n_ctx*n_layer*sizeof(float16)
+static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
+    { MODEL_7B,   1026ull*MB },
+    { MODEL_13B,  1608ull*MB },
+    { MODEL_30B,  3124ull*MB },
+    { MODEL_65B,  5120ull*MB },
+};
+
+// this is mostly needed for temporary mul_mat buffers to dequantize the data
+// not actually needed if BLAS is disabled
+static const std::map<e_model, size_t> MEM_REQ_EVAL = {
+    { MODEL_7B,   768ull*MB },
+    { MODEL_13B, 1024ull*MB },
+    { MODEL_30B, 1280ull*MB },
+    { MODEL_65B, 1536ull*MB },
+};
+
+// default hparams (LLaMA 7B)
+struct llama_hparams {
+    int32_t n_vocab = 32000;
+    int32_t n_ctx   = 512;   // this is provided as user input?
+    int32_t n_embd  = 4096;
+    int32_t n_mult  = 256;
+    int32_t n_head  = 32;
+    int32_t n_layer = 32;
+    int32_t n_rot   = 64;
+    int32_t f16     = 1;
+};
+
+struct llama_layer {
+    // normalization
+    struct ggml_tensor * attention_norm;
+
+    // attention
+    struct ggml_tensor * wq;
+    struct ggml_tensor * wk;
+    struct ggml_tensor * wv;
+    struct ggml_tensor * wo;
+
+    // normalization
+    struct ggml_tensor * ffn_norm;
+
+    // ff
+    struct ggml_tensor * w1;
+    struct ggml_tensor * w2;
+    struct ggml_tensor * w3;
+};
+
+struct llama_kv_cache {
+    struct ggml_tensor * k;
+    struct ggml_tensor * v;
+
+    struct ggml_context * ctx;
+
+    std::vector<uint8_t> buf;
+
+    int n; // number of tokens currently in the cache
+};
+
+struct llama_model {
+    e_model type = MODEL_UNKNOWN;
+
+    llama_hparams hparams;
+
+    struct ggml_tensor * tok_embeddings;
+
+    struct ggml_tensor * norm;
+    struct ggml_tensor * output;
+
+    std::vector<llama_layer> layers;
+
+    // context
+    struct ggml_context * ctx;
+
+    // key + value cache for the self attention
+    // TODO: move to llama_state
+    struct llama_kv_cache kv_self;
+
+    // the model memory buffer
+    std::vector<uint8_t> buf;
+
+    // tensors
+    int n_loaded;
+    std::unordered_map<std::string, struct ggml_tensor *> tensors;
+};
+
+struct llama_vocab {
+    using id    = int32_t;
+    using token = std::string;
+
+    struct token_score {
+        token tok;
+        float score;
+    };
+
+    std::unordered_map<token, id> token_to_id;
+    std::vector<token_score> id_to_token;
+};
+
+struct llama_context {
+    std::mt19937 rng;
+
+    int64_t t_load_us = 0;
+    int64_t t_start_us = 0;
+
+    int64_t t_sample_us = 0;
+    int64_t t_eval_us   = 0;
+    int64_t t_p_eval_us = 0;
+
+    int32_t n_sample = 0; // number of tokens sampled
+    int32_t n_eval   = 0; // number of eval calls
+    int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
+
+    llama_model model;
+    llama_vocab vocab;
+
+    size_t mem_per_token = 0;
+
+    // decode output (2-dimensional array: [n_tokens][n_vocab])
+    std::vector<float> logits;
+    bool logits_all = false;
+
+    // input embedding (1-dimensional array: [n_embd])
+    std::vector<float> embedding;
+
+    // memory buffers used to evaluate the model
+    // TODO: move in llama_state
+    std::vector<uint8_t> buf_compute;
+    std::vector<uint8_t> buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
+
+    int    buf_last = 0;
+    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
+
+    void use_buf(struct ggml_context * ctx, int i) {
+#if defined(LLAMA_USE_SCRATCH)
+        size_t last_size = 0;
+
+        if (i == -1) {
+            last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
+        } else {
+            auto & buf = buf_scratch[i];
+            last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
+        }
+
+        if (buf_last >= 0) {
+            buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
+        }
+
+        buf_last = i;
+#else
+        (void) i;
+        (void) ctx;
+#endif
+    }
+
+    size_t get_buf_max_mem(int i) const {
+#if defined(LLAMA_USE_SCRATCH)
+        return buf_max_size[i];
+#else
+        (void) i;
+        return 0;
+#endif
+    }
+};
+
+//
+// kv cache
+//
+
+static bool kv_cache_init(
+        const struct llama_hparams & hparams,
+             struct llama_kv_cache & cache,
+                         ggml_type   wtype,
+                               int   n_ctx) {
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+
+    const int n_mem      = n_layer*n_ctx;
+    const int n_elements = n_embd*n_mem;
+
+    cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
+
+    struct ggml_init_params params;
+    params.mem_size   = cache.buf.size();
+    params.mem_buffer = cache.buf.data();
+
+    cache.ctx = ggml_init(params);
+
+    if (!cache.ctx) {
+        fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+        return false;
+    }
+
+    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+    return true;
+}
+
+static void kv_cache_free(struct llama_kv_cache & cache) {
+    if (cache.ctx) {
+        ggml_free(cache.ctx);
+        cache.ctx = nullptr;
+    }
+}
+
+struct llama_context_params llama_context_default_params() {
+    struct llama_context_params result = {
+        /*.n_ctx                       =*/ 512,
+        /*.n_parts                     =*/ -1,
+        /*.seed                        =*/ 0,
+        /*.f16_kv                      =*/ false,
+        /*.logits_all                  =*/ false,
+        /*.vocab_only                  =*/ false,
+        /*.use_mlock                   =*/ false,
+        /*.embedding                   =*/ false,
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
+    };
+
+    return result;
+}
+
+//
+// model loading
+//
+
+static bool llama_model_load(
+        const std::string & fname,
+        llama_context & lctx,
+        int n_ctx,
+        int n_parts,
+        ggml_type memory_type,
+        bool vocab_only,
+        llama_progress_callback progress_callback,
+        void *progress_callback_user_data) {
+    fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+
+    const int64_t t_start_us = ggml_time_us();
+
+    lctx.t_start_us = t_start_us;
+
+    std::vector<char> f_buf(1024*1024);
+
+    auto & model = lctx.model;
+    auto & vocab = lctx.vocab;
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
+            fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
+                    __func__, fname.c_str());
+            return false;
+        }
+        if (magic != LLAMA_FILE_MAGIC) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+
+        uint32_t format_version;
+        fin.read((char *) &format_version, sizeof(format_version));
+
+        if (format_version != LLAMA_FILE_VERSION) {
+            fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
+                    __func__, fname.c_str(), format_version, LLAMA_FILE_VERSION);
+            return false;
+        }
+    }
+
+    int n_ff = 0;
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        //fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_mult,  sizeof(hparams.n_mult));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fin.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        hparams.n_ctx = n_ctx;
+
+        n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+
+        if (n_parts < 1) {
+            n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+        }
+
+        // temp warning to tell the user to use "--n_parts"
+        if (hparams.f16 == 4 && n_parts != 1) {
+            fprintf(stderr, "%s: GPTQ model detected - are you sure n_parts should be %d? we normally expect it to be 1\n", __func__, n_parts);
+            fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
+        }
+
+        if (hparams.n_layer == 32) {
+            model.type = e_model::MODEL_7B;
+        }
+
+        if (hparams.n_layer == 40) {
+            model.type = e_model::MODEL_13B;
+        }
+
+        if (hparams.n_layer == 60) {
+            model.type = e_model::MODEL_30B;
+        }
+
+        if (hparams.n_layer == 80) {
+            model.type = e_model::MODEL_65B;
+        }
+
+        fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        fprintf(stderr, "%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        fprintf(stderr, "%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        fprintf(stderr, "%s: n_mult  = %d\n", __func__, hparams.n_mult);
+        fprintf(stderr, "%s: n_head  = %d\n", __func__, hparams.n_head);
+        fprintf(stderr, "%s: n_layer = %d\n", __func__, hparams.n_layer);
+        fprintf(stderr, "%s: n_rot   = %d\n", __func__, hparams.n_rot);
+        fprintf(stderr, "%s: f16     = %d\n", __func__, hparams.f16);
+        fprintf(stderr, "%s: n_ff    = %d\n", __func__, n_ff);
+        fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts);
+        fprintf(stderr, "%s: type    = %d\n", __func__, model.type);
+    }
+
+    // load vocab
+    {
+        std::string word;
+        vocab.id_to_token.resize(model.hparams.n_vocab);
+        std::vector<char> tmp(64);
+
+        for (int i = 0; i < model.hparams.n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            word.resize(len);
+            if (len > 0) {
+                tmp.resize(len);
+                fin.read(tmp.data(), len);
+                word.assign(tmp.data(), len);
+            } else {
+                word.clear();
+            }
+
+            float score;
+            fin.read((char *) &score, sizeof(score));
+
+            vocab.token_to_id[word] = i;
+
+            auto &tok_score = vocab.id_to_token[i];
+            tok_score.tok = word;
+            tok_score.score = score;
+        }
+    }
+
+    if (vocab_only) {
+        return true;
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+    // in order to save memory and also to speed up the computation
+    // wtype is for per-layer weights, while vtype is for other weights
+    ggml_type wtype, vtype;
+    switch (model.hparams.f16) {
+        case 0: wtype = vtype = GGML_TYPE_F32;  break;
+        case 1: wtype = vtype = GGML_TYPE_F16;  break;
+        case 2: wtype = vtype = GGML_TYPE_Q4_0; break;
+        case 3: wtype = vtype = GGML_TYPE_Q4_1; break;
+        case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break;
+        default:
+                {
+                    fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
+                            __func__, fname.c_str(), model.hparams.f16);
+                    return false;
+                }
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
+
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
+
+        ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
+
+        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
+        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
+        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
+
+        ctx_size += (5 + 10*n_layer)*256; // object overhead
+
+        fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // print memory requirements
+    {
+        const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
+
+        // this is the total memory required to run the inference
+        const size_t mem_required =
+            ctx_size +
+            MEM_REQ_SCRATCH0.at(model.type) +
+            MEM_REQ_SCRATCH1.at(model.type) +
+            MEM_REQ_EVAL.at    (model.type);
+
+        // this is the memory required by one llama_state
+        const size_t mem_required_state =
+            scale*MEM_REQ_KV_SELF.at(model.type);
+
+        fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__,
+                mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
+    }
+
+    // create the ggml context
+    {
+        lctx.model.buf.resize(ctx_size);
+
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ lctx.model.buf.size(),
+            /*.mem_buffer =*/ lctx.model.buf.data(),
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.tok_embeddings = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);
+
+        model.norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.output = ggml_new_tensor_2d(ctx, vtype,         n_embd, n_vocab);
+
+        // map by name
+        model.tensors["tok_embeddings.weight"] = model.tok_embeddings;
+
+        model.tensors["norm.weight"]   = model.norm;
+        model.tensors["output.weight"] = model.output;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+            layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+
+            layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+            layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_ff);
+            layer.w2 = ggml_new_tensor_2d(ctx, wtype,   n_ff, n_embd);
+            layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_ff);
+
+            // map by name
+            model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm;
+
+            model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq;
+            model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk;
+            model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv;
+            model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo;
+
+            model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm;
+
+            model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1;
+            model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2;
+            model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3;
+        }
+    }
+
+    const size_t file_offset = fin.tellg();
+
+    fin.close();
+
+    std::vector<uint8_t> tmp;
+
+    if (progress_callback) {
+        progress_callback(0.0, progress_callback_user_data);
+    }
+
+    for (int i = 0; i < n_parts; ++i) {
+        const int part_id = i;
+        //const int part_id = n_parts - i - 1;
+
+        std::string fname_part = fname;
+        if (i > 0) {
+            fname_part += "." + std::to_string(i);
+        }
+
+        fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
+
+        fin = std::ifstream(fname_part, std::ios::binary);
+        fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
+
+        fin.seekg(0, fin.end);
+        const size_t file_size = fin.tellg();
+
+        fin.seekg(file_offset);
+
+        // load weights
+        {
+            size_t total_size = 0;
+
+            model.n_loaded = 0;
+
+            fprintf(stderr, "%s: ", __func__);
+
+            while (true) {
+                int32_t n_dims;
+                int32_t length;
+                int32_t ftype;
+
+                fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+                fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+                fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+                if (fin.eof()) {
+                    break;
+                }
+
+                int32_t nelements = 1;
+                int32_t ne[2] = { 1, 1 };
+                for (int i = 0; i < n_dims; ++i) {
+                    fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                    nelements *= ne[i];
+                }
+
+                std::string name(length, 0);
+                fin.read(&name[0], length);
+
+                if (model.tensors.find(name.data()) == model.tensors.end()) {
+                    fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                    return false;
+                }
+
+                // split_type = 0: split by columns
+                // split_type = 1: split by rows
+                int split_type = 0;
+
+                // split_type = 0:
+                // regex:
+                //   - tok_embeddings.*
+                //   - layers.*.attention.wo.weight
+                //   - layers.*.feed_forward.w2.weight
+
+                // split_type = 1:
+                // regex:
+                //   - output.*
+                //   - layers.*.attention.wq.weight
+                //   - layers.*.attention.wk.weight
+                //   - layers.*.attention.wv.weight
+                //   - layers.*.feed_forward.w1.weight
+                //   - layers.*.feed_forward.w3.weight
+                if (name.find("tok_embeddings") != std::string::npos) {
+                    split_type = 0;
+                } else if (name.find("layers") != std::string::npos) {
+                    if (name.find("attention.wo.weight") != std::string::npos) {
+                        split_type = 0;
+                    } else if (name.find("feed_forward.w2.weight") != std::string::npos) {
+                        split_type = 0;
+                    } else {
+                        split_type = 1;
+                    }
+                } else if (name.find("output") != std::string::npos) {
+                    split_type = 1;
+                }
+
+                auto tensor = model.tensors[name.data()];
+
+                if (n_dims == 1) {
+                    if (ggml_nelements(tensor) != nelements) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                        return false;
+                    }
+                } else {
+                    if (ggml_nelements(tensor)/n_parts != nelements) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                        return false;
+                    }
+                }
+
+                if (n_dims == 1) {
+                    if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                                __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
+                        return false;
+                    }
+                } else {
+                    if (split_type == 0) {
+                        if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
+                            fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                                    __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
+                            return false;
+                        }
+                    } else {
+                        if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
+                            fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                                    __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
+                            return false;
+                        }
+                    }
+                }
+
+                if (0) {
+                    static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                    fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
+                }
+
+                size_t bpe = 0;
+
+                switch (ftype) {
+                    case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
+                    case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
+                    case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
+                    case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
+                    default:
+                            {
+                                fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
+                                return false;
+                            }
+                };
+
+                if (n_dims == 1 || n_parts == 1) {
+                    if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                                __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                        return false;
+                    }
+
+                    if (part_id == 0) {
+                        fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+                    } else {
+                        fin.seekg(ggml_nbytes(tensor), std::ios::cur);
+                    }
+
+                    total_size += ggml_nbytes(tensor);
+                } else {
+                    if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                                __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
+                        return false;
+                    }
+
+                    if (split_type == 0) {
+                        const int np0 = ne[0];
+
+                        const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
+                        assert(row_size == tensor->nb[1]);
+
+                        for (int i1 = 0; i1 < ne[1]; ++i1) {
+                            const size_t offset_row = i1*row_size;
+                            const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
+                            fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
+                        }
+                    } else {
+                        const int np1 = ne[1];
+
+                        const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
+
+                        for (int i1 = 0; i1 < ne[1]; ++i1) {
+                            const size_t offset_row = (i1 + part_id*np1)*row_size;
+                            fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
+                        }
+                    }
+
+                    total_size += ggml_nbytes(tensor)/n_parts;
+                }
+
+                //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+                model.n_loaded++;
+
+                // progress
+                if (progress_callback) {
+                    double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
+                    double current_progress = (double(i) + current_file_progress) / double(n_parts);
+                    progress_callback(current_progress, progress_callback_user_data);
+                }
+                if (model.n_loaded % 8 == 0) {
+                    fprintf(stderr, ".");
+                    fflush(stderr);
+                }
+            }
+
+            fprintf(stderr, " done\n");
+
+            fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
+            if (model.n_loaded == 0) {
+                fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+            } else if (model.n_loaded != (int) model.tensors.size()) {
+                fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+                return false;
+            }
+        }
+
+        fin.close();
+    }
+
+    lctx.t_load_us = ggml_time_us() - t_start_us;
+
+    if (progress_callback) {
+        progress_callback(1.0, progress_callback_user_data);
+    }
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - lctx:      llama context
+//   - tokens:    new batch of tokens to process
+//   - n_past:    the context size so far
+//   - n_threads: number of threads to use
+//
+static bool llama_eval_internal(
+        llama_context & lctx,
+    const llama_token * tokens,
+            const int   n_tokens,
+            const int   n_past,
+            const int   n_threads) {
+    const int64_t t_start_us = ggml_time_us();
+
+    const int N = n_tokens;
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+
+    auto & kv_self = model.kv_self;
+
+    LLAMA_ASSERT(!!kv_self.ctx);
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+    const int n_vocab = hparams.n_vocab;
+    const int n_rot   = hparams.n_embd/hparams.n_head;
+
+    auto & mem_per_token = lctx.mem_per_token;
+    auto & buf_compute   = lctx.buf_compute;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute.size(),
+        /*.mem_buffer =*/ buf_compute.data(),
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    // for big prompts, if BLAS is enabled, it is better to use only one thread
+    // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
+    ggml_cgraph gf = {};
+    gf.n_threads = N > 255 && ggml_cpu_has_blas() ? 1 : n_threads;
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, tokens, N*ggml_element_size(embd));
+
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * inpSA = inpL;
+
+        struct ggml_tensor * cur;
+
+        lctx.use_buf(ctx0, 0);
+
+        // norm
+        {
+            cur = ggml_rms_norm(ctx0, inpL);
+
+            // cur = attention_norm*cur
+            cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
+                        cur);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_rope(ctx0,
+                            ggml_cpy(ctx0,
+                                Qcur,
+                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                            n_past, n_rot, 0),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_rope(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            n_past, n_rot, 1),
+                        0, 2, 1, 3);
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)));
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                    ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                    ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection (no bias)
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].wo,
+                    cur);
+        }
+
+        lctx.use_buf(ctx0, 1);
+
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_rms_norm(ctx0, inpFF);
+
+                // cur = ffn_norm*cur
+                cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
+                        cur);
+            }
+
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model.layers[il].w3,
+                    cur);
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w1,
+                    cur);
+
+            // SILU activation
+            cur = ggml_silu(ctx0, cur);
+
+            cur = ggml_mul(ctx0, cur, tmp);
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w2,
+                    cur);
+        }
+
+        cur = ggml_add(ctx0, cur, inpFF);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    lctx.use_buf(ctx0, 0);
+
+    // used at the end to optionally extract the embeddings
+    struct ggml_tensor * embeddings = NULL;
+
+    // norm
+    {
+
+        inpL = ggml_rms_norm(ctx0, inpL);
+
+        // inpL = norm*inpL
+        inpL = ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.norm, inpL),
+                    inpL);
+
+        embeddings = inpL;
+    }
+
+    // lm_head
+    inpL = ggml_mul_mat(ctx0, model.output, inpL);
+
+    lctx.use_buf(ctx0, -1);
+
+    // logits -> probs
+    //inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute       (ctx0, &gf);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // extract logits
+    {
+        auto & logits_out = lctx.logits;
+
+        if (lctx.logits_all) {
+            logits_out.resize(n_vocab * N);
+            memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+        } else {
+            // return result for just the last token
+            logits_out.resize(n_vocab);
+            memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+        }
+    }
+
+    // extract embeddings
+    if (lctx.embedding.size()) {
+        auto & embedding_out = lctx.embedding;
+
+        embedding_out.resize(n_embd);
+        memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
+    }
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+
+#if 0
+    printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
+            ggml_used_mem(ctx0)/1024.0/1024.0,
+            lctx.get_buf_max_mem(0)/1024.0/1024.0,
+            lctx.get_buf_max_mem(1)/1024.0/1024.0);
+#endif
+
+    ggml_free(ctx0);
+
+    // measure the performance only for the single-token evals
+    if (N == 1) {
+        lctx.t_eval_us += ggml_time_us() - t_start_us;
+        lctx.n_eval++;
+    }
+    else if (N > 1) {
+        lctx.t_p_eval_us += ggml_time_us() - t_start_us;
+        lctx.n_p_eval += N;
+    }
+
+    return true;
+}
+
+//
+// tokenizer
+//
+
+static size_t utf8_len(char src) {
+    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+    return lookup[highbits];
+}
+
+struct llama_sp_symbol {
+    using index = int;
+    index prev;
+    index next;
+    const char * text;
+    size_t n;
+};
+
+struct llama_sp_bigram {
+    struct comparator {
+        bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
+            return (l.score < r.score) || (l.score == r.score && l.left > r.left);
+        }
+    };
+    using queue_storage = std::vector<llama_sp_bigram>;
+    using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
+    llama_sp_symbol::index left;
+    llama_sp_symbol::index right;
+    float score;
+    size_t size;
+};
+
+// original implementation:
+// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
+struct llama_tokenizer {
+    llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        // split string into utf8 chars
+        int index = 0;
+        size_t offs = 0;
+        while (offs < text.size()) {
+            llama_sp_symbol sym;
+            size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
+            sym.text = text.c_str() + offs;
+            sym.n = char_len;
+            offs += char_len;
+            sym.prev = index - 1;
+            sym.next = offs == text.size() ? -1 : index + 1;
+            index++;
+            symbols_.emplace_back(std::move(sym));
+        }
+
+        // seed the work queue with all possible 2-character tokens.
+        for (size_t i = 1; i < symbols_.size(); ++i) {
+            try_add_bigram(i - 1, i);
+        }
+
+        // keep substituting the highest frequency pairs for as long as we can.
+        while (!work_queue_.empty()) {
+            auto bigram = work_queue_.top();
+            work_queue_.pop();
+
+            auto & left_sym = symbols_[bigram.left];
+            auto & right_sym = symbols_[bigram.right];
+
+            // if one of the symbols already got merged, skip it.
+            if (left_sym.n == 0 || right_sym.n == 0 ||
+                left_sym.n + right_sym.n != bigram.size) {
+                continue;
+            }
+
+            // merge the right sym into the left one
+            left_sym.n += right_sym.n;
+            right_sym.n = 0;
+
+            //printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
+
+            // remove the right sym from the chain
+            left_sym.next = right_sym.next;
+            if (right_sym.next >= 0) {
+                symbols_[right_sym.next].prev = bigram.left;
+            }
+
+            // find more substitutions
+            try_add_bigram(left_sym.prev, bigram.left);
+            try_add_bigram(bigram.left, left_sym.next);
+        }
+
+        for (int i = 0; i != -1; i = symbols_[i].next) {
+            auto & symbol = symbols_[i];
+            auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
+
+            if (token == vocab_.token_to_id.end()) {
+                // output any symbols that did not form tokens as bytes.
+                for (int j = 0; j < (int) symbol.n; ++j) {
+                    llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
+                    output.push_back(token_id);
+                }
+            } else {
+                output.push_back((*token).second);
+            }
+        }
+    }
+
+private:
+    void try_add_bigram(int left, int right) {
+        if (left == -1 || right == -1) {
+            return;
+        }
+
+        const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
+        auto token = vocab_.token_to_id.find(text);
+
+        if (token == vocab_.token_to_id.end()) {
+            return;
+        }
+
+        if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
+            return;
+        }
+
+        const auto &tok_score = vocab_.id_to_token[(*token).second];
+
+        llama_sp_bigram bigram;
+        bigram.left = left;
+        bigram.right = right;
+        bigram.score = tok_score.score;
+        bigram.size = text.size();
+        work_queue_.push(bigram);
+    }
+
+    const llama_vocab & vocab_;
+    std::vector<llama_sp_symbol> symbols_;
+    llama_sp_bigram::queue work_queue_;
+};
+
+static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
+    llama_tokenizer tokenizer(vocab);
+    std::vector<llama_vocab::id> output;
+
+    if (text.size() == 0) {
+        return output;
+    }
+
+    if (bos) {
+        output.push_back(1);
+    }
+
+    tokenizer.tokenize(text, output);
+    return output;
+}
+
+//
+// sampling
+//
+
+static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
+    // find the top k tokens
+    std::partial_sort(
+            logits_id.begin(),
+            logits_id.begin() + top_k, logits_id.end(),
+            [](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
+        return a.first > b.first;
+    });
+
+    logits_id.resize(top_k);
+}
+
+static llama_vocab::id llama_sample_top_p_top_k(
+        llama_context & lctx,
+        const std::vector<llama_vocab::id> & last_n_tokens,
+        int top_k,
+        double top_p,
+        double temp,
+        double repeat_penalty) {
+    auto & rng = lctx.rng;
+
+    const int n_logits = lctx.model.hparams.n_vocab;
+
+    const auto & logits = lctx.logits;
+    const auto * plogits = logits.data() + logits.size() - n_logits;
+
+    std::vector<std::pair<double, llama_vocab::id>> logits_id;
+    logits_id.reserve(n_logits);
+
+    {
+        const double scale = 1.0/temp;
+        for (int i = 0; i < n_logits; ++i) {
+            // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
+            // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
+            if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
+                // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
+                if (plogits[i] < 0.0) {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
+                } else {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
+                }
+            } else {
+                logits_id.push_back(std::make_pair(plogits[i]*scale, i));
+            }
+        }
+    }
+
+    sample_top_k(logits_id, top_k);
+
+    double maxl = -std::numeric_limits<double>::infinity();
+    for (const auto & kv : logits_id) {
+        maxl = std::max(maxl, kv.first);
+    }
+
+    // compute probs for the top k tokens
+    std::vector<double> probs;
+    probs.reserve(logits_id.size());
+
+    double sum = 0.0;
+    for (const auto & kv : logits_id) {
+        double p = exp(kv.first - maxl);
+        probs.push_back(p);
+        sum += p;
+    }
+
+    // normalize the probs
+    for (auto & p : probs) {
+        p /= sum;
+    }
+
+    if (top_p < 1.0f) {
+        double cumsum = 0.0f;
+        for (int i = 0; i < (int) probs.size(); i++) {
+            cumsum += probs[i];
+            if (cumsum >= top_p) {
+                probs.resize(i + 1);
+                logits_id.resize(i + 1);
+                break;
+            }
+        }
+
+        cumsum = 1.0/cumsum;
+        for (int i = 0; i < (int) probs.size(); i++) {
+            probs[i] *= cumsum;
+        }
+    }
+
+    //printf("\n");
+    //for (int i = 0; i < (int) 10; i++) {
+    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+    //}
+    //printf("\n\n");
+    //exit(0);
+
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+    int idx = dist(rng);
+
+    return logits_id[idx].second;
+}
+
+//
+// quantization
+//
+
+// TODO: reuse code from the llama_model_load() somehow
+bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype, int qk) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    switch (itype) {
+        case 2: type = GGML_TYPE_Q4_0; break;
+        case 3: type = GGML_TYPE_Q4_1; break;
+        default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
+    };
+
+    if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
+        fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
+        return false;
+    }
+
+    llama_vocab vocab;
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *) &magic, sizeof(magic));
+        if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
+            fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
+                    __func__, fname_inp.c_str());
+            return false;
+        }
+        if (magic != LLAMA_FILE_MAGIC) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *) &magic, sizeof(magic));
+
+        uint32_t format_version;
+        finp.read((char *) &format_version, sizeof(format_version));
+
+        if (format_version != LLAMA_FILE_VERSION) {
+            fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
+                    __func__, fname_inp.c_str(), format_version, LLAMA_FILE_VERSION);
+            return false;
+        }
+
+        fout.write((char *) &format_version, sizeof(format_version));
+    }
+
+    llama_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        //finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        finp.read((char *) &hparams.n_mult,  sizeof(hparams.n_mult));
+        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        finp.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        finp.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_mult  = %d\n", __func__, hparams.n_mult);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+
+        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        //fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fout.write((char *) &hparams.n_mult,  sizeof(hparams.n_mult));
+        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fout.write((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fout.write((char *) &itype,           sizeof(hparams.f16));
+    }
+
+    // load vocab
+    {
+        const int32_t n_vocab = hparams.n_vocab;
+
+        if (n_vocab != hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        vocab.id_to_token.resize(n_vocab);
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read ((char *) &len, sizeof(len));
+            fout.write((char *) &len, sizeof(len));
+
+            word.resize(len);
+            finp.read ((char *) word.data(), len);
+            fout.write((char *) word.data(), len);
+
+            float score;
+            finp.read ((char *) &score, sizeof(score));
+            fout.write((char *) &score, sizeof(score));
+
+            vocab.token_to_id[word] = i;
+
+            auto &tok_score = vocab.id_to_token[i];
+            tok_score.tok = word;
+            tok_score.score = score;
+        }
+    }
+
+    // load weights
+    {
+        size_t total_size_org = 0;
+        size_t total_size_new = 0;
+
+        std::vector<float> work;
+
+        std::vector<uint8_t>     data_u8;
+        std::vector<ggml_fp16_t> data_f16;
+        std::vector<float>       data_f32;
+
+        std::vector<int64_t> hist_all(1 << 4, 0);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ftype;
+
+            finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            finp.read(reinterpret_cast<char *>(&length), sizeof(length));
+            finp.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+            if (finp.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            finp.read (&name[0], length);
+
+            {
+                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
+            }
+
+            // regexes of tensor names to be quantized
+            const std::vector<std::string> k_names = {
+                ".*weight",
+            };
+
+            bool quantize = false;
+            for (const auto & s : k_names) {
+                if (std::regex_match(name, std::regex(s))) {
+                    quantize = true;
+                    break;
+                }
+            }
+
+            // quantize only 2D tensors
+            quantize &= (n_dims == 2);
+
+            if (quantize) {
+                if (ftype != 0 && ftype != 1) {
+                    fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
+                    return false;
+                }
+
+                if (ftype == 1) {
+                    data_f16.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
+                    data_f32.resize(nelements);
+                    for (int i = 0; i < nelements; ++i) {
+                        data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
+                    }
+                } else {
+                    data_f32.resize(nelements);
+                    finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
+                }
+
+                ftype = itype;
+            } else {
+                const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
+
+                data_u8.resize(nelements*bpe);
+                finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
+            }
+
+            fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fout.write(reinterpret_cast<char *>(&length), sizeof(length));
+            fout.write(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+            for (int i = 0; i < n_dims; ++i) {
+                fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+            }
+            fout.write(&name[0], length);
+
+            if (quantize) {
+                printf("quantizing .. ");
+                work.resize(nelements); // for quantization
+
+                size_t cur_size = 0;
+                std::vector<int64_t> hist_cur(1 << 4, 0);
+
+                switch (type) {
+                    case GGML_TYPE_Q4_0:
+                        {
+                            cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
+                        } break;
+                    default:
+                        {
+                            fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
+                            return false;
+                        }
+                }
+
+                fout.write(reinterpret_cast<char *>(work.data()), cur_size);
+                total_size_new += cur_size;
+
+                printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
+                for (int i = 0; i < (int) hist_cur.size(); ++i) {
+                    hist_all[i] += hist_cur[i];
+                }
+
+                for (int i = 0; i < (int) hist_cur.size(); ++i) {
+                    printf("%5.3f ", hist_cur[i] / (float)nelements);
+                }
+                printf("\n");
+            } else {
+                printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
+                fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
+                total_size_new += data_u8.size();
+            }
+
+            total_size_org += nelements * sizeof(float);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
+        printf("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+
+        {
+            int64_t sum_all = 0;
+            for (int i = 0; i < (int) hist_all.size(); ++i) {
+                sum_all += hist_all[i];
+            }
+
+            printf("%s: hist: ", __func__);
+            for (int i = 0; i < (int) hist_all.size(); ++i) {
+                printf("%5.3f ", hist_all[i] / (float)sum_all);
+            }
+            printf("\n");
+        }
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+//
+// interface implementation
+//
+
+struct llama_context * llama_init_from_file(
+                             const char * path_model,
+            struct llama_context_params   params) {
+    ggml_time_init();
+
+    llama_context * ctx = new llama_context;
+
+    if (params.seed <= 0) {
+        params.seed = time(NULL);
+    }
+
+    ctx->rng = std::mt19937(params.seed);
+    ctx->logits_all = params.logits_all;
+
+    ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+    if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
+                          params.vocab_only, params.progress_callback,
+                          params.progress_callback_user_data)) {
+        fprintf(stderr, "%s: failed to load model\n", __func__);
+        llama_free(ctx);
+        return nullptr;
+    }
+
+    if (params.use_mlock) {
+        char *err;
+        if (!ggml_mlock(ctx->model.ctx, &err)) {
+            fprintf(stderr, "%s\n", err);
+            free(err);
+            llama_free(ctx);
+            return nullptr;
+        }
+    }
+
+    // reserve memory for context buffers
+    {
+        if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
+            fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+            llama_free(ctx);
+            return nullptr;
+        }
+
+        {
+            const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
+            fprintf(stderr, "%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        }
+
+        const auto & hparams = ctx->model.hparams;
+
+        // resized during inference
+        if (params.logits_all) {
+            ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+        } else {
+            ctx->logits.reserve(hparams.n_ctx);
+        }
+
+        if (params.embedding){
+            ctx->embedding.resize(hparams.n_embd);
+        }
+
+        ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
+
+        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
+        ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
+    }
+
+    return ctx;
+}
+
+void llama_free(struct llama_context * ctx) {
+    kv_cache_free(ctx->model.kv_self);
+
+    if (ctx->model.ctx) {
+        ggml_free(ctx->model.ctx);
+    }
+
+    delete ctx;
+}
+
+int llama_model_quantize(
+        const char * fname_inp,
+        const char * fname_out,
+               int   itype,
+               int   qk) {
+    if (!llama_model_quantize_internal(fname_inp, fname_out, itype, qk)) {
+        fprintf(stderr, "%s: failed to quantize\n", __func__);
+        return 1;
+    }
+
+    return 0;
+}
+
+int llama_eval(
+        struct llama_context * ctx,
+           const llama_token * tokens,
+                         int   n_tokens,
+                         int   n_past,
+                         int   n_threads) {
+    if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
+        fprintf(stderr, "%s: failed to eval\n", __func__);
+        return 1;
+    }
+
+    return 0;
+}
+
+int llama_tokenize(
+        struct llama_context * ctx,
+                  const char * text,
+                 llama_token * tokens,
+                         int   n_max_tokens,
+                        bool   add_bos) {
+    auto res = llama_tokenize(ctx->vocab, text, add_bos);
+
+    if (n_max_tokens < (int) res.size()) {
+        fprintf(stderr, "%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+int llama_n_vocab(struct llama_context * ctx) {
+    return ctx->vocab.id_to_token.size();
+}
+
+int llama_n_ctx(struct llama_context * ctx) {
+    return ctx->model.hparams.n_ctx;
+}
+
+int llama_n_embd(struct llama_context * ctx) {
+    return ctx->model.hparams.n_embd;
+}
+
+float * llama_get_logits(struct llama_context * ctx) {
+    return ctx->logits.data();
+}
+
+float * llama_get_embeddings(struct llama_context * ctx) {
+    return ctx->embedding.data();
+}
+
+const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
+    if (token >= llama_n_vocab(ctx)) {
+        return nullptr;
+    }
+
+    return ctx->vocab.id_to_token[token].tok.c_str();
+}
+
+llama_token llama_token_bos() {
+    return 1;
+}
+
+llama_token llama_token_eos() {
+    return 2;
+}
+
+llama_token llama_sample_top_p_top_k(
+          llama_context * ctx,
+      const llama_token * last_n_tokens_data,
+                    int   last_n_tokens_size,
+                    int   top_k,
+                 double   top_p,
+                 double   temp,
+                 double   repeat_penalty) {
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    llama_token result = 0;
+
+    // TODO: avoid this ...
+    const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
+
+    result = llama_sample_top_p_top_k(
+            *ctx,
+            last_n_tokens,
+            top_k,
+            top_p,
+            temp,
+            repeat_penalty);
+
+    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    ctx->n_sample++;
+
+    return result;
+}
+
+
+void llama_print_timings(struct llama_context * ctx) {
+    const int64_t t_end_us = ggml_time_us();
+
+    const int32_t n_sample = std::max(1, ctx->n_sample);
+    const int32_t n_eval   = std::max(1, ctx->n_eval);
+    const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
+
+    fprintf(stderr, "\n");
+    fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
+    fprintf(stderr, "%s:      sample time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample);
+    fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval);
+    fprintf(stderr, "%s:        eval time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3f * ctx->t_eval_us,   n_eval,   1e-3f * ctx->t_eval_us   / n_eval);
+    fprintf(stderr, "%s:       total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
+}
+
+void llama_reset_timings(struct llama_context * ctx) {
+    ctx->t_start_us = ggml_time_us();
+
+    ctx->t_sample_us = ctx->n_sample = 0;
+    ctx->t_eval_us   = ctx->n_eval   = 0;
+    ctx->t_p_eval_us = ctx->n_p_eval = 0;
+}
+
+const char * llama_print_system_info(void) {
+    static std::string s;
+
+    s  = "";
+    s += "AVX = "       + std::to_string(ggml_cpu_has_avx())       + " | ";
+    s += "AVX2 = "      + std::to_string(ggml_cpu_has_avx2())      + " | ";
+    s += "AVX512 = "    + std::to_string(ggml_cpu_has_avx512())    + " | ";
+    s += "FMA = "       + std::to_string(ggml_cpu_has_fma())       + " | ";
+    s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
+    s += "ARM_FMA = "   + std::to_string(ggml_cpu_has_arm_fma())   + " | ";
+    s += "F16C = "      + std::to_string(ggml_cpu_has_f16c())      + " | ";
+    s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
+    s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
+    s += "BLAS = "      + std::to_string(ggml_cpu_has_blas())      + " | ";
+    s += "SSE3 = "      + std::to_string(ggml_cpu_has_sse3())      + " | ";
+    s += "VSX = "       + std::to_string(ggml_cpu_has_vsx())       + " | ";
+
+    return s.c_str();
+}
diff --git a/examples/talk.llama/llama.h b/examples/talk.llama/llama.h
new file mode 100644
index 00000000000..ebf55f41c35
--- /dev/null
+++ b/examples/talk.llama/llama.h
@@ -0,0 +1,153 @@
+#ifndef LLAMA_H
+#define LLAMA_H
+
+#include <stddef.h>
+#include <stdint.h>
+#include <stdbool.h>
+
+#ifdef LLAMA_SHARED
+#    ifdef _WIN32
+#        ifdef LLAMA_BUILD
+#            define LLAMA_API __declspec(dllexport)
+#        else
+#            define LLAMA_API __declspec(dllimport)
+#        endif
+#    else
+#        define LLAMA_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define LLAMA_API
+#endif
+
+#define LLAMA_FILE_VERSION 1
+#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
+#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+    //
+    // C interface
+    //
+    // TODO: show sample usage
+    //
+
+    struct llama_context;
+
+    typedef int llama_token;
+
+    typedef struct llama_token_data {
+        llama_token id;  // token id
+
+        float p;     // probability of the token
+        float plog;  // log probability of the token
+
+    } llama_token_data;
+
+    typedef void (*llama_progress_callback)(double progress, void *ctx);
+
+    struct llama_context_params {
+        int n_ctx;   // text context
+        int n_parts; // -1 for default
+        int seed;    // RNG seed, 0 for random
+
+        bool f16_kv;     // use fp16 for KV cache
+        bool logits_all; // the llama_eval() call computes all logits, not just the last one
+        bool vocab_only; // only load the vocabulary, no weights
+        bool use_mlock;  // force system to keep model in RAM
+        bool embedding;  // embedding mode only
+
+        // called with a progress value between 0 and 1, pass NULL to disable
+        llama_progress_callback progress_callback;
+        // context pointer passed to the progress callback
+        void * progress_callback_user_data;
+    };
+
+    LLAMA_API struct llama_context_params llama_context_default_params();
+
+    // Various functions for loading a ggml llama model.
+    // Allocate (almost) all memory needed for the model.
+    // Return NULL on failure
+    LLAMA_API struct llama_context * llama_init_from_file(
+                             const char * path_model,
+            struct llama_context_params   params);
+
+    // Frees all allocated memory
+    LLAMA_API void llama_free(struct llama_context * ctx);
+
+    // TODO: not great API - very likely to change
+    // Returns 0 on success
+    LLAMA_API int llama_model_quantize(
+            const char * fname_inp,
+            const char * fname_out,
+                   int   itype,
+                   int   qk);
+
+    // Run the llama inference to obtain the logits and probabilities for the next token.
+    // tokens + n_tokens is the provided batch of new tokens to process
+    // n_past is the number of tokens to use from previous eval calls
+    // Returns 0 on success
+    LLAMA_API int llama_eval(
+            struct llama_context * ctx,
+               const llama_token * tokens,
+                             int   n_tokens,
+                             int   n_past,
+                             int   n_threads);
+
+    // Convert the provided text into tokens.
+    // The tokens pointer must be large enough to hold the resulting tokens.
+    // Returns the number of tokens on success, no more than n_max_tokens
+    // Returns a negative number on failure - the number of tokens that would have been returned
+    // TODO: not sure if correct
+    LLAMA_API int llama_tokenize(
+            struct llama_context * ctx,
+                      const char * text,
+                     llama_token * tokens,
+                             int   n_max_tokens,
+                            bool   add_bos);
+
+    LLAMA_API int llama_n_vocab(struct llama_context * ctx);
+    LLAMA_API int llama_n_ctx  (struct llama_context * ctx);
+    LLAMA_API int llama_n_embd (struct llama_context * ctx);
+
+    // Token logits obtained from the last call to llama_eval()
+    // The logits for the last token are stored in the last row
+    // Can be mutated in order to change the probabilities of the next token
+    // Rows: n_tokens
+    // Cols: n_vocab
+    LLAMA_API float * llama_get_logits(struct llama_context * ctx);
+
+    // Get the embeddings for the input
+    // shape: [n_embd] (1-dimensional)
+    LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
+
+    // Token Id -> String. Uses the vocabulary in the provided context
+    LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
+
+    // Special tokens
+    LLAMA_API llama_token llama_token_bos();
+    LLAMA_API llama_token llama_token_eos();
+
+    // TODO: improve the last_n_tokens interface ?
+    LLAMA_API llama_token llama_sample_top_p_top_k(
+       struct llama_context * ctx,
+          const llama_token * last_n_tokens_data,
+                        int   last_n_tokens_size,
+                        int   top_k,
+                     double   top_p,
+                     double   temp,
+                     double   repeat_penalty);
+
+    // Performance information
+    LLAMA_API void llama_print_timings(struct llama_context * ctx);
+    LLAMA_API void llama_reset_timings(struct llama_context * ctx);
+
+    // Print system information
+    LLAMA_API const char * llama_print_system_info(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/examples/talk.llama/speak.sh b/examples/talk.llama/speak.sh
new file mode 100755
index 00000000000..8888a206143
--- /dev/null
+++ b/examples/talk.llama/speak.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Usage:
+#  speak.sh <voice_id> <text-to-speak>
+
+# espeak
+# Mac OS: brew install espeak
+# Linux: apt-get install espeak
+#
+#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
+
+# for Mac
+say "$2"
+
+# Eleven Labs
+#
+#wd=$(dirname $0)
+#script=$wd/eleven-labs.py
+#python3 $script $1 "$2" >/dev/null 2>&1
+#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
diff --git a/examples/talk.llama/talk-llama.cpp b/examples/talk.llama/talk-llama.cpp
new file mode 100644
index 00000000000..8cbb1cd1541
--- /dev/null
+++ b/examples/talk.llama/talk-llama.cpp
@@ -0,0 +1,511 @@
+// Talk with AI
+//
+
+#include "common.h"
+#include "common-sdl.h"
+#include "whisper.h"
+#include "llama.h"
+
+#include <cassert>
+#include <cstdio>
+#include <fstream>
+#include <regex>
+#include <string>
+#include <thread>
+#include <vector>
+#include <regex>
+
+std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
+    // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
+    std::vector<llama_token> res(text.size() + (int)add_bos);
+    int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
+    assert(n >= 0);
+    res.resize(n);
+
+    return res;
+}
+
+// command-line parameters
+struct whisper_params {
+    int32_t n_threads  = std::min(4, (int32_t) std::thread::hardware_concurrency());
+    int32_t voice_ms   = 10000;
+    int32_t capture_id = -1;
+    int32_t max_tokens = 32;
+    int32_t audio_ctx  = 0;
+
+    float vad_thold    = 0.6f;
+    float freq_thold   = 100.0f;
+
+    bool speed_up      = false;
+    bool translate     = false;
+    bool print_special = false;
+    bool print_energy  = false;
+    bool no_timestamps = true;
+
+    std::string person      = "Santa";
+    std::string language    = "en";
+    std::string model_wsp   = "models/ggml-base.en.bin";
+    std::string model_llama = "models/ggml-llama-7B.bin";
+    std::string speak       = "./examples/talk/speak.sh";
+    std::string fname_out;
+};
+
+void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
+
+bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
+    for (int i = 1; i < argc; i++) {
+        std::string arg = argv[i];
+
+        if (arg == "-h" || arg == "--help") {
+            whisper_print_usage(argc, argv, params);
+            exit(0);
+        }
+        else if (arg == "-t"   || arg == "--threads")       { params.n_threads     = std::stoi(argv[++i]); }
+        else if (arg == "-vms" || arg == "--voice-ms")      { params.voice_ms      = std::stoi(argv[++i]); }
+        else if (arg == "-c"   || arg == "--capture")       { params.capture_id    = std::stoi(argv[++i]); }
+        else if (arg == "-mt"  || arg == "--max-tokens")    { params.max_tokens    = std::stoi(argv[++i]); }
+        else if (arg == "-ac"  || arg == "--audio-ctx")     { params.audio_ctx     = std::stoi(argv[++i]); }
+        else if (arg == "-vth" || arg == "--vad-thold")     { params.vad_thold     = std::stof(argv[++i]); }
+        else if (arg == "-fth" || arg == "--freq-thold")    { params.freq_thold    = std::stof(argv[++i]); }
+        else if (arg == "-su"  || arg == "--speed-up")      { params.speed_up      = true; }
+        else if (arg == "-tr"  || arg == "--translate")     { params.translate     = true; }
+        else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
+        else if (arg == "-pe"  || arg == "--print-energy")  { params.print_energy  = true; }
+        else if (arg == "-p"   || arg == "--person")        { params.person        = argv[++i]; }
+        else if (arg == "-l"   || arg == "--language")      { params.language      = argv[++i]; }
+        else if (arg == "-mw"  || arg == "--model-whisper") { params.model_wsp     = argv[++i]; }
+        else if (arg == "-ml"  || arg == "--model-llama")   { params.model_llama   = argv[++i]; }
+        else if (arg == "-s"   || arg == "--speak")         { params.speak         = argv[++i]; }
+        else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
+        else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            whisper_print_usage(argc, argv, params);
+            exit(0);
+        }
+    }
+
+    return true;
+}
+
+void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
+    fprintf(stderr, "\n");
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h,       --help          [default] show this help message and exit\n");
+    fprintf(stderr, "  -t N,     --threads N     [%-7d] number of threads to use during computation\n", params.n_threads);
+    fprintf(stderr, "  -vms N,   --voice-ms N    [%-7d] voice duration in milliseconds\n",              params.voice_ms);
+    fprintf(stderr, "  -c ID,    --capture ID    [%-7d] capture device ID\n",                           params.capture_id);
+    fprintf(stderr, "  -mt N,    --max-tokens N  [%-7d] maximum number of tokens per audio chunk\n",    params.max_tokens);
+    fprintf(stderr, "  -ac N,    --audio-ctx N   [%-7d] audio context size (0 - all)\n",                params.audio_ctx);
+    fprintf(stderr, "  -vth N,   --vad-thold N   [%-7.2f] voice activity detection threshold\n",        params.vad_thold);
+    fprintf(stderr, "  -fth N,   --freq-thold N  [%-7.2f] high-pass frequency cutoff\n",                params.freq_thold);
+    fprintf(stderr, "  -su,      --speed-up      [%-7s] speed up audio by x2 (reduced accuracy)\n",     params.speed_up ? "true" : "false");
+    fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",   params.translate ? "true" : "false");
+    fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                        params.print_special ? "true" : "false");
+    fprintf(stderr, "  -pe,      --print-energy  [%-7s] print sound energy (for debugging)\n",          params.print_energy ? "true" : "false");
+    fprintf(stderr, "  -p NAME,  --person NAME   [%-7s] person name (for prompt selection)\n",          params.person.c_str());
+    fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n",                             params.language.c_str());
+    fprintf(stderr, "  -mw FILE, --model-whisper [%-7s] whisper model file\n",                          params.model_wsp.c_str());
+    fprintf(stderr, "  -mg FILE, --model-llama   [%-7s] llama model file\n",                            params.model_llama.c_str());
+    fprintf(stderr, "  -s FILE,  --speak TEXT    [%-7s] command for TTS\n",                             params.speak.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] text output file name\n",                       params.fname_out.c_str());
+    fprintf(stderr, "\n");
+}
+
+std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
+    const auto t_start = std::chrono::high_resolution_clock::now();
+
+    prob = 0.0f;
+    t_ms = 0;
+
+    whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+
+    wparams.print_progress   = false;
+    wparams.print_special    = params.print_special;
+    wparams.print_realtime   = false;
+    wparams.print_timestamps = !params.no_timestamps;
+    wparams.translate        = params.translate;
+    wparams.no_context       = true;
+    wparams.single_segment   = true;
+    wparams.max_tokens       = params.max_tokens;
+    wparams.language         = params.language.c_str();
+    wparams.n_threads        = params.n_threads;
+
+    wparams.audio_ctx        = params.audio_ctx;
+    wparams.speed_up         = params.speed_up;
+
+    if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
+        return "";
+    }
+
+    int prob_n = 0;
+    std::string result;
+
+    const int n_segments = whisper_full_n_segments(ctx);
+    for (int i = 0; i < n_segments; ++i) {
+        const char * text = whisper_full_get_segment_text(ctx, i);
+
+        result += text;
+
+        const int n_tokens = whisper_full_n_tokens(ctx, i);
+        for (int j = 0; j < n_tokens; ++j) {
+            const auto token = whisper_full_get_token_data(ctx, i, j);
+
+            prob += token.p;
+            ++prob_n;
+        }
+    }
+
+    if (prob_n > 0) {
+        prob /= prob_n;
+    }
+
+    const auto t_end = std::chrono::high_resolution_clock::now();
+    t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
+
+    return result;
+}
+
+// need to have leading ' '
+//const std::string k_prompt = R"( Transcript of a dialog, where {1} interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer {1}'s requests immediately and with precision.
+//
+//{0}: Hello, Bob.
+//{1}: Hello {0}. How may I help you today?
+//{0}:)";
+
+const std::string k_prompt = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
+{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
+There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
+The transcript only includes text, it does not include markup like HTML and Markdown.
+{1} answers responds with short and concise answers.
+
+{0}{4} Hello, {1}!
+{1}{4} Hello {0}! How may I help you today?
+{0}{4} What time is it?
+{1}{4} It is {2} o'clock.
+{0}{4} What year is it?
+{1}{4} We are in {3}.
+{0}{4} What is a cat?
+{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
+{0}{4} Name a color.
+{1}{4} Blue
+{0}{4})";
+
+int main(int argc, char ** argv) {
+    whisper_params params;
+
+    if (whisper_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (whisper_lang_id(params.language.c_str()) == -1) {
+        fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
+        whisper_print_usage(argc, argv, params);
+        exit(0);
+    }
+
+    // whisper init
+
+    struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
+
+    // llama init
+
+    auto lparams = llama_context_default_params();
+
+    lparams.n_ctx      = 512;
+    lparams.n_parts    = 2; // TODO fix
+    lparams.seed       = 1; // TODO fix
+    lparams.f16_kv     = true;
+
+    struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
+
+    // print some info about the processing
+    {
+        fprintf(stderr, "\n");
+        if (!whisper_is_multilingual(ctx_wsp)) {
+            if (params.language != "en" || params.translate) {
+                params.language = "en";
+                params.translate = false;
+                fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+            }
+        }
+        fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
+                __func__,
+                params.n_threads,
+                params.language.c_str(),
+                params.translate ? "translate" : "transcribe",
+                params.no_timestamps ? 0 : 1);
+
+        fprintf(stderr, "\n");
+    }
+
+
+    // init audio
+
+    audio_async audio(30*1000);
+    if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
+        fprintf(stderr, "%s: audio.init() failed!\n", __func__);
+        return 1;
+    }
+
+    audio.resume();
+
+    int n_iter = 0;
+
+    bool is_running  = true;
+    bool force_speak = false;
+
+    float prob0 = 0.0f;
+
+    const std::string chat_symb = ":";
+    const std::string bot_name = "LLAMA";
+
+    std::vector<float> pcmf32_cur;
+    std::vector<float> pcmf32_prompt;
+
+    std::string prompt_org = k_prompt;
+    prompt_org = ::replace(prompt_org, "{0}", params.person);
+    prompt_org = ::replace(prompt_org, "{1}", bot_name);
+
+    {
+        // get time string
+        std::string time_str;
+        {
+            time_t t = time(0);
+            struct tm * now = localtime(&t);
+            char buf[128];
+            strftime(buf, sizeof(buf), "%H:%M", now);
+            time_str = buf;
+        }
+        prompt_org = ::replace(prompt_org, "{2}", time_str);
+    }
+
+    {
+        // get year string
+        std::string year_str;
+        {
+            time_t t = time(0);
+            struct tm * now = localtime(&t);
+            char buf[128];
+            strftime(buf, sizeof(buf), "%Y", now);
+            year_str = buf;
+        }
+        prompt_org = ::replace(prompt_org, "{3}", year_str);
+    }
+
+    prompt_org = ::replace(prompt_org, "{4}", chat_symb);
+
+    auto embd_inp = ::llama_tokenize(ctx_llama, prompt_org, true);
+
+    const int n_ctx = llama_n_ctx(ctx_llama);
+
+    printf("\n");
+    printf("%s : initializing - please wait ...\n", __func__);
+
+    if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
+        fprintf(stderr, "%s : failed to eval\n", __func__);
+        return 1;
+    }
+
+    //fprintf(stdout, "\n");
+    //fprintf(stdout, "%s", prompt_org.c_str());
+    //fflush(stdout);
+
+    printf("%s : done! start speaking in the microphone\n", __func__);
+    printf("\n");
+    printf("%s%s", params.person.c_str(), chat_symb.c_str());
+    fflush(stdout);
+
+    audio.clear();
+
+    const int n_keep = embd_inp.size();
+    const int voice_id = 2;
+
+    int n_past = n_keep;
+    int n_prev = 64; // TODO arg
+
+    std::vector<llama_token> embd;
+
+    std::vector<std::string> antiprompts = {
+        params.person + chat_symb,
+    };
+
+    // main loop
+    while (is_running) {
+        // handle Ctrl + C
+        is_running = sdl_poll_events();
+
+        if (!is_running) {
+            break;
+        }
+
+        // delay
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+        int64_t t_ms = 0;
+
+        {
+            audio.get(2000, pcmf32_cur);
+
+            if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
+                //fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
+
+                audio.get(params.voice_ms, pcmf32_cur);
+
+                std::string text_heard;
+
+                if (!force_speak) {
+                    text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
+                }
+
+                // remove text between brackets using regex
+                {
+                    std::regex re("\\[.*?\\]");
+                    text_heard = std::regex_replace(text_heard, re, "");
+                }
+
+                // remove text between brackets using regex
+                {
+                    std::regex re("\\(.*?\\)");
+                    text_heard = std::regex_replace(text_heard, re, "");
+                }
+
+                // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
+                text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
+
+                // take first line
+                text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
+
+                // remove leading and trailing whitespace
+                text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
+                text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
+
+                const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
+
+                if (text_heard.empty() || tokens.empty() || force_speak) {
+                    //fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
+                    audio.clear();
+
+                    continue;
+                }
+
+                force_speak = false;
+
+                text_heard.insert(0, 1, ' ');
+                text_heard += "\n" + bot_name + chat_symb;
+                fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
+                fflush(stdout);
+
+                embd = ::llama_tokenize(ctx_llama, text_heard, false);
+
+                // text inference
+                bool done = false;
+                std::string text_to_speak;
+                while (true) {
+                    // predict
+                    if (embd.size() > 0) {
+                        if (n_past + (int) embd.size() > n_ctx) {
+                            n_past = n_keep;
+
+                            // insert n_left/2 tokens at the start of embd from last_n_tokens
+                            embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
+
+                            //printf("\n---\n");
+                            //printf("resetting: '");
+                            //for (int i = 0; i < (int) embd.size(); i++) {
+                            //    printf("%s", llama_token_to_str(ctx_llama, embd[i]));
+                            //}
+                            //printf("'\n");
+                            //printf("\n---\n");
+                        }
+
+                        if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
+                            fprintf(stderr, "%s : failed to eval\n", __func__);
+                            return 1;
+                        }
+                    }
+
+                    //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
+
+                    embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
+                    n_past += embd.size();
+                    embd.clear();
+
+                    if (done) break;
+
+                    {
+                        // out of user input, sample next token
+                        const float top_k          = 5;
+                        const float top_p          = 0.80f;
+                        const float temp           = 0.30f;
+                        const float repeat_penalty = 1.1764f;
+
+                        const int repeat_last_n    = 256;
+
+                        llama_token id = 0;
+
+                        {
+                            //auto logits = llama_get_logits(ctx_llama);
+                            //logits[llama_token_eos()] = 0;
+
+                            id = llama_sample_top_p_top_k(ctx_llama,
+                                    embd_inp.data() + std::max(0, n_past - repeat_last_n),
+                                    repeat_last_n, top_k, top_p, temp, repeat_penalty);
+                        }
+
+                        if (id != llama_token_eos()) {
+                            // add it to the context
+                            embd.push_back(id);
+
+                            text_to_speak += llama_token_to_str(ctx_llama, id);
+
+                            printf("%s", llama_token_to_str(ctx_llama, id));
+                        } else {
+                            // TODO
+                            printf("EOS TOKEN - SHOULD NOT HAPPEN\n");
+                            exit(0);
+                        }
+                    }
+
+                    {
+                        std::string last_output;
+                        for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
+                            last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
+                        }
+                        last_output += llama_token_to_str(ctx_llama, embd[0]);
+
+                        for (std::string & antiprompt : antiprompts) {
+                            if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
+                                done = true;
+                                text_to_speak = ::replace(text_to_speak, antiprompt, "");
+                                fflush(stdout);
+                                break;
+                            }
+                        }
+                    }
+
+                    is_running = sdl_poll_events();
+
+                    if (!is_running) {
+                        break;
+                    }
+                }
+
+                text_to_speak = ::replace(text_to_speak, "\"", "");
+                system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
+
+                audio.clear();
+
+                ++n_iter;
+            }
+        }
+    }
+
+    audio.pause();
+
+    whisper_print_timings(ctx_wsp);
+    whisper_free(ctx_wsp);
+
+    return 0;
+}
diff --git a/ggml.c b/ggml.c
index d67612c36a3..c9a4e867523 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1,18 +1,23 @@
+// Defines CLOCK_MONOTONIC and asprintf on Linux
+#define _GNU_SOURCE
+
 #include "ggml.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
-#elif !defined(__FreeBSD__)
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
 #include <alloca.h>
 #endif
 
 #include <assert.h>
+#include <errno.h>
 #include <time.h>
 #include <math.h>
 #include <stdlib.h>
 #include <string.h>
 #include <stdint.h>
 #include <stdio.h>
+#include <float.h>
 
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
@@ -27,7 +32,6 @@
 #else
 // ref: https://github.com/ggerganov/whisper.cpp/issues/168
 #include <windows.h>
-#include <errno.h>
 #endif
 
 typedef volatile LONG atomic_int;
@@ -79,9 +83,21 @@ typedef void* thread_ret_t;
 #define static_assert(cond, msg) _Static_assert(cond, msg)
 #endif
 
+#define GGML_MLOCK_SUPPORT 0
+
+#ifdef __has_include
+    #if __has_include(<sys/mman.h>)
+        #undef GGML_MLOCK_SUPPORT
+        #define GGML_MLOCK_SUPPORT 1
+        #include <sys/mman.h>
+    #endif
+#endif
+
+
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
+#define GGML_SILU_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
 #define GGML_VEC_DOT_UNROLL  2
@@ -159,6 +175,39 @@ typedef double ggml_float;
 #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
 #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
 
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+/* the inline asm below is about 12% faster than the lookup method */
+#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    register float f;
+    register double d;
+    __asm__(
+        "mtfprd %0,%2\n"
+        "xscvhpdp %0,%0\n"
+        "frsp %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=f"(f):
+        /* in */   "r"(h));
+    return f;
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+    register double d;
+    register ggml_fp16_t r;
+    __asm__( /* xscvdphp can work on double or single precision */
+        "xscvdphp %0,%2\n"
+        "mffprd %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=r"(r):
+        /* in */   "f"(f));
+    return r;
+}
+
 #else
 
 // FP16 <-> FP32
@@ -245,6 +294,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 // precomputed gelu table for f16 (128 KB)
 static ggml_fp16_t table_gelu_f16[1 << 16];
 
+// precomputed silu table for f16 (128 KB)
+static ggml_fp16_t table_silu_f16[1 << 16];
+
 // precomputed exp table for f16 (128 KB)
 static ggml_fp16_t table_exp_f16[1 << 16];
 
@@ -253,6 +305,7 @@ static float table_f32_f16[1 << 16];
 
 // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
 // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+// This is also true for POWER9.
 #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
 
 inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
@@ -348,6 +401,540 @@ int64_t ggml_cycles_per_ms(void) {
 
 static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 
+//
+// quantization
+//
+
+#define QK 32
+
+// AVX routines provided by GH user Const-me
+// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
+#if __AVX2__ || __AVX512F__
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytesFromNibbles( const uint8_t* rsi )
+{
+    // Load 16 bytes from memory
+    __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
+
+    // Expand bytes into uint16_t values
+    __m256i bytes = _mm256_cvtepu8_epi16( tmp );
+
+    // Unpack values into individual bytes
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    __m256i high = _mm256_andnot_si256( lowMask, bytes );
+    __m256i low = _mm256_and_si256( lowMask, bytes );
+    high = _mm256_slli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+    return bytes;
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+}
+#endif
+
+// method 5
+// blocks of QK elements
+// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
+
+// reference implementation for deterministic creation of model files
+static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    const size_t bs = sizeof(float) + QK/2;
+
+    uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
+    uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int l = 0; l < QK; l++) {
+            const float v = x[i*QK + l];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        *(float *)pd = d;
+        pd += bs;
+
+        for (int l = 0; l < QK; l += 2) {
+            const float v0 = x[i*QK + l + 0]*id;
+            const float v1 = x[i*QK + l + 1]*id;
+
+            const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
+            const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
+
+            assert(vi0 >= 0 && vi0 < 16);
+            assert(vi1 >= 0 && vi1 < 16);
+
+            pp[l/2] = vi0 | (vi1 << 4);
+        }
+
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
+    }
+}
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+    assert(k % QK == 0);
+
+#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
+    const int nb = k / QK;
+    const size_t bs = sizeof(float) + QK/2;
+
+    uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
+    uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
+
+    uint8_t pp[QK/2];
+#endif
+
+#if defined(__POWER9_VECTOR__)
+    const vector float v85 = vec_splats(8.5f);
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = *(vector float *)(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
+        //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
+        amaxv[0] = vec_max(amaxv[0], amaxv[2]);
+        amaxv[4] = vec_max(amaxv[4], amaxv[6]);
+        //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
+        amaxv[0] = vec_max(amaxv[0], amaxv[4]);
+
+        amax = MAX(
+                MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
+                MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        *(float *)pd = d;
+        pd += bs;
+
+        const vector float vid = vec_splats(id);
+        for (int l = 0; l < 8; l++) {
+            const vector float vf  = vec_madd(srcv[l], vid, v85);
+            const vector signed int vi = vec_signed(vf);
+
+            pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
+            pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
+        }
+
+        //memcpy(pb, pp, sizeof(pp));
+        pb += bs;
+    }
+#elif __ARM_NEON
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+        amax = MAX(
+                MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
+                MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        *(float *)pd = d;
+        pd += bs;
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
+            const int32x4_t   vi = vcvtq_s32_f32(vf);
+
+            pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+        }
+
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
+    }
+#elif defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 7.0f;
+        *(float *)pd = d;
+        pd += bs;
+        const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
+        const __m256i off = _mm256_set1_epi8( 8 );
+        i0 = _mm256_add_epi8( i0, off );
+
+        // Compress the vector into 4 bit/value, and store
+        __m128i res = packNibbles( i0 );
+        _mm_storeu_si128( ( __m128i* )pb, res );
+        pb += bs;
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = wasm_v128_load(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
+
+        amax = MAX(
+                MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
+                MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 3) - 1);
+        const float id = d ? 1.0/d : 0.0;
+
+        *(float *)pd = d;
+        pd += bs;
+
+        for (int l = 0; l < 8; l++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
+            const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
+
+            pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
+            pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+        }
+
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
+    }
+#else
+    // scalar
+    quantize_row_q4_0_reference(x, y, k);
+#endif
+}
+
+// method 4
+// blocks of QK elements
+// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+    const size_t bs = 2*sizeof(float) + QK/2;
+
+    uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
+    uint8_t * restrict pm = ((uint8_t *)y + 0*bs +   sizeof(float));
+    uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
+
+    uint8_t pp[QK/2];
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int l = 0; l < QK; l++) {
+            const float v = x[i*QK + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        *(float *)pm = min;
+        *(float *)pd = d;
+        pm += bs;
+        pd += bs;
+
+        for (int l = 0; l < QK; l += 2) {
+            const float v0 = (x[i*QK + l + 0] - min)*id;
+            const float v1 = (x[i*QK + l + 1] - min)*id;
+
+            const uint8_t vi0 = round(v0);
+            const uint8_t vi1 = round(v1);
+
+            assert(vi0 >= 0 && vi0 < 16);
+            assert(vi1 >= 0 && vi1 < 16);
+
+            pp[l/2] = vi0 | (vi1 << 4);
+        }
+
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
+    }
+}
+
+// TODO: vectorize
+void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+    const size_t bs = sizeof(float) + QK/2;
+
+    const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
+    const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // scale factor
+        const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        for (int l = 0; l < QK; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytesFromNibbles(pp+l/2);
+
+            // Subtract 8 from the integers
+            vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_mul_ps(vf[j], d_v);
+                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+            }
+        }
+    }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float d = *(const float *) (pd + i*bs);
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        const float32x4_t vd = vdupq_n_f32(d);
+
+        for (int l = 0; l < QK; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit nibbles to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Convert to signed 8-bit integers
+            const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
+            const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
+
+            // Subtract 8 from each byte
+            const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
+            const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
+
+            // Interleave and combine
+            const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
+            const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
+
+            const int8x16_t vq = vcombine_s8(vx_0, vx_1);
+
+            // convert to 2x int16x8_t
+            const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
+            const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
+            const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
+            const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
+            const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
+
+            // Multiply by d
+            const float32x4_t r0 = vmulq_f32(vf_0, vd);
+            const float32x4_t r1 = vmulq_f32(vf_1, vd);
+            const float32x4_t r2 = vmulq_f32(vf_2, vd);
+            const float32x4_t r3 = vmulq_f32(vf_3, vd);
+
+            // Store
+            vst1q_f32(y + i*QK + l +  0, r0);
+            vst1q_f32(y + i*QK + l +  4, r1);
+            vst1q_f32(y + i*QK + l +  8, r2);
+            vst1q_f32(y + i*QK + l + 12, r3);
+        }
+    }
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d = *(const float *) (pd + i*bs);
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = (vi0 - 8)*d;
+            const float v1 = (vi1 - 8)*d;
+
+            //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
+
+            y[i*QK + l + 0] = v0;
+            y[i*QK + l + 1] = v1;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+        }
+    }
+#endif
+}
+
+void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
+    assert(k % QK == 0);
+
+    const int nb = k / QK;
+    const size_t bs = 2*sizeof(float) + QK/2;
+
+    const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
+    const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
+    const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
+        const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs));
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        for (int l = 0; l < QK; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytesFromNibbles(pp+l/2);
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale, add m and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
+                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+            }
+        }
+    }
+#else
+    for (int i = 0; i < nb; i++) {
+        const float d = *(const float *) (pd + i*bs);
+        const float m = *(const float *) (pm + i*bs);
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        for (int l = 0; l < QK; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK + l + 0] = v0;
+            y[i*QK + l + 1] = v1;
+
+            assert(!isnan(y[i*QK + l + 0]));
+            assert(!isnan(y[i*QK + l + 1]));
+        }
+    }
+#endif
+}
+
 //
 // simd mappings
 //
@@ -889,6 +1476,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
+#if __AVX512F__ && QK == 32
+static inline __m512 dot_q4_0_oneblock_avx512(
+    __m512 acc,
+    const uint8_t * pd0,
+    const uint8_t * pd1,
+    const uint8_t * pb0,
+    const uint8_t * pb1,
+    size_t bs,
+    int i
+) {
+    const float * d0_0 = (const float *) (pd0 + i*bs);
+    const float * d1_0 = (const float *) (pd1 + i*bs);
+
+    const uint8_t * restrict p0 = pb0 + (i+0)*bs;
+    const uint8_t * restrict p1 = pb1 + (i+0)*bs;
+
+    // Compute combined scale for the block
+    float scaleScalar = d0_0[0] * d1_0[0];
+    __m512 scale = _mm512_set1_ps( scaleScalar );
+
+    __m256i bx = bytesFromNibbles( p0 );
+    __m256i by = bytesFromNibbles( p1 );
+
+    // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+    const __m256i off = _mm256_set1_epi8( 8 );
+    bx = _mm256_sub_epi8( bx, off );
+    by = _mm256_sub_epi8( by, off );
+
+    // Sign-extend 16 signed bytes into int16_t
+    __m512i x32 = _mm512_cvtepi8_epi16( bx );
+    __m512i y32 = _mm512_cvtepi8_epi16( by );
+    // Compute products of int16_t integers, add pairwise
+    __m512i i64 = _mm512_madd_epi16( x32, y32 );
+
+    // Convert int32_t to float
+    __m512 p = _mm512_cvtepi32_ps( i64 );
+    // Apply the scale, and accumulate
+    return _mm512_fmadd_ps( scale, p, acc );
+}
+#endif
+
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
@@ -925,136 +1553,531 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
     *s = sumf;
 }
 
-// compute GGML_VEC_DOT_UNROLL dot products at once
-// xs - x row stride in bytes
-inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
-    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = n / QK;
 
-    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+    assert(n % QK == 0);
+    assert(nb % 2 == 0);
 
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
-    }
+    const size_t bs = sizeof(float) + QK/2;
 
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
+    const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
+    const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
 
-    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+    const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + sizeof(float));
+    const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + sizeof(float));
 
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
+    float sumf = 0.0;
 
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+#if defined(__ARM_NEON)
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
 
-            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+    for (int i = 0; i < nb; i += 2) {
+        const float d0_0 = *(const float *) (pd0 + i*bs);
+        const float d1_0 = *(const float *) (pd1 + i*bs);
+        const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
+        const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
 
-                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
-            }
-        }
-    }
+        //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
 
-    // reduce sum0..sum3 to sum0
-    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
-    }
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
 
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
-        }
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
-        }
-    }
-#endif
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
 
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        s[i] = sumf[i];
-    }
-}
+        const uint8x16_t v0_0 = vld1q_u8(p0);
+        const uint8x16_t v1_0 = vld1q_u8(p1);
+        const uint8x16_t v0_1 = vld1q_u8(p0 + bs);
+        const uint8x16_t v1_1 = vld1q_u8(p1 + bs);
 
-inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
+        const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
 
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
 
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
+        const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
 
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+        const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
 
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
 
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
+
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
+
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int16x8_t
+        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
+        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
+
+        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
+        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
+
+        // scalar
+#if defined(__ARM_FEATURE_QRDMX)
+        sum0 += d0_0*d1_0*vaddvq_s32(p_0);
+        sum1 += d0_1*d1_1*vaddvq_s32(p_1);
 #else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
+        sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
+        sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
 #endif
-}
+#else
+	    const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
 
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
 
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
 
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
 
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
 
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-        }
-    }
+        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
 
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        GGML_ASSERT(false);
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
+        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+
+        // scalar
+#if defined(__ARM_FEATURE_QRDMX)
+        sum0 += d0_0*d1_0*vaddvq_s16(p_0);
+        sum1 += d0_1*d1_1*vaddvq_s16(p_1);
 #else
-    for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
+        sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
+        sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
 #endif
-}
+#endif
+    }
 
-//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
-inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
+    sumf = sum0 + sum1;
+#elif defined(__AVX512F__)
+    // Initialize accumulator with zeros
+    __m512 acc0 = _mm512_setzero_ps();
+    __m512 acc1 = _mm512_setzero_ps();
 
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+    const int superblock_size = 8;
+    const int superblock_count = nb / superblock_size;
+    const int remainder = nb % superblock_size;
 
-    GGML_F32_VEC ay[GGML_F32_ARR];
+    for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
+        int i = superblock_ix * superblock_size;
 
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
+    }
+
+    // Remainders
+    for (int i = superblock_count * superblock_size; i < nb; ++i) {
+        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
+    }
+
+    // Horizontal sum of all lanes of the accumulator
+    sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
+#elif defined(__AVX2__)
+    const size_t countBlocks = nb;
+
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float * d0_0 = (const float *) (pd0 + i*bs);
+        const float * d1_0 = (const float *) (pd1 + i*bs);
+
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
+
+        // Compute combined scale for the block
+        const __m256 scale = _mm256_mul_ps( _mm256_broadcast_ss( d0_0 ), _mm256_broadcast_ss( d1_0 ) );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        __m256i bx = bytesFromNibbles( p0 );
+        __m256i by = bytesFromNibbles( p1 );
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+        by = _mm256_sub_epi8( by, off );
+
+        // Sign-extend first 16 signed bytes into int16_t
+        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
+        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
+        // Compute products of int16_t integers, add pairwise
+        __m256i i32 = _mm256_madd_epi16( x16, y16 );
+
+        // Sign-extend last 16 signed bytes into int16_t vectors
+        x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
+        y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
+        // Accumulate products of int16_t integers
+        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( i32 );
+        // Apply the scale, and accumulate
+        acc = _mm256_fmadd_ps( scale, p, acc );
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
+#elif defined(__wasm_simd128__)
+    // wasm simd
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
+
+    for (int i = 0; i < nb; i += 2) {
+        const float d0_0 = *(const float *) (pd0 + i*bs);
+        const float d1_0 = *(const float *) (pd1 + i*bs);
+        const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
+        const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
+
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
+
+        const v128_t m4b = wasm_u8x16_splat(0xf);
+        const v128_t s8b = wasm_i8x16_splat(0x8);
+
+        const v128_t v0_0 = wasm_v128_load(p0);
+        const v128_t v0_1 = wasm_v128_load(p0 + bs);
+        const v128_t v1_0 = wasm_v128_load(p1);
+        const v128_t v1_1 = wasm_v128_load(p1 + bs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
+        const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
+
+        const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
+        const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
+
+        const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
+        const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
+
+        const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
+        const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
+
+        // sub 8
+        const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
+        const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
+
+        const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
+        const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
+
+        const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
+        const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
+
+        const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
+        const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
+
+        // dot product into int16x8_t
+        const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
+        const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
+
+        const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
+        const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
+
+        const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
+        const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
+
+        const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
+        const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
+
+        const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
+        const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
+
+        const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
+        const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
+
+        const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
+        const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
+
+        sum0 += d0_0*d1_0*(
+                wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
+                wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
+                wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
+                wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
+        sum1 += d0_1*d1_1*(
+                wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
+                wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
+                wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
+                wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
+    }
+
+    sumf = sum0 + sum1;
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = *(const float *) (pd0 + i*bs);
+        const float d1 = *(const float *) (pd1 + i*bs);
+
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
+            const float f1 = d0*((int8_t) (v0 >> 4)  - 8);
+
+            const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
+            const float f3 = d1*((int8_t) (v1 >> 4)  - 8);
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#endif
+
+    *s = sumf;
+}
+
+inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const int nb = n / QK;
+
+    const size_t bs = 2*sizeof(float) + QK/2;
+
+    const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
+    const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
+
+    const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
+    const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));
+
+    const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
+    const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));
+
+    float sumf = 0.0;
+
+#if defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    // Accumulator for constant offsets
+    float acc_offset = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float * m0 = (const float *) (pm0 + i*bs);
+        const float * m1 = (const float *) (pm1 + i*bs);
+
+        const float * d0 = (const float *) (pd0 + i*bs);
+        const float * d1 = (const float *) (pd1 + i*bs);
+
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
+
+        const __m256 d0v = _mm256_broadcast_ss( d0 );
+        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 m0v = _mm256_broadcast_ss( m0 );
+        const __m256 m1v = _mm256_broadcast_ss( m1 );
+
+
+        // Compute combined scale for the block
+        const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
+
+        // Compute cross scales for the block
+        const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
+        const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
+        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        __m256i bx = bytesFromNibbles( p0 );
+        __m256i by = bytesFromNibbles( p1 );
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval.
+
+        // Sign-extend first 16 signed bytes into int16_t
+        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
+        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
+        // Compute products of int16_t integers, add pairwise
+        __m256i i32 = _mm256_madd_epi16( x16, y16 );
+
+        // Sign-extend last 16 signed bytes into int16_t vectors
+        __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
+        __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
+        // Accumulate products of int16_t integers
+        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
+
+        // compute sums of unsigned bytes in bx, by in blocks of 8.
+        // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
+        // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
+        // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
+        __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
+        __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
+        __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
+        __m256  sums  = _mm256_cvtepi32_ps( sumsi );
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( i32 );
+        // Apply the scale, and accumulate
+        // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
+        acc = _mm256_fmadd_ps( scale_01, p, acc );
+        acc = _mm256_fmadd_ps( cross_scales, sums, acc );
+        // acc_offset += m0*m1 (for each entry in the block)
+        acc_offset += (*m0)*(*m1);
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float m0 = *(const float *) (pm0 + i*bs);
+        const float m1 = *(const float *) (pm1 + i*bs);
+
+        const float d0 = *(const float *) (pd0 + i*bs);
+        const float d1 = *(const float *) (pd1 + i*bs);
+
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
+
+        for (int j = 0; j < QK/2; j++) {
+            const uint8_t v0 = p0[j];
+            const uint8_t v1 = p1[j];
+
+            const float f0 = d0*(v0 & 0xf) + m0;
+            const float f1 = d0*(v0 >> 4)  + m0;
+
+            const float f2 = d1*(v1 & 0xf) + m1;
+            const float f3 = d1*(v1 >> 4)  + m1;
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#endif
+
+    *s = sumf;
+}
+
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#endif
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
 
             GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
         }
@@ -1111,6 +2134,35 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+    return x/(1.0 + exp(-x));
+}
+
+inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = table_silu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_SILU_FP16
+inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]);
+    }
+}
+#endif
+
 inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
@@ -1165,7 +2217,21 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 // data types
 //
 
+static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
+    QK,
+    QK,
+    1,
+    1,
+    1,
+    1,
+    1,
+};
+
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+    sizeof(float  )   + QK/2,
+    sizeof(float  )*2 + QK/2,
     sizeof(int8_t ),
     sizeof(int16_t),
     sizeof(int32_t),
@@ -1173,6 +2239,9 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     sizeof(float  ),
 };
 
+// don't forget to update the array above when adding new types
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
 
@@ -1192,7 +2261,9 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "STEP",
     "RELU",
     "GELU",
+    "SILU",
     "NORM",
+    "RMS_NORM",
 
     "MUL_MAT",
 
@@ -1213,6 +2284,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
 
@@ -1232,7 +2305,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "step(x)",
     "relu(x)",
     "gelu(x)",
+    "silu(x)",
     "norm(x)",
+    "rms_norm(x)",
 
     "X*Y",
 
@@ -1253,6 +2328,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "flash_ff(x)",
 };
 
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+
 //
 // ggml object
 //
@@ -1279,6 +2356,7 @@ struct ggml_context {
     size_t mem_size;
     void * mem_buffer;
     bool   mem_buffer_owned;
+    bool   mem_buffer_mlocked;
 
     int n_objects;
 
@@ -1380,13 +2458,21 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type];
+    return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+}
+
+int ggml_blck_size(enum ggml_type type) {
+    return GGML_BLCK_SIZE[type];
 }
 
 size_t ggml_type_size(enum ggml_type type) {
     return GGML_TYPE_SIZE[type];
 }
 
+float ggml_type_sizef(enum ggml_type type) {
+    return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
 }
@@ -1413,9 +2499,13 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
-        (t0->ne[0]  == t1->ne[0])  &&
-        (t0->ne[2]  == t1->ne[2])  &&
-        (t0->ne[3]  == t1->ne[3]);
+        (t0->ne[0] == t1->ne[0])  &&
+        (t0->ne[2] == t1->ne[2])  &&
+        (t0->ne[3] == t1->ne[3]);
+}
+
+static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+    return tensor->nb[0] > tensor->nb[1];
 }
 
 static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
@@ -1423,7 +2513,7 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
 
     return
         tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
-        tensor->nb[1] == tensor->nb[0]*tensor->ne[0] &&
+        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] &&
         tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
@@ -1474,7 +2564,7 @@ static inline int ggml_up(int n, int m) {
 
 // assert that pointer is aligned to GGML_MEM_ALIGN
 #define ggml_assert_aligned(ptr) \
-    assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+    GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -1485,7 +2575,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     static bool is_first_call = true;
 
     if (is_first_call) {
-        // initialize GELU, EXP and F32 tables
+        // initialize GELU, SILU and EXP F32 tables
         {
             const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
@@ -1495,12 +2585,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
                 memcpy(&ii, &ui, sizeof(ii));
                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
                 table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
-            GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+            GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
         // initialize g_state
@@ -1545,16 +2636,19 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     }
 
     *ctx = (struct ggml_context) {
-        /*.mem_size         =*/ params.mem_size,
-        /*.mem_buffer       =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
-        /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
-        /*.n_objects        =*/ 0,
-        /*.objects_begin    =*/ NULL,
-        /*.objects_end      =*/ NULL,
-        /*.scratch          =*/ { 0, 0, NULL, },
-        /*.scratch_save     =*/ { 0, 0, NULL, },
+        /*.mem_size           =*/ params.mem_size,
+        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+        /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
+        /*.mem_buffer_mlocked =*/ false,
+        /*.n_objects          =*/ 0,
+        /*.objects_begin      =*/ NULL,
+        /*.objects_end        =*/ NULL,
+        /*.scratch            =*/ { 0, 0, NULL, },
+        /*.scratch_save       =*/ { 0, 0, NULL, },
     };
 
+    GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
+
     ggml_assert_aligned(ctx->mem_buffer);
 
     GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
@@ -1577,6 +2671,14 @@ void ggml_free(struct ggml_context * ctx) {
             GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
                     __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
 
+#if GGML_MLOCK_SUPPORT
+            if (ctx->mem_buffer_mlocked) {
+                if (munlock(ctx->mem_buffer, ctx->mem_size)) {
+                    fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
+                }
+            }
+#endif
+
             if (ctx->mem_buffer_owned) {
                 free(ctx->mem_buffer);
             }
@@ -1605,6 +2707,37 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
+bool ggml_mlock_supported(void) {
+    return GGML_MLOCK_SUPPORT;
+}
+
+#if GGML_MLOCK_SUPPORT
+#ifdef __APPLE__
+    #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
+                             "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MLOCK (ulimit -l)."
+#else
+    #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
+#endif
+bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
+    if (ctx->mem_buffer_mlocked) {
+        return true;
+    }
+    if (mlock(ctx->mem_buffer, ctx->mem_size)) {
+        int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
+                           ctx->mem_size, strerror(errno));
+        GGML_ASSERT(ret >= 0);
+        return false;
+    }
+    ctx->mem_buffer_mlocked = true;
+    return true;
+}
+#else // GGML_MLOCK_SUPPORT
+bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
+    *err_p = strdup("can't mlock because it's not supported on this system");
+    return false;
+}
+#endif // GGML_MLOCK_SUPPORT
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_tensor * ggml_new_tensor_impl(
@@ -1623,8 +2756,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
     size_t size_needed = 0;
 
     if (data == NULL) {
-        size_needed += GGML_TYPE_SIZE[type];
-        for (int i = 0; i < n_dims; i++) {
+        size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
+        for (int i = 1; i < n_dims; i++) {
             size_needed *= ne[i];
         }
         // align to GGML_MEM_ALIGN
@@ -1717,7 +2850,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
     }
 
     result->nb[0] = GGML_TYPE_SIZE[type];
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+    result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
         result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
     }
 
@@ -1814,6 +2948,14 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -1851,7 +2993,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
             } break;
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 
@@ -1866,6 +3008,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -1903,7 +3053,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
             } break;
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 
@@ -1912,6 +3062,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
 
 int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -1948,6 +3106,14 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -1982,6 +3148,14 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
 
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -2018,9 +3192,17 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
     switch (tensor->type) {
-        case GGML_TYPE_I8:
+        case GGML_TYPE_Q4_0:
             {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(false);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
                 ((int8_t *)(tensor->data))[i] = value;
             } break;
         case GGML_TYPE_I16:
@@ -2108,7 +3290,7 @@ struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2147,7 +3329,7 @@ struct ggml_tensor * ggml_sub_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2186,7 +3368,7 @@ struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2195,7 +3377,7 @@ struct ggml_tensor * ggml_mul_impl(
     }
 
     if (inplace) {
-        assert(is_node == false);
+        GGML_ASSERT(is_node == false);
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -2229,7 +3411,7 @@ struct ggml_tensor * ggml_div_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    assert(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_are_same_shape(a, b));
 
     bool is_node = false;
 
@@ -2238,7 +3420,7 @@ struct ggml_tensor * ggml_div_impl(
     }
 
     if (inplace) {
-        assert(is_node == false);
+        GGML_ASSERT(is_node == false);
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -2362,7 +3544,7 @@ struct ggml_tensor * ggml_mean(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement
+        GGML_ASSERT(false); // TODO: implement
         is_node = true;
     }
 
@@ -2383,7 +3565,7 @@ struct ggml_tensor * ggml_repeat(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b) {
-    assert(ggml_can_repeat(a, b));
+    GGML_ASSERT(ggml_can_repeat(a, b));
 
     bool is_node = false;
 
@@ -2610,6 +3792,40 @@ struct ggml_tensor * ggml_gelu_inplace(
     return ggml_gelu_impl(ctx, a, true);
 }
 
+// ggml_silu
+
+struct ggml_tensor * ggml_silu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SILU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_silu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_silu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_silu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_silu_impl(ctx, a, true);
+}
+
 // ggml_norm
 
 struct ggml_tensor * ggml_norm_impl(
@@ -2619,7 +3835,7 @@ struct ggml_tensor * ggml_norm_impl(
     bool is_node = false;
 
     if (!inplace && (a->grad)) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2645,13 +3861,47 @@ struct ggml_tensor * ggml_norm_inplace(
     return ggml_norm_impl(ctx, a, true);
 }
 
+struct ggml_tensor * ggml_rms_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_RMS_NORM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_rms_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_rms_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_rms_norm_impl(ctx, a, true);
+}
+
 // ggml_mul_mat
 
 struct ggml_tensor * ggml_mul_mat(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_can_mul_mat(a, b));
+    GGML_ASSERT(ggml_can_mul_mat(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
 
     bool is_node = false;
 
@@ -2677,13 +3927,13 @@ struct ggml_tensor * ggml_scale_impl(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         bool inplace) {
-    assert(ggml_is_scalar(b));
-    assert(ggml_is_padded_1d(a));
+    GGML_ASSERT(ggml_is_scalar(b));
+    GGML_ASSERT(ggml_is_padded_1d(a));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2720,12 +3970,12 @@ struct ggml_tensor * ggml_cpy_impl(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         bool inplace) {
-    assert(ggml_nelements(a) == ggml_nelements(b));
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2760,14 +4010,14 @@ struct ggml_tensor * ggml_reshape(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b) {
-    assert(ggml_is_contiguous(a));
-    assert(ggml_is_contiguous(b));
-    assert(ggml_nelements(a) == ggml_nelements(b));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
 
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2786,13 +4036,13 @@ struct ggml_tensor * ggml_reshape_2d(
         struct ggml_tensor  * a,
         int                   ne0,
         int                   ne1) {
-    assert(ggml_is_contiguous(a));
-    assert(ggml_nelements(a) == ne0*ne1);
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
 
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2813,13 +4063,13 @@ struct ggml_tensor * ggml_reshape_3d(
         int                   ne0,
         int                   ne1,
         int                   ne2) {
-    assert(ggml_is_contiguous(a));
-    assert(ggml_nelements(a) == ne0*ne1*ne2);
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
 
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2842,7 +4092,7 @@ struct ggml_tensor * ggml_view_1d(
         int                   ne0,
         size_t                offset) {
     if (a->grad) {
-        assert(false); // gradient propagation is not supported
+        GGML_ASSERT(false); // gradient propagation is not supported
     }
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
@@ -2865,7 +4115,7 @@ struct ggml_tensor * ggml_view_2d(
         size_t                nb1,
         size_t                offset) {
     if (a->grad) {
-        assert(false); // gradient propagation is not supported
+        GGML_ASSERT(false); // gradient propagation is not supported
     }
 
     const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
@@ -2893,22 +4143,22 @@ struct ggml_tensor * ggml_permute(
         int                   axis1,
         int                   axis2,
         int                   axis3) {
-    assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
-    assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
-    assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
-    assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
-
-    assert(axis0 != axis1);
-    assert(axis0 != axis2);
-    assert(axis0 != axis3);
-    assert(axis1 != axis2);
-    assert(axis1 != axis3);
-    assert(axis2 != axis3);
+    GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+    GGML_ASSERT(axis0 != axis1);
+    GGML_ASSERT(axis0 != axis2);
+    GGML_ASSERT(axis0 != axis3);
+    GGML_ASSERT(axis1 != axis2);
+    GGML_ASSERT(axis1 != axis3);
+    GGML_ASSERT(axis2 != axis3);
 
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2953,7 +4203,7 @@ struct ggml_tensor * ggml_transpose(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -2979,12 +4229,12 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
 
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3009,7 +4259,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3034,7 +4284,7 @@ struct ggml_tensor * ggml_soft_max(
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3058,11 +4308,11 @@ struct ggml_tensor * ggml_rope(
         int                   n_past,
         int                   n_dims,
         int                   mode) {
-    assert(n_past >= 0);
+    GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
     if (a->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3089,13 +4339,13 @@ struct ggml_tensor * ggml_conv_1d_1s(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_is_matrix(b));
-    assert(a->ne[1] == b->ne[1]);
-    assert(a->ne[3] == 1);
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[1] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3116,13 +4366,13 @@ struct ggml_tensor * ggml_conv_1d_2s(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    assert(ggml_is_matrix(b));
-    assert(a->ne[1] == b->ne[1]);
-    assert(a->ne[3] == 1);
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[1] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        assert(false); // TODO: implement backward
+        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -3145,7 +4395,7 @@ struct ggml_tensor * ggml_flash_attn(
         struct ggml_tensor  * k,
         struct ggml_tensor  * v,
         bool                  masked) {
-    assert(ggml_can_mul_mat(k, q));
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
     // TODO: check if vT can be multiplied by (k*qT)
 
     bool is_node = false;
@@ -3177,7 +4427,7 @@ struct ggml_tensor * ggml_flash_ff(
         struct ggml_tensor  * b1,
         struct ggml_tensor  * c0,
         struct ggml_tensor  * c1) {
-    assert(ggml_can_mul_mat(b0, a));
+    GGML_ASSERT(ggml_can_mul_mat(b0, a));
     // TODO: more checks
 
     bool is_node = false;
@@ -3208,7 +4458,7 @@ void ggml_set_param(
         struct ggml_tensor * tensor) {
     tensor->is_param = true;
 
-    assert(tensor->grad == NULL);
+    GGML_ASSERT(tensor->grad == NULL);
     tensor->grad = ggml_dup_tensor(ctx, tensor);
 }
 
@@ -3218,9 +4468,9 @@ static void ggml_compute_forward_dup_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -3243,7 +4493,7 @@ static void ggml_compute_forward_dup_f16(
 
     if (src0->nb[0] == sizeof(ggml_fp16_t)) {
         if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             const size_t rs = ne00*nb00;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3259,7 +4509,7 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         } else if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3281,7 +4531,7 @@ static void ggml_compute_forward_dup_f16(
         //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3297,7 +4547,7 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3347,7 +4597,7 @@ static void ggml_compute_forward_dup_f32(
 
     if (src0->nb[0] == sizeof(float)) {
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             const size_t rs = ne00*nb00;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3363,7 +4613,7 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3385,7 +4635,7 @@ static void ggml_compute_forward_dup_f32(
         //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3401,7 +4651,7 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -3435,6 +4685,8 @@ static void ggml_compute_forward_dup(
             {
                 ggml_compute_forward_dup_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -3510,13 +4762,15 @@ static void ggml_compute_forward_add(
             {
                 ggml_compute_forward_add_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3560,13 +4814,15 @@ static void ggml_compute_forward_sub(
             {
                 ggml_compute_forward_sub_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3610,13 +4866,15 @@ static void ggml_compute_forward_mul(
             {
                 ggml_compute_forward_mul_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3660,13 +4918,15 @@ static void ggml_compute_forward_div(
             {
                 ggml_compute_forward_div_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3706,13 +4966,15 @@ static void ggml_compute_forward_sqr(
             {
                 ggml_compute_forward_sqr_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3752,13 +5014,15 @@ static void ggml_compute_forward_sqrt(
             {
                 ggml_compute_forward_sqrt_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3808,13 +5072,15 @@ static void ggml_compute_forward_sum(
             {
                 ggml_compute_forward_sum_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3883,13 +5149,15 @@ static void ggml_compute_forward_mean(
             {
                 ggml_compute_forward_mean_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3945,13 +5213,15 @@ static void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -3991,13 +5261,15 @@ static void ggml_compute_forward_abs(
             {
                 ggml_compute_forward_abs_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4037,13 +5309,15 @@ static void ggml_compute_forward_sgn(
             {
                 ggml_compute_forward_sgn_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4083,13 +5357,15 @@ static void ggml_compute_forward_neg(
             {
                 ggml_compute_forward_neg_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4129,13 +5405,15 @@ static void ggml_compute_forward_step(
             {
                 ggml_compute_forward_step_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4175,13 +5453,15 @@ static void ggml_compute_forward_relu(
             {
                 ggml_compute_forward_relu_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -4238,17 +5518,87 @@ static void ggml_compute_forward_gelu(
             {
                 ggml_compute_forward_gelu_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    //printf("XXXXXXXX gelu\n");
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
 
+
 // ggml_compute_forward_norm
 
 static void ggml_compute_forward_norm_f32(
@@ -4320,17 +5670,100 @@ static void ggml_compute_forward_norm(
             {
                 ggml_compute_forward_norm_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+static void ggml_compute_forward_rms_norm_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const ggml_float eps = 1e-6f; // TODO: make this a parameter
+
+    // TODO: optimize
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float mean = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    mean += x[i00] * x[i00];
+                }
+
+                mean /= ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+                // for (int i00 = 0; i00 < ne00; i00++) {
+                //     y[i00] = x[i00];
+                // }
+
+                const float scale = 1.0/sqrt(mean + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
 
+
 // ggml_compute_forward_mul_mat
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -4340,7 +5773,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    UNUSED(src0);
+    //const int ne00 = src0->ne[0];
+    //const int ne01 = src0->ne[1];
 
     const int ne10 = src1->ne[0];
 
@@ -4348,10 +5782,10 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
-             (ne0 >= 32 && ne1  >= 32   && ne10 >= 32)
-            )) {
-        //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
     }
 
@@ -4374,16 +5808,16 @@ static void ggml_compute_forward_mul_mat_f32(
 
     const int ne10 = src1->ne[0];
     const int ne11 = src1->ne[1];
-    const int ne12 = src1->ne[2];
-    const int ne13 = src1->ne[3];
+    //const int ne12 = src1->ne[2];
+    //const int ne13 = src1->ne[3];
 
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    const int ne2  = dst->ne[2];
-    const int ne3  = dst->ne[3];
-    const int ne   = ne0*ne1*ne2*ne3;
+    //const int ne0  = dst->ne[0];
+    //const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
 
-    const int nb00 = src0->nb[0];
+    //const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
     const int nb02 = src0->nb[2];
     const int nb03 = src0->nb[3];
@@ -4407,7 +5841,7 @@ static void ggml_compute_forward_mul_mat_f32(
     assert(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+    assert(nb00 == sizeof(float));
 
     // dst cannot be transposed or permuted
     assert(nb0 == sizeof(float));
@@ -4422,9 +5856,6 @@ static void ggml_compute_forward_mul_mat_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -4444,19 +5875,17 @@ static void ggml_compute_forward_mul_mat_f32(
 
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
-                const float * x = (float *) (src0->data);
+                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // zT = y * xT
-                {
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne10,
-                                     x, ne10,
-                            0.0f,    d, ne01);
-                }
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -4467,130 +5896,430 @@ static void ggml_compute_forward_mul_mat_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
         return;
     }
 
     if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
+        return;
+    }
 
-        // TODO: fix this memset (wsize is overestimated)
-        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+    // TODO: do not support transposed src1
+    assert(nb10 == sizeof(float));
 
-        float * const wdata = params->wdata;
+    // parallelize by src0 rows using ggml_vec_dot_f32
 
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        for (int ic = 0; ic < ne11; ++ic) {
+            // src1 indices
+            const int i13 = i03;
+            const int i12 = i02;
+            const int i11 = ic;
+
+            // dst indices
+            const int i0 = i01;
+            const int i1 = i11;
+            const int i2 = i02;
+            const int i3 = i03;
 
-        for (int k = 1; k < nth; k++) {
-            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+            ggml_vec_dot_f32(ne00,
+                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
         }
+    }
+
+    //int64_t t1 = ggml_perf_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
+
+        float * const wdata = params->wdata;
+
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                {
+                    size_t id = 0;
+                    for (int i01 = 0; i01 < ne01; ++i01) {
+                        for (int i00 = 0; i00 < ne00; ++i00) {
+                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                        }
+                    }
+                }
+
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
+            }
+        }
+
+        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
         return;
     }
+#endif
+
+    if (params->type == GGML_TASK_INIT) {
+        ggml_fp16_t * const wdata = params->wdata;
 
-    if (nb01 >= nb00) {
-        // TODO: do not support transposed src1
-        assert(nb10 == sizeof(float));
+        size_t id = 0;
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    for (int i10 = 0; i10 < ne10; ++i10) {
+                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                    }
+                }
+            }
+        }
 
-        // parallelize by src0 rows using ggml_vec_dot_f32
+        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
 
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
+        return;
+    }
 
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
 
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
+    // fp16 -> half the size, so divide by 2
+    // TODO: do not support transposed src1
+    assert(nb10/2 == sizeof(ggml_fp16_t));
 
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+    // parallelize by src0 rows using ggml_vec_dot_f16
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                // src1 indices
-                const int i13 = i03;
-                const int i12 = i02;
-                const int i11 = ic;
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-                // dst indices
-                const int i0 = i01;
-                const int i1 = i11;
-                const int i2 = i02;
-                const int i3 = i03;
-
-                ggml_vec_dot_f32(ne00,
-                        (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
-                        (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
-                        (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
-            }
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    ggml_fp16_t * wdata = params->wdata;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const int i13 = i03;
+        const int i12 = i02;
+
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
+
+        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
+
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
         }
-    } else {
-        // parallelize by src1 columns using ggml_vec_mad_f32
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
+    }
+
+    //int64_t t1 = ggml_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_mul_mat_q4_0_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+    //const int ne   = ne0*ne1*ne2*ne3;
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // TODO: we don't support permuted src0
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
 
-        // total columns in src1
-        const int nc = ne10;
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
 
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+        GGML_ASSERT(nb10 == sizeof(float));
+
+        if (params->ith != 0) {
+            return;
+        }
+
+        if (params->type == GGML_TASK_INIT) {
+            return;
+        }
 
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
+        if (params->type == GGML_TASK_FINALIZE) {
+            return;
+        }
 
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
         float * const wdata = params->wdata;
 
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                {
+                    size_t id = 0;
+                    for (int i01 = 0; i01 < ne01; ++i01) {
+                        dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        id += ne00;
+                    }
+                }
+
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
+            }
+        }
+
+        /*printf("CBLAS Q4_0 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
+
+        return;
+    }
+#endif
+
+    if (params->type == GGML_TASK_INIT) {
+        char * wdata = params->wdata;
+
         for (int i13 = 0; i13 < ne13; ++i13) {
             for (int i12 = 0; i12 < ne12; ++i12) {
                 for (int i11 = 0; i11 < ne11; ++i11) {
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
-
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
-
-                        // dst indices
-                        const int i1 = i11;
-                        const int i2 = i12;
-                        const int i3 = i13;
-
-                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
-                        ggml_vec_mad_f32(ne01,
-                                (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
-                                (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
-                               *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
-                    }
+                    quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
                 }
             }
         }
+
+        return;
     }
 
-    //int64_t t1 = ggml_perf_time_us();
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: do not support transposed src1
+
+    // parallelize by src0 rows using ggml_vec_dot_q4_0
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    void * wdata = params->wdata;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const int i13 = i03;
+        const int i12 = i02;
+
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
+
+        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
+
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+        assert(ne00 % 32 == 0);
+
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
+        }
+    }
+
+    //int64_t t1 = ggml_time_us();
     //static int64_t acc = 0;
     //acc += t1 - t0;
     //if (t1 - t0 > 10) {
@@ -4598,13 +6327,12 @@ static void ggml_compute_forward_mul_mat_f32(
     //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
     //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
     //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
 
     //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
     //}
 }
 
-static void ggml_compute_forward_mul_mat_f16_f32(
+static void ggml_compute_forward_mul_mat_q4_1_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -4626,7 +6354,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     const int ne1  = dst->ne[1];
     const int ne2  = dst->ne[2];
     const int ne3  = dst->ne[3];
-    const int ne   = ne0*ne1*ne2*ne3;
+    //const int ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -4652,7 +6380,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     GGML_ASSERT(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -4667,9 +6395,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -4692,54 +6417,24 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
                 {
-                    int id = 0;
+                    size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
-                        for (int i00 = 0; i00 < ne00; ++i00) {
-                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        }
+                        dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        id += ne00;
                     }
                 }
 
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
-                //      float * z =                          wdata + ne00*ne01;
-
-                // z = x * yT
-                //{
-                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                //            ne01, ne11, ne00,
-                //            1.0f, x, ne00,
-                //                  y, ne00,
-                //            0.0f, z, ne11);
-                //}
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                // transpose z
-                //for (int j = 0; j < ne11; ++j) {
-                //    for (int i = 0; i < ne01; ++i) {
-                //        d[j*ne01 + i] = z[i*ne11 + j];
-                //    }
-                //}
-
-                {
-#if 1
-                    // zT = y * xT
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne00,
-                                     x, ne00,
-                            0.0f,    d, ne01);
-#else
-                    // zT = (xT * y)T
-                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
-                            ne01, ne11, ne10,
-                            1.0f,    x, ne00,
-                                     y, ne00,
-                            0.0f,    d, ne01);
-#endif
-                }
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -4750,150 +6445,65 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        if (nb01 >= nb00) {
-            ggml_fp16_t * const wdata = params->wdata;
-
-            int id = 0;
-            for (int i13 = 0; i13 < ne13; ++i13) {
-                for (int i12 = 0; i12 < ne12; ++i12) {
-                    for (int i11 = 0; i11 < ne11; ++i11) {
-                        for (int i10 = 0; i10 < ne10; ++i10) {
-                            wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                        }
-                    }
+        char * wdata = params->wdata;
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    //for (int i10 = 0; i10 < ne10; ++i10) {
+                    //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                    //}
+                    quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
                 }
             }
-
-            GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
-            return;
         }
 
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
         return;
     }
 
     if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
-
-        ggml_fp16_t * const wdata = params->wdata;
-
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
-
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
-
-        for (int i = ic0; i < ic1; ++i) {
-            ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
-        }
-
-        for (int k = 1; k < nth; k++) {
-            for (int i = ic0; i < ic1; ++i) {
-                ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
-            }
-        }
-
         return;
     }
 
-    if (nb01 >= nb00) {
-        // fp16 -> half the size, so divide by 2
-        // TODO: do not support transposed src1
-        assert(nb10/2 == sizeof(ggml_fp16_t));
-
-        // parallelize by src0 rows using ggml_vec_dot_f16
-
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
-
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
-
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
-
-        ggml_fp16_t * wdata = params->wdata;
-
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int i13 = i03;
-            const int i12 = i02;
-
-            const int i0 = i01;
-            const int i2 = i02;
-            const int i3 = i03;
-
-            ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
-
-            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+    // TODO: do not support transposed src1
 
-            assert(ne00 % 32 == 0);
-
-            for (int ic = 0; ic < ne11; ++ic) {
-                ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
-            }
-        }
-    } else {
-        // parallelize by src1 columns using ggml_vec_mad_f16
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
+    // parallelize by src0 rows using ggml_vec_dot_q4_1
 
-        // total columns in src1
-        const int nc = ne10;
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
 
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        ggml_fp16_t * const wdata = params->wdata;
+    void * wdata = params->wdata;
 
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    // dst indices
-                    const int i1 = i11;
-                    const int i2 = i12;
-                    const int i3 = i13;
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-                    ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+        const int i13 = i03;
+        const int i12 = i02;
 
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
 
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
+        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
 
-                        assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-                        ggml_fp16_t * src0_col =  (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
-                        float         src1_val = *      (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+        assert(ne00 % 32 == 0);
 
-                        ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
-                    }
-                }
-            }
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
         }
     }
 
@@ -4916,6 +6526,14 @@ static void ggml_compute_forward_mul_mat(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                ggml_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_mul_mat_q4_1_f32(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F16:
             {
                 ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
@@ -4929,9 +6547,37 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
+
+#if 0
+    if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
+        static int first = 8;
+        printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
+        printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
+        printf("dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+        if (first) {
+            --first;
+        } else {
+            for (int k = 0; k < dst->ne[1]; ++k) {
+                for (int j = 0; j < dst->ne[0]/16; ++j) {
+                    for (int i = 0; i < 16; ++i) {
+                        printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+                    }
+                    printf("\n");
+                }
+                printf("\n");
+            }
+            printf("\n");
+            exit(0);
+        }
+    } else {
+        printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
+        printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
+        printf("aaaa dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    }
+#endif
 }
 
 // ggml_compute_forward_scale
@@ -4981,13 +6627,15 @@ static void ggml_compute_forward_scale(
             {
                 ggml_compute_forward_scale_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5045,6 +6693,60 @@ static void ggml_compute_forward_transpose(
 
 // ggml_compute_forward_get_rows
 
+static void ggml_compute_forward_get_rows_q4_0(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        dequantize_row_q4_0(
+                (const void *) ((char *) src0->data + r*src0->nb[1]),
+                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_q4_1(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    assert( dst->ne[0] == nc);
+    assert( dst->ne[1] == nr);
+    assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        dequantize_row_q4_1(
+                (const void *) ((char *) src0->data + r*src0->nb[1]),
+                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+    }
+}
+
 static void ggml_compute_forward_get_rows_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -5106,6 +6808,14 @@ static void ggml_compute_forward_get_rows(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            {
+                ggml_compute_forward_get_rows_q4_0(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_get_rows_q4_1(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F16:
             {
                 ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
@@ -5119,9 +6829,27 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
 }
 
 // ggml_compute_forward_diag_mask_inf
@@ -5172,13 +6900,15 @@ static void ggml_compute_forward_diag_mask_inf(
             {
                 ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5217,6 +6947,7 @@ static void ggml_compute_forward_soft_max_f32(
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
             assert(!isnan(p[i]));
         }
 #endif
@@ -5263,13 +6994,15 @@ static void ggml_compute_forward_soft_max(
             {
                 ggml_compute_forward_soft_max_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5333,23 +7066,84 @@ static void ggml_compute_forward_rope_f32(
     }
 }
 
+static void ggml_compute_forward_rope_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    //const int ne0 = src0->ne[0];
+    const int ne1 = src0->ne[1];
+    const int ne2 = src0->ne[2];
+    const int ne3 = src0->ne[3];
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    const int nb3 = src0->nb[3];
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(ggml_fp16_t));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int p = (mode == 0 ? n_past + i2 : i2);
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < n_dims; i0 += 2) {
+                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+
+                    const double cos_theta = cos(p*theta);
+                    const double sin_theta = sin(p*theta);
+
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                          ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+                    double x0 = ggml_fp16_to_fp32(src[0]);
+                    double x1 = ggml_fp16_to_fp32(src[1]);
+
+                    dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
+                    dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
+                }
+            }
+        }
+    }
+}
+
 static void ggml_compute_forward_rope(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+            } break;
         case GGML_TYPE_F32:
             {
                 ggml_compute_forward_rope_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -5610,6 +7404,8 @@ static void ggml_compute_forward_conv_1d_1s(
             {
                 ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -5876,6 +7672,8 @@ static void ggml_compute_forward_conv_1d_2s(
             {
                 ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -6359,12 +8157,14 @@ static void ggml_compute_forward_flash_attn(
             {
                 ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -6568,12 +8368,14 @@ static void ggml_compute_forward_flash_ff(
             {
                 GGML_ASSERT(false); // TODO
             } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -6581,7 +8383,7 @@ static void ggml_compute_forward_flash_ff(
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
-    assert(params);
+    GGML_ASSERT(params);
 
     switch (tensor->op) {
         case GGML_OP_DUP:
@@ -6648,10 +8450,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gelu(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_NORM:
             {
                 ggml_compute_forward_norm(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_RMS_NORM:
+            {
+                ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -6829,7 +8639,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_MEAN:
             {
-                assert(false); // TODO: implement
+                GGML_ASSERT(false); // TODO: implement
             } break;
         case GGML_OP_REPEAT:
             {
@@ -6884,17 +8694,25 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_GELU:
             {
-                assert(false); // TODO: not implemented
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_SILU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_NORM:
             {
-                assert(false); // TODO: not implemented
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_RMS_NORM:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_MUL_MAT:
             {
                 if (src0->grad) {
                     // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
-                    assert(false);
+                    GGML_ASSERT(false);
                 }
                 if (src1->grad) {
                     src1->grad =
@@ -7010,12 +8828,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
 
     if (node->op == GGML_OP_NONE && node->grad == NULL) {
         // reached a leaf node, not part of the gradient graph (e.g. a constant)
-        assert(cgraph->n_leafs < GGML_MAX_NODES);
+        GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
 
         cgraph->leafs[cgraph->n_leafs] = node;
         cgraph->n_leafs++;
     } else {
-        assert(cgraph->n_nodes < GGML_MAX_NODES);
+        GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
 
         cgraph->nodes[cgraph->n_nodes] = node;
         cgraph->grads[cgraph->n_nodes] = node->grad;
@@ -7039,7 +8857,7 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
 
     if (n_new > 0) {
         // the last added node should always be starting point
-        assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+        GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
     }
 }
 
@@ -7070,7 +8888,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
 struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
     struct ggml_cgraph result = *gf;
 
-    assert(gf->n_nodes > 0);
+    GGML_ASSERT(gf->n_nodes > 0);
 
     // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
     if (keep) {
@@ -7233,10 +9051,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 }
 
 void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
-    if (cgraph->n_threads <= 0) {
-        cgraph->n_threads = 8;
-    }
-
     const int n_threads = cgraph->n_threads;
 
     struct ggml_compute_state_shared state_shared = {
@@ -7269,7 +9083,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
             };
 
             int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            assert(rc == 0);
+            GGML_ASSERT(rc == 0);
             UNUSED(rc);
         }
     }
@@ -7311,7 +9125,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
+                case GGML_OP_SILU:
+                    {
+                        node->n_tasks = n_threads;
+                    } break;
                 case GGML_OP_NORM:
+                case GGML_OP_RMS_NORM:
                     {
                         node->n_tasks = n_threads;
                     } break;
@@ -7328,32 +9147,51 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
-                        // TODO: better way to determine if the matrix is transposed
-                        if (node->src0->nb[1] < node->src0->nb[0]) {
-                            cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
-                        } else {
-                            if (node->src0->type == GGML_TYPE_F16 &&
+                        if (node->src0->type == GGML_TYPE_F16 &&
+                                node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                   //       the threads are still spinning
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
+                                //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
+                                //printf("cur = %zu\n", cur);
+                            } else {
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
+                            }
+#else
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
+#endif
+                        } else if (node->src0->type == GGML_TYPE_F32 &&
+                                node->src1->type == GGML_TYPE_F32) {
+                            cur = 0;
+                        } else if (node->src0->type == GGML_TYPE_Q4_0 &&
                                 node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                    node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                       //       the threads are still spinning
-                                    cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
-                                    //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
-                                    //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
-                                    //printf("cur = %zu\n", cur);
-                                } else {
-                                    cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
-                                }
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                            } else {
+                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+                            }
 #else
-                                cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
+                            cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
 #endif
-                            } else if (node->src0->type == GGML_TYPE_F32 &&
-                                       node->src1->type == GGML_TYPE_F32) {
-                                cur = 0;
+                        } else if (node->src0->type == GGML_TYPE_Q4_1 &&
+                                node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
                             } else {
-                                GGML_ASSERT(false);
+                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
                             }
+#else
+                            cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+#endif
+                        } else {
+                            GGML_ASSERT(false);
                         }
 
                         work_size = MAX(work_size, cur);
@@ -7454,13 +9292,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_COUNT:
                     {
-                        assert(false);
+                        GGML_ASSERT(false);
                     } break;
             }
         }
 
         if (cgraph->work != NULL && work_size > cgraph->work_size) {
-            assert(false); // TODO: better handling
+            GGML_ASSERT(false); // TODO: better handling
         }
 
         if (work_size > 0 && cgraph->work == NULL) {
@@ -7626,7 +9464,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
         for (int j = 0; j < n_threads - 1; j++) {
             int rc = ggml_thread_join(workers[j].thrd, NULL);
-            assert(rc == 0);
+            GGML_ASSERT(rc == 0);
             UNUSED(rc);
         }
 
@@ -7733,7 +9571,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
     char color[16];
 
     FILE * fp = fopen(filename, "w");
-    assert(fp);
+    GGML_ASSERT(fp);
 
     fprintf(fp, "digraph G {\n");
     fprintf(fp, "  newrank = true;\n");
@@ -7891,7 +9729,7 @@ static enum ggml_opt_result ggml_opt_adam(
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
         struct ggml_cgraph * gb) {
-    assert(ggml_is_scalar(f));
+    GGML_ASSERT(ggml_is_scalar(f));
 
     gf->n_threads = params.n_threads;
     gb->n_threads = params.n_threads;
@@ -7905,7 +9743,7 @@ static enum ggml_opt_result ggml_opt_adam(
         if (gf->nodes[i]->is_param) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
 
-            assert(np < GGML_MAX_PARAMS);
+            GGML_ASSERT(np < GGML_MAX_PARAMS);
 
             ps[np++] = gf->nodes[i];
             nx += ggml_nelements(gf->nodes[i]);
@@ -8205,7 +10043,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         if (gf->nodes[i]->is_param) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
 
-            assert(np < GGML_MAX_PARAMS);
+            GGML_ASSERT(np < GGML_MAX_PARAMS);
 
             ps[np++] = gf->nodes[i];
             nx += ggml_nelements(gf->nodes[i]);
@@ -8517,6 +10355,68 @@ enum ggml_opt_result ggml_opt(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist) {
+    const int nb = k / qk;
+    const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2);
+    const size_t row_size = nb*bs;
+
+    assert(k % qk == 0);
+
+    char * pdst = (char *) dst;
+
+    for (int j = 0; j < n; j += k) {
+        uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
+        uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
+
+        quantize_row_q4_0_reference(src + j, pd, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < qk; l += 2) {
+                const uint8_t vi0 = pb[l/2] & 0xF;
+                const uint8_t vi1 = pb[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+            pb += bs;
+        }
+    }
+
+    return (n/k)*row_size;
+}
+
+size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist) {
+    const int nb = k / qk;
+    const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2);
+    const size_t row_size = nb*bs;
+
+    assert(k % qk == 0);
+
+    char * pdst = (char *) dst;
+
+    for (int j = 0; j < n; j += k) {
+        uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
+        uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
+
+        quantize_row_q4_1(src + j, pd, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < qk; l += 2) {
+                const uint8_t vi0 = pb[l/2] & 0xF;
+                const uint8_t vi1 = pb[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+            pb += bs;
+        }
+    }
+
+    return (n/k)*row_size;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 int ggml_cpu_has_avx(void) {
 #if defined(__AVX__)
     return 1;
diff --git a/ggml.h b/ggml.h
index 18f317bec04..ddb97318b33 100644
--- a/ggml.h
+++ b/ggml.h
@@ -198,6 +198,8 @@ struct ggml_object;
 struct ggml_context;
 
 enum ggml_type {
+    GGML_TYPE_Q4_0,
+    GGML_TYPE_Q4_1,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
@@ -226,7 +228,9 @@ enum ggml_op {
     GGML_OP_STEP,
     GGML_OP_RELU,
     GGML_OP_GELU,
+    GGML_OP_SILU,
     GGML_OP_NORM, // normalize
+    GGML_OP_RMS_NORM,
 
     GGML_OP_MUL_MAT,
 
@@ -326,7 +330,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
 int    ggml_nelements(const struct ggml_tensor * tensor);
 size_t ggml_nbytes   (const struct ggml_tensor * tensor);
 
-size_t ggml_type_size   (enum ggml_type type);
+int    ggml_blck_size (enum ggml_type type);
+size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
+float  ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
+
 size_t ggml_element_size(const struct ggml_tensor * tensor);
 
 struct ggml_context * ggml_init(struct ggml_init_params params);
@@ -336,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
 
 size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
 
+bool ggml_mlock_supported(void);
+bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
+
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
@@ -466,12 +476,20 @@ struct ggml_tensor * ggml_gelu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_silu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // normalize along rows
 // TODO: eps is hardcoded to 1e-5 for now
 struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // A: m rows, n columns
 // B: p rows, n columns (i.e. we transpose it internally)
 // result is m columns, p rows
@@ -726,6 +744,13 @@ enum ggml_opt_result ggml_opt(
         struct ggml_opt_params params,
         struct ggml_tensor * f);
 
+//
+// quantization
+//
+
+size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
+size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
+
 //
 // system info
 //
diff --git a/models/download-coreml-model.sh b/models/download-coreml-model.sh
new file mode 100755
index 00000000000..d46789d7c06
--- /dev/null
+++ b/models/download-coreml-model.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+
+# This script downloads Whisper model files that have already been converted to Core ML format.
+# This way you don't have to convert them yourself.
+
+src="https://huggingface.co/datasets/ggerganov/whisper.cpp-coreml"
+pfx="resolve/main/ggml"
+
+# get the path of this script
+function get_script_path() {
+    if [ -x "$(command -v realpath)" ]; then
+        echo "$(dirname $(realpath $0))"
+    else
+        local ret="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)"
+        echo "$ret"
+    fi
+}
+
+models_path="$(get_script_path)"
+
+# Whisper models
+models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
+
+# list available models
+function list_models {
+    printf "\n"
+    printf "  Available models:"
+    for model in "${models[@]}"; do
+        printf " $model"
+    done
+    printf "\n\n"
+}
+
+if [ "$#" -ne 1 ]; then
+    printf "Usage: $0 <model>\n"
+    list_models
+
+    exit 1
+fi
+
+model=$1
+
+if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
+    printf "Invalid model: $model\n"
+    list_models
+
+    exit 1
+fi
+
+# download Core ML model
+
+printf "Downloading Core ML model $model from '$src' ...\n"
+
+cd $models_path
+
+if [ -f "ggml-$model.mlmodel" ]; then
+    printf "Model $model already exists. Skipping download.\n"
+    exit 0
+fi
+
+if [ -x "$(command -v wget)" ]; then
+    wget --quiet --show-progress -O ggml-$model.mlmodel $src/$pfx-$model.mlmodel
+elif [ -x "$(command -v curl)" ]; then
+    curl -L --output ggml-$model.mlmodel $src/$pfx-$model.mlmodel
+else
+    printf "Either wget or curl is required to download models.\n"
+    exit 1
+fi
+
+
+if [ $? -ne 0 ]; then
+    printf "Failed to download Core ML model $model \n"
+    printf "Please try again later or download the original Whisper model files and convert them yourself.\n"
+    exit 1
+fi
+
+printf "Done! Model '$model' saved in 'models/ggml-$model.mlmodel'\n"
+printf "Run the following command to compile it:\n\n"
+printf "  $ xcrun coremlc compile ./models/ggml-$model.mlmodel ./models\n\n"
+printf "You can now use it like this:\n\n"
+printf "  $ ./main -m models/ggml-$model.bin -f samples/jfk.wav\n"
+printf "\n"
diff --git a/whisper.cpp b/whisper.cpp
index d65738a3e56..75c1e2603ae 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -1,5 +1,8 @@
 #define WHISPER_BUILD
 #include "whisper.h"
+#if WHISPER_USE_COREML
+#include "coreml/whisper-encoder.h"
+#endif
 
 #include "ggml.h"
 
@@ -586,6 +589,10 @@ struct whisper_state {
 
     int lang_id = 0; // english by default
 
+#ifdef WHISPER_USE_COREML
+    whisper_coreml_context * ctx_coreml;
+#endif
+
     // [EXPERIMENTAL] token-level timestamps data
     int64_t t_beg = 0;
     int64_t t_last = 0;
@@ -636,6 +643,8 @@ struct whisper_context {
     whisper_model model;
     whisper_vocab vocab;
     whisper_state * state = nullptr;
+
+    std::string path_model; // populated by whisper_init_from_file()
 };
 
 template<typename T>
@@ -1366,6 +1375,7 @@ static bool whisper_encode_internal(
         }
     }
 
+#ifndef WHISPER_USE_COREML
     struct ggml_tensor * cur;
 
     // convolution + gelu
@@ -1597,7 +1607,7 @@ static bool whisper_encode_internal(
                         ggml_repeat(ctx0, layer.mlp_ln_w, cur),
                         cur),
                     ggml_repeat(ctx0, layer.mlp_ln_b, cur));
-    }
+            }
 
 #ifdef WHISPER_USE_FLASH_FF
             wstate.use_buf(ctx0, 0);
@@ -1637,7 +1647,7 @@ static bool whisper_encode_internal(
                 ggml_repeat(ctx0, layer.mlp_1_b, cur),
                 cur);
 #endif
-}
+        }
 
         wstate.use_buf(ctx0, 3);
 
@@ -1674,6 +1684,13 @@ static bool whisper_encode_internal(
 
         //ggml_graph_print(&gf);
     }
+#else
+    wstate.use_buf(ctx0, -1);
+
+    struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+
+    whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
+#endif
 
     // cur
     //{
@@ -1841,8 +1858,6 @@ static bool whisper_decode_internal(
 
         // self-attention
         {
-            wstate.use_buf(ctx0, 1);
-
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.attn_q_w,
                     cur);
@@ -1904,8 +1919,6 @@ static bool whisper_decode_internal(
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
-            wstate.use_buf(ctx0, 0);
-
             //struct ggml_tensor * KQ_scaled =
             //    ggml_scale(ctx0,
             //            KQ,
@@ -1914,20 +1927,16 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
 
-            wstate.use_buf(ctx0, 1);
-
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            wstate.use_buf(ctx0, 0);
-
             struct ggml_tensor * V_trans =
-                ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
-                            n_state/n_head, n_head, n_past + N),
-                        1, 2, 0, 3);
-
-            wstate.use_buf(ctx0, 1);
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
+                                n_state/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
@@ -1964,8 +1973,6 @@ static bool whisper_decode_internal(
 
             cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
 
-            wstate.use_buf(ctx0, 1);
-
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
                     ggml_mul(ctx0,
@@ -1976,8 +1983,6 @@ static bool whisper_decode_internal(
 
         // cross-attention
         {
-            wstate.use_buf(ctx0, 0);
-
             struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
                     layer.cross_attn_q_w,
                     cur);
@@ -2001,12 +2006,13 @@ static bool whisper_decode_internal(
                         ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
                         n_state/n_head, n_head, M);
 
-            struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
 
             // ------
 
-            wstate.use_buf(ctx0, 1);
-
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
                         ggml_cpy(ctx0,
@@ -2016,8 +2022,6 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
 
-            wstate.use_buf(ctx0, 0);
-
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
@@ -2030,16 +2034,10 @@ static bool whisper_decode_internal(
             // no masking for cross-attention
             //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
-            wstate.use_buf(ctx0, 1);
-
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
-            wstate.use_buf(ctx0, 0);
-
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
 
-            wstate.use_buf(ctx0, 1);
-
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
             // cur = KQV_merged.contiguous().view(n_state, N)
@@ -2477,12 +2475,25 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
 // interface implementation
 //
 
+#ifdef WHISPER_USE_COREML
+// replace .bin with .mlmodelc
+static std::string whisper_get_coreml_path(std::string path_bin) {
+    auto pos = path_bin.rfind('.');
+    if (pos != std::string::npos) {
+        path_bin = path_bin.substr(0, pos);
+    }
+
+    path_bin += ".mlmodelc";
+
+    return path_bin;
+}
+#endif
+
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
     whisper_state * state = new whisper_state;
 
     const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
 
-
     if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
         fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
         return nullptr;
@@ -2503,7 +2514,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
-
     state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
 
     state->logits_id.reserve(ctx->model.hparams.n_vocab);
@@ -2523,6 +2533,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     state->rng = std::mt19937(0);
 
+#ifdef WHISPER_USE_COREML
+    const auto path_coreml = whisper_get_coreml_path(ctx->path_model);
+
+    fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
+    fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__);
+
+    state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
+    if (!state->ctx_coreml) {
+        fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
+        return nullptr;
+    }
+
+    fprintf(stderr, "%s: Core ML model loaded\n", __func__);
+#endif
+
     return state;
 }
 
@@ -2538,6 +2563,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
     }
 
     loader.context = &fin;
+
     loader.read = [](void * ctx, void * output, size_t read_size) {
         std::ifstream * fin = (std::ifstream*)ctx;
         fin->read((char *)output, read_size);
@@ -2554,7 +2580,13 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
         fin->close();
     };
 
-    return whisper_init_no_state(&loader);
+    auto ctx = whisper_init_no_state(&loader);
+
+    if (ctx) {
+        ctx->path_model = path_model;
+    }
+
+    return ctx;
 }
 
 struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
@@ -2679,6 +2711,10 @@ void whisper_free(struct whisper_context * ctx) {
 
         whisper_free_state(ctx->state);
 
+#ifdef WHISPER_USE_COREML
+        whisper_coreml_free(ctx->state->ctx_coreml);
+        ctx->state->ctx_coreml = nullptr;
+#endif
         delete ctx;
     }
 }