Skip to content

Commit

Permalink
Cleanup chat cli and fix vector.c
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 19, 2024
1 parent d670e78 commit 83b9f86
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 157 deletions.
5 changes: 5 additions & 0 deletions jlama-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
<artifactId>progressbar</artifactId>
<version>0.10.0</version>
</dependency>
<dependency>
<groupId>com.diogonunes</groupId>
<artifactId>JColor</artifactId>
<version>5.5.1</version>
</dependency>
<dependency>
<groupId>org.jboss.resteasy</groupId>
<artifactId>resteasy-jaxrs</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,42 @@

import static com.github.tjake.jlama.model.ModelSupport.loadModel;

import com.diogonunes.jcolor.AnsiFormat;
import com.diogonunes.jcolor.Attribute;
import com.github.tjake.jlama.model.AbstractModel;

import java.io.PrintWriter;
import java.nio.charset.Charset;
import java.util.Optional;
import java.util.Scanner;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

import picocli.CommandLine.*;

@Command(name = "chat", description = "Interact with the specified model")
public class ChatCommand extends ModelBaseCommand {
public class ChatCommand extends BaseCommand {
private static final AnsiFormat chatText = new AnsiFormat(Attribute.CYAN_TEXT());
private static final AnsiFormat statsColor = new AnsiFormat(Attribute.BLUE_TEXT());

@Option(
names = {"-s", "--system-prompt"},
description = "Change the default system prompt for this model")
String systemPrompt =
"You are a happy demo app of a project called jlama. You answer any question then add \"Jlama is awesome!\" after.";
String systemPrompt = null;

@Option(
names = {"-t", "--temperature"},
description = "Temperature of response [0,1] (default: ${DEFAULT-VALUE})",
defaultValue = "0.6")
protected Float temperature;

@Option(
names = {"--top-p"},
description =
"Controls how many different words the model considers per token [0,1] (default: ${DEFAULT-VALUE})",
defaultValue = ".9")
protected Float topp;

@Override
public void run() {
Expand All @@ -40,13 +64,50 @@ public void run() {
Optional.ofNullable(modelQuantization),
Optional.ofNullable(threadCount));

m.generate(
UUID.randomUUID(),
m.wrapPrompt(prompt, Optional.of(systemPrompt)),
prompt,
temperature,
tokens,
true,
makeOutHandler());
UUID session = UUID.randomUUID();
PrintWriter out = System.console().writer();

out.println("Chatting with " + model + "...\n\n");
out.flush();
Scanner sc = new Scanner(System.in);
boolean first = true;
while (true) {
out.print("You: ");
out.flush();
String prompt = sc.nextLine();
out.println();
out.flush();
if (prompt.isEmpty()) {
break;
}
String wrappedPrompt = m.wrapPrompt(prompt, first ? Optional.ofNullable(systemPrompt) : Optional.empty());
m.generate(
session,
wrappedPrompt,
prompt,
temperature,
Integer.MAX_VALUE,
true,
makeOutHandler());

first = false;
}
}

protected BiConsumer<String, Float> makeOutHandler()
{
PrintWriter out;
BiConsumer<String, Float> outCallback;

out = System.console().writer();
out.print(chatText.format("Jlama: "));
out.flush();
outCallback = (w, t) -> {
out.print(chatText.format(w));
out.flush();
};

return outCallback;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ protected AbstractModel(
this.modelQType = modelQType;
this.kvBufferCache = new KvBufferCache(this);

// FIXME: This is a hack to support Avoid Q8F32 evals
if (modelDType == DType.F32 && workingMemoryQType != DType.F32) {
workingMemoryQType = DType.F32;
}

if (workingMemoryQType != workingMemoryDType) {
boolean supportsQType;
AbstractTensor tmp = makeTensor(Q8ByteBufferTensor.BLOCK_SIZE);
Expand Down Expand Up @@ -275,7 +280,7 @@ public void generate(
Integer startPos = (Integer) kvmem.getMetadata(KvBufferCache.TOKEN_COUNT); // Number of tokens in the buffer
if (startPos == null) startPos = 0;

logger.info("Starting at token {} for session {}", startPos, sessionId);
logger.debug("Starting at token {} for session {}", startPos, sessionId);

if (ntokens > c.contextLength) ntokens = c.contextLength;

Expand All @@ -293,8 +298,6 @@ public void generate(
promptLength++;
}

String clientPrompt = cleanPrompt == null ? prompt : cleanPrompt;
onTokenWithTimings.accept(clientPrompt, 0f);
long start = System.currentTimeMillis();
// Batch Process Prompt
AbstractTensor last = batchForward(promptTokens, startPos, kvmem);
Expand Down Expand Up @@ -327,8 +330,10 @@ public void generate(

if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature);
output.close();

// Model may tell us it's done
if (next == c.eosToken) break;

kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i);

try {
Expand Down
149 changes: 6 additions & 143 deletions jlama-native/src/main/c/vector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void __attribute__((noinline)) gemm_q8_q4_128_arm(int m0, int m, int n0, int n,
int ii = m0 + job / xtiles * RM;
int jj = n0 + job % xtiles * RN;

float32x4_t sums[RM][RN] __attribute__((aligned(32)));
float32x4_t sums[RM][RN];

//Reset the sums to zero for this tile
for (int i = 0; i < RM; i++) {
Expand Down Expand Up @@ -228,143 +228,6 @@ void __attribute__((noinline)) gemm_q8_q4_128_arm(int m0, int m, int n0, int n,
}
}
#else
void __attribute__((noinline)) gemm_q8_q4_128(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) {
int ytiles = (m - m0) / RM;
int xtiles = (n - n0) / RN;
int tiles = xtiles * ytiles;
// Mask to keep the first 4 bits of each byte
__m128i mask_first_4bits = _mm_set1_epi8(0xF);
//Subtract 8 from each byte to get signed values
__m128i eight = _mm_set1_epi8(8);
int numBlocks = params.k / Q4_BLOCK_SIZE;

// This fits on the stack (max of 5x5)
__attribute__((aligned(16))) float scalef[4];
for (int job = 0; job < tiles; ++job) {

int ii = m0 + job / xtiles * RM;
int jj = n0 + job % xtiles * RN;

__m128 sums[RM][RN] __attribute__((aligned(32)));

//Reset the sums to zero for this tile
for (int i = 0; i < RM; i++) {
for (int j = 0; j < RN; j++) {
sums[i][j] = _mm_setzero_ps();
}
}

for (int ni = 0; ni < RN; ++ni) {
int ao = params.aoffset;
int bo = params.boffset;

for (int i = 0; i < numBlocks; i += 4) { //128bits == 4floats
int aoo = ao;
int boo = bo;

for (int mi = 0; mi < RM; ++mi) {
ao = aoo;
bo = boo;

// Load float32
__m128 ablock = _mm_loadu_ps(params.af + (params.ldaf * (ii + mi) + (ao / Q4_BLOCK_SIZE)));
__m128 bblock = _mm_loadu_ps(params.bf + (params.ldbf * (jj + ni) + ((bo*2) / Q4_BLOCK_SIZE)));
__m128 scaled = _mm_mul_ps(ablock, bblock);
_mm_store_ps(scalef, scaled);

for(int j = 0; j < 4; j++, ao += 32, bo += 16) {
// Load 4 bytes into a 128-bit integer register
__m128i int_va0 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao)));
__m128i int_va1 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4)));
__m128i int_va2 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4 + 4)));
__m128i int_va3 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4 + 4 + 4)));
__m128i int_va4 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4 + 4 + 4 + 4)));
__m128i int_va5 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4 + 4 + 4 + 4 + 4)));
__m128i int_va6 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4 + 4 + 4 + 4 + 4 + 4)));
__m128i int_va7 = _mm_cvtepi8_epi32(_mm_loadu_si32((__m128i const*)(params.a + params.lda * (ii + mi) + ao + 4 + 4 + 4 + 4 + 4 + 4 + 4)));

// Load 8 bytes into a 128-bit integer register
__m128i int_vb0 = _mm_loadu_si32((__m128i const*)(params.b + params.ldb * (jj + ni) + bo));
__m128i int_vb1 = _mm_loadu_si32((__m128i const*)(params.b + params.ldb * (jj + ni) + bo + 4));
__m128i int_vb2 = _mm_loadu_si32((__m128i const*)(params.b + params.ldb * (jj + ni) + bo + 4 + 4));
__m128i int_vb3 = _mm_loadu_si32((__m128i const*)(params.b + params.ldb * (jj + ni) + bo + 4 + 4 + 4));

// Masked values
__m128i first_4bits0 = _mm_and_si128(int_vb0, mask_first_4bits);
__m128i first_4bits1 = _mm_and_si128(int_vb1, mask_first_4bits);
__m128i first_4bits2 = _mm_and_si128(int_vb2, mask_first_4bits);
__m128i first_4bits3 = _mm_and_si128(int_vb3, mask_first_4bits);

// Shift first 4 bits to rightmost positions
__m128i last_4bits0 = _mm_srli_epi16(int_vb0, 4);
__m128i last_4bits1 = _mm_srli_epi16(int_vb1, 4);
__m128i last_4bits2 = _mm_srli_epi16(int_vb2, 4);
__m128i last_4bits3 = _mm_srli_epi16(int_vb3, 4);

last_4bits0 = _mm_and_si128(last_4bits0, mask_first_4bits);
last_4bits1 = _mm_and_si128(last_4bits1, mask_first_4bits);
last_4bits2 = _mm_and_si128(last_4bits2, mask_first_4bits);
last_4bits3 = _mm_and_si128(last_4bits3, mask_first_4bits);

//Subtract 8 from each int
first_4bits0 = _mm_sub_epi8(first_4bits0, eight);
first_4bits1 = _mm_sub_epi8(first_4bits1, eight);
first_4bits2 = _mm_sub_epi8(first_4bits2, eight);
first_4bits3 = _mm_sub_epi8(first_4bits3, eight);

last_4bits0 = _mm_sub_epi8(last_4bits0, eight);
last_4bits1 = _mm_sub_epi8(last_4bits1, eight);
last_4bits2 = _mm_sub_epi8(last_4bits2, eight);
last_4bits3 = _mm_sub_epi8(last_4bits3, eight);

// Extend bytes to 32-bit integers
__m128i int_vb_ext_lo0 = _mm_cvtepi8_epi32(first_4bits0);
__m128i int_vb_ext_lo1 = _mm_cvtepi8_epi32(first_4bits1);
__m128i int_vb_ext_lo2 = _mm_cvtepi8_epi32(first_4bits2);
__m128i int_vb_ext_lo3 = _mm_cvtepi8_epi32(first_4bits3);

__m128i int_vb_ext_hi0 = _mm_cvtepi8_epi32(last_4bits0);
__m128i int_vb_ext_hi1 = _mm_cvtepi8_epi32(last_4bits1);
__m128i int_vb_ext_hi2 = _mm_cvtepi8_epi32(last_4bits2);
__m128i int_vb_ext_hi3 = _mm_cvtepi8_epi32(last_4bits3);

__m128i isum = _mm_mullo_epi32(int_va0, int_vb_ext_lo0);
isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va1, int_vb_ext_lo1));
isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va2, int_vb_ext_lo2));
isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va3, int_vb_ext_lo3));

isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va4, int_vb_ext_hi0));
isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va5, int_vb_ext_hi1));
isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va6, int_vb_ext_hi2));
isum = _mm_add_epi32(isum, _mm_mullo_epi32(int_va7, int_vb_ext_hi3));

// broadcast the float32 version of 'factor' to all elements
__m128 vb_f32 = _mm_set1_ps(scalef[j]);
// Convert these 32-bit integers to floats
__m128 fsum = _mm_cvtepi32_ps(isum);
sums[mi][ni] = _mm_add_ps(sums[mi][ni], _mm_mul_ps(fsum, vb_f32));
}
}
}
}


for (int mi = 0; mi < RM; ++mi) {
for (int ni = 0; ni < RN; ++ni) {
__attribute__((aligned(16))) float result[4];
_mm_store_ps(result, sums[mi][ni]);

float dot = 0.0;
for(int i = 0; i < 4; ++i) {
dot += result[i];
}
//fprintf(stderr, "ii: %d, ni: %d, jj: %d, mi: %d, ldc: %d\n", ii, ni, jj, mi, params.ldc);
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
}

void __attribute__((noinline)) gemm_q8_q4_256(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) {
int ytiles = (m - m0) / RM;
int xtiles = (n - n0) / RN;
Expand All @@ -382,7 +245,7 @@ void __attribute__((noinline)) gemm_q8_q4_256(int m0, int m, int n0, int n, int
int ii = m0 + job / xtiles * RM;
int jj = n0 + job % xtiles * RN;

__m256 sums[RM][RN] __attribute__((aligned(32)));
__m256 sums[RM][RN];

//Reset the sums to zero for this tile
for (int i = 0; i < RM; i++) {
Expand Down Expand Up @@ -470,7 +333,7 @@ void __attribute__((noinline)) gemm_q8_q4_512(int m0, int m, int n0, int n, int
int ii = m0 + job / xtiles * RM;
int jj = n0 + job % xtiles * RN;

__m256 sums[RM][RN] __attribute__((aligned(32)));
__m256 sums[RM][RN];

//Reset the sums to zero for this tile
for (int i = 0; i < RM; i++) {
Expand Down Expand Up @@ -591,7 +454,7 @@ void gemm_f32_128_arm(int m0, int m, int n0, int n, int RM, int RN, struct gemm_
int tiles = xtiles * ytiles;

// This fits on the stack (max of 5x5)
float32x4_t sums[RM][RN] __attribute__((aligned(32)));
float32x4_t sums[RM][RN];

for (int job = 0; job < tiles; ++job) {
int ii = m0 + job / xtiles * RM;
Expand Down Expand Up @@ -636,7 +499,7 @@ void gemm_f32_256(int m0, int m, int n0, int n, int RM, int RN, struct gemm_para
int tiles = xtiles * ytiles;

// This fits on the stack (max of 5x5)
__m256 sums[RM][RN] __attribute__((aligned(32)));
__m256 sums[RM][RN];

for (int job = 0; job < tiles; ++job) {
int ii = m0 + job / xtiles * RM;
Expand Down Expand Up @@ -688,7 +551,7 @@ void gemm_f32_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_para
int tiles = xtiles * ytiles;

// This fits on the stack (max of 5x5)
__m512 sums[RM][RN] __attribute__((aligned(32)));
__m512 sums[RM][RN];

for (int job = 0; job < tiles; ++job) {
int ii = m0 + job / xtiles * RM;
Expand Down

0 comments on commit 83b9f86

Please sign in to comment.