Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: unexpected SPU error #766

Closed
linzzzzzz opened this issue Jul 12, 2024 · 2 comments · Fixed by #768
Closed

[Bug]: unexpected SPU error #766

linzzzzzz opened this issue Jul 12, 2024 · 2 comments · Fixed by #768

Comments

@linzzzzzz
Copy link

Issue Type

Usability

Modules Involved

SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0b1

OS Platform and Distribution

Linux

Python Version

3.10

Compiler Version

No response

Current Behavior?

Why plaintext calculation is ok while SPU simulation failed? I don't think I'm performing any sophisticated calculation.

Standalone code to reproduce the issue

A = np.array([[-4.16757847e-01, -5.62668272e-02, -2.13619610e+00,
         1.64027081e+00, -1.79343559e+00, -8.41747366e-01,
         5.02881417e-01, -1.24528809e+00, -1.05795222e+00,
        -9.09007615e-01],
       [ 5.51454045e-01,  2.29220801e+00,  4.15393930e-02,
        -1.11792545e+00,  5.39058321e-01, -5.96159700e-01,
        -1.91304965e-02,  1.17500122e+00, -7.47870949e-01,
         9.02525097e-03],
       [-8.78107893e-01, -1.56434170e-01,  2.56570452e-01,
        -9.88779049e-01, -3.38821966e-01, -2.36184031e-01,
        -6.37655012e-01, -1.18761229e+00, -1.42121723e+00,
        -1.53495196e-01],
       [-2.69056960e-01,  2.23136679e+00, -2.43476758e+00,
         1.12726505e-01,  3.70444537e-01,  1.35963386e+00,
         5.01857207e-01, -8.44213704e-01,  9.76147160e-06,
         5.42352572e-01]])

B = np.array([[-0.3135082 ,  0.77101174, -1.86809065,  1.73118467,  1.46767801,
        -0.33567734,  0.61134078,  0.04797059, -0.82913529,  0.08771022],
       [ 1.00036589, -0.38109252, -0.37566942, -0.07447076,  0.43349633,
         1.27837923, -0.63467931,  0.50839624,  0.21611601, -1.85861239],
       [-0.41931648, -0.1323289 , -0.03957024,  0.32600343, -2.04032305,
         0.04625552, -0.67767558, -1.43943903,  0.52429643,  0.73527958],
       [-0.65325027,  0.84245628, -0.38151648,  0.06648901, -1.09873895,
         1.58448706, -2.65944946, -0.09145262,  0.69511961, -2.03346655]])

C = np.array([[-0.18946926, -0.07721867,  0.82470301,  1.24821292, -0.40389227,
        -1.38451867,  1.36723542,  1.21788563, -0.46200535,  0.35088849],
       [ 0.38186623,  0.56627544,  0.20420798,  1.40669624, -1.7379595 ,
         1.04082395,  0.38047197, -0.21713527,  1.1735315 , -2.34360319],
       [ 1.16152149,  0.38607805, -1.13313327,  0.43309255, -0.30408644,
         2.58529487,  1.83533272,  0.44068987, -0.71925384, -0.58341459],
       [-0.32504963, -0.56023451, -0.90224607, -0.59097228, -0.27617949,
        -0.51688389, -0.69858995, -0.92889192,  2.55043824, -1.47317325]])

D = np.array([[-1.02141473,  0.4323957 , -0.32358007,  0.42382471,  0.79918   ,
         1.26261366,  0.75196485, -0.99376098,  1.10914328, -1.76491773],
       [-0.1144213 , -0.49817419, -1.06079904,  0.59166652, -0.18325657,
         1.01985473, -1.48246548,  0.84631189,  0.49794015,  0.12650418],
       [-1.41881055, -0.25177412, -1.54667461, -2.08265194,  3.2797454 ,
         0.97086132,  1.79259285, -0.42901332,  0.69619798,  0.69741627],
       [ 0.60151581,  0.00365949, -0.22824756, -2.06961226,  0.61014409,
         0.4234969 ,  1.11788673, -0.27424209,  1.74181219, -0.44750088]])

A_1 = np.square(A)
B_1 = B + 0.1
C_1 = np.square(C)
D_1 = D + 0.1

s_0, s_1 = A_1.shape
I = jnp.tile(jnp.arange(s_1), (s_0,1))
Z = jnp.zeros(A_1.shape, dtype=int)


def my_compare(x1, x2):

    A = x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3])
    B = x1[1]*x1[3]*x2[1]*x2[3]
    A_sign = A > 0
    B_sign = B > 0
    z = jnp.logical_xor(A_sign, B_sign)

    zz_0 = jax.lax.select(z, x2[0], x1[0])
    zz_1 = jax.lax.select(z, x2[1], x1[1])
    zz_2 = jax.lax.select(z, x2[2], x1[2])
    zz_3 = jax.lax.select(z, x2[3], x1[3])
    zz_4 = jax.lax.select(z, x2[4], x1[4])
    zz_5 = jax.lax.select(z, x2[5], x1[5])

    return [zz_0,zz_1,zz_2,zz_3,zz_4,zz_5]


fn = lambda a,b,c,d,e,f: jax.lax.reduce([a,b,c,d,e,f], [0.1,10000.0,0.1,10000.0,0,0], my_compare, [0])


### plaintext calculation
res = fn(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])

res


### SPU simulation

config = spu.RuntimeConfig(
    protocol=spu.spu_pb2.ProtocolKind.CHEETAH,
    field=spu.spu_pb2.FieldType.FM128, 
    fxp_fraction_bits=40,
)

simulator = pps.Simulator(2, config)
spu_argmax = pps.sim_jax(simulator, fn)


res = spu_argmax(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])

res

Relevant log output

### plaintext calculation
[Array(1.5507424, dtype=float32),
 Array(0.14797059, dtype=float32),
 Array(1.4832454, dtype=float32),
 Array(-0.893761, dtype=float32),
 Array(7, dtype=int32),
 Array(0, dtype=int32)]



### SPU simulation
RuntimeError                              Traceback (most recent call last)
Cell In[49], line 13
      9 simulator = pps.Simulator(2, config)
     10 spu_argmax = pps.sim_jax(simulator, fn)
---> 13 z = spu_argmax(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])
     15 z

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:168, in sim_jax.<locals>.wrapper(*args, **kwargs)
    154 executable, output = spu_fe.compile(
    155     spu_fe.Kind.JAX,
    156     fun,
   (...)
    163     copts=copts,
    164 )
    166 wrapper.pphlo = executable.code.decode("utf-8")
--> 168 out_flat = sim(executable, *args_flat)
    170 _, output_tree = jax.tree_util.tree_flatten(output)
    172 return jax.tree_util.tree_unflatten(output_tree, out_flat)

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:116, in Simulator.__call__(self, executable, *flat_args)
    110 jobs = [
    111     PropagatingThread(target=wrapper, args=(rank,))
    112     for rank in range(self.wsize)
    113 ]
    115 [job.start() for job in jobs]
--> 116 parties = [job.join() for job in jobs]
    118 outputs = zip(*parties)
    119 return [self.io.reconstruct(out) for out in outputs]

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:116, in <listcomp>(.0)
    110 jobs = [
    111     PropagatingThread(target=wrapper, args=(rank,))
    112     for rank in range(self.wsize)
    113 ]
    115 [job.start() for job in jobs]
--> 116 parties = [job.join() for job in jobs]
    118 outputs = zip(*parties)
    119 return [self.io.reconstruct(out) for out in outputs]

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:43, in PropagatingThread.join(self)
     41 super(PropagatingThread, self).join()
     42 if self.exc:
---> 43     raise self.exc
     44 return self.ret

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:36, in PropagatingThread.run(self)
     34 self.exc = None
     35 try:
---> 36     self.ret = self._target(*self._args, **self._kwargs)
     37 except BaseException as e:
     38     self.exc = e

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:105, in Simulator.__call__.<locals>.wrapper(rank)
    102     rt.set_var(executable.input_names[idx], param[rank])
    104 # run
--> 105 rt.run(executable)
    107 # do outfeed
    108 return [rt.get_var(name) for name in executable.output_names]

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/api.py:44, in Runtime.run(self, executable)
     37 def run(self, executable: spu_pb2.ExecutableProto) -> None:
     38     """Run an SPU executable.
     39 
     40     Args:
     41         executable (spu_pb2.ExecutableProto): executable.
     42 
     43     """
---> 44     return self._vm.Run(executable.SerializeToString())

RuntimeError: what: 
	[Enforce fail at libspu/kernel/hal/polymorphic.cc:195] (x.shape() == y.shape()). 
Stacktrace:
#0 spu::kernel::hlo::Greater()+0x7fe3b6c1c37b
#1 spu::device::pphlo::dispatchOp<>()+0x7fe3b6532660
#2 spu::device::pphlo::dispatchOp<>()+0x7fe3b653379a
#3 spu::device::pphlo::dispatchOp<>()+0x7fe3b6536b45
#4 spu::device::pphlo::dispatchOp<>()+0x7fe3b6537496
#5 spu::device::pphlo::dispatchOp<>()+0x7fe3b65392ed
#6 spu::device::pphlo::dispatchOp<>()+0x7fe3b653b8f4
#7 spu::device::runBlock()+0x7fe3b6693c25
#8 spu::device::runRegion()+0x7fe3b6695cb3
#9 std::_Function_handler<>::_M_invoke()+0x7fe3b651d57d
#10 spu::kernel::hlo::TreeReduce()+0x7fe3b6c400a3
#11 spu::kernel::hlo::Reduce()+0x7fe3b6c422dd
@tpppppub
Copy link
Member

Thanks for reporting. Will be fixed later.

@linzzzzzz
Copy link
Author

linzzzzzz commented Jul 12, 2024

sounds good.

I found a temporary workaround by changing the sign calculation from A_sign = A > 0 and B_sign = B > 0 to A_sign = A > x1[5] and B_sign = B > x1[5], where x1[5] is initialized to be always 0.

def my_compare(x1, x2):

    A = x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3])
    B = x1[1]*x1[3]*x2[1]*x2[3]
    A_sign = A > x1[5]
    B_sign = B > x1[5]
    z = jnp.logical_xor(A_sign, B_sign)

    zz_0 = jax.lax.select(z, x2[0], x1[0])
    zz_1 = jax.lax.select(z, x2[1], x1[1])
    zz_2 = jax.lax.select(z, x2[2], x1[2])
    zz_3 = jax.lax.select(z, x2[3], x1[3])
    zz_4 = jax.lax.select(z, x2[4], x1[4])
    zz_5 = jax.lax.select(z, x2[5], x1[5])

    return [zz_0,zz_1,zz_2,zz_3,zz_4,zz_5]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants