In [1]:
from tinygrad import TinyJit
import time
import os
os.environ["METAL"] = "1"
os.environ["NOOPT"] = "1"
os.environ["METAL_XCODE"] = "1"
os.environ["TINYGRAD_DEBUG"] = "5"
os.environ["DISABLE_COMPILER_CACHE"] = "1"
import numpy as np
import mlx.core as mx
from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.helpers import Timing, Context
import mlx.core as mx
from mlx import nn as mlx_nn
from tinygrad import nn
import math

In [None]:
class MLXQuantizedLinear:
  def __init__(self, in_features, out_features, bits=4, group_size=64, bias=False):
    self.weight = Tensor.randint((4096, 512), low=0, high=9, dtype=dtypes.uint32).realize()
    self.scales = Tensor.rand(4096, 64, dtype=dtypes.half).realize()
    self.biases = Tensor.rand(4096, 64, dtype=dtypes.half).realize()
    self.bits = bits
    self.group_size = group_size

  def __call__(self, x):
    w_full = Tensor.cat(
        *[select_bits(self.weight, self.bits, i)[..., None] for i in range(0, 32, self.bits)], dim=-1
    )
    print(w_full.shape)
    w_full = w_full.reshape(len(self.weight), self.scales.shape[-1], -1)
    w_full = self.scales[..., None] * w_full + self.biases[..., None]
    return x.linear(w_full.reshape(len(self.weight), -1).T)

def select_bits(w, bits, start):
    shift_left = 32 - (start + bits)
    shift_right = shift_left + start
    return (w * (2**shift_left)) // (2**shift_right)

In [2]:
y = Tensor.rand(1, 1, 4096)
# tiny = nn.Linear(4096, 4096, bias=False)
# with Context(DEBUG=4, NOOPT=1):
#     ll = tiny(y).realize()

In [3]:
weight = Tensor.uniform(4096, 4096)
with Context(DEBUG=4, NOOPT=1):
    ll = y.matmul(weight).realize()

opened device METAL from pid:55515
[32m*** CUSTOM     1[0m custom_random                            arg  1 mem  0.00 GB 
[32m*** CUSTOM     2[0m custom_random                            arg  1 mem  0.07 GB 
r_[34m4096[0m[90m_[0m[31m4096[0m[90m[0m
UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=0, upcasted=0, dont_use_locals=False), src=(
  UOp(UOps.STORE, dtypes.void, arg=None, src=(
    UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
      UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
        UOp(UOps.LOAD, dtypes.float, arg=None, src=(
          UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
          UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 4096), strides=(0, 1), offs

In [4]:
i, j, k = y.shape
k, m = weight.shape
C = Tensor.zeros(i*j*m).contiguous()
with Context(DEBUG=4, NOOPT=1):
    for i_idx in range(i):
        for j_idx in range(j):
            for m_idx in range(m):
                for k_idx in range(k):
                    C[i_idx * j * m + j_idx * m + m_idx] += y[i_idx][j_idx][k_idx] * weight[k_idx][m_idx]
    C.realize()

assign <LB METAL () float ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))> <- <LB METAL () float (<BinaryOps.ADD: 8>, None)>
E_[34m4096[0m[90m[0m
UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=0, upcasted=0, dont_use_locals=False), src=(
  UOp(UOps.STORE, dtypes.void, arg=None, src=(
    UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
      UOp(UOps.VALID, dtypes.bool, arg=None, src=(
        UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),
      x6:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=()),
       x6,)),)),))
[]
#include <metal_stdlib>
using namespace metal;
kernel void E_4096(device float* data0, uint3 gid [[

error: write on a pipe with no reader


KeyboardInterrupt: 

In [None]:
x = Tensor.rand(1, 1, 4096)
mlx = MLXQuantizedLinear(4096, 4096)
with Context(DEBUG=4, NOOPT=1):
    ll = mlx(x).realize()

In [None]:
ll = Tensor.zeros((10, 10))
ll.contiguous()

In [None]:
ll.lazydata.lbs[0].contiguous

In [5]:
weight.shape

(4096, 4096)

In [7]:
a = Tensor([[0, 1], [2, 3]])
b = Tensor([[0, 1], [2, 3]])
a.matmul(b).numpy()

array([[ 2,  3],
       [ 6, 11]], dtype=int32)

In [16]:
a.reshape(*(),*[1],*[2, 2]).numpy()

array([[[0, 1],
        [2, 3]]], dtype=int32)

In [21]:
a = a.reshape(2, 1, 2)
b = b.reshape(1, 2, 2).transpose(-1, -2)
(a*b).sum(-1).numpy()

array([[ 2,  3],
       [ 6, 11]], dtype=int32)