Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,21 @@
import shlex
import glob
import sys
import os

extensions = [Extension("geometry", sources=["src_c/geometry.c"])]

compiler_options = {"unix": ["-mavx2"], "msvc": ["/arch:AVX2"]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary. Since you're using intrinsics, the compiler isn't emitting the symbols, you are. With these options, the project will not run at all on older systems, I believe.


compiler_type = "msvc" if os.name == "nt" else "unix"


extensions = [
Extension(
"geometry",
sources=["src_c/geometry.c"],
extra_compile_args=compiler_options[compiler_type],
)
]


def build() -> None:
Expand Down
112 changes: 96 additions & 16 deletions src_c/collisions.c
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
#ifdef __AVX2__
#include <immintrin.h>
#endif

#include "include/collisions.h"
#include <stdio.h>

#ifndef ABS
#define ABS(x) ((x) < 0 ? -(x) : (x))
#endif /* ~ABS */
#ifndef DOT2D
#define DOT2D(X0, Y0, X1, Y1) ((X0) * (X1) + (Y0) * (Y1))
#endif /* ~DOT2D */

#ifndef CODE_BOTTOM
#define CODE_BOTTOM 1
#endif /* CODE_BOTTOM */
#ifndef CODE_TOP
#define CODE_TOP 2
#endif /* CODE_TOP */
#ifndef CODE_LEFT
#define CODE_LEFT 4
#endif /* CODE_LEFT */
#ifndef CODE_RIGHT
#define CODE_RIGHT 8
#endif /* CODE_RIGHT */

static int
pgCollision_LineLine(pgLineBase *A, pgLineBase *B)
{
Expand Down Expand Up @@ -204,6 +192,97 @@ static int
pgIntersection_LineRect(pgLineBase *line, SDL_Rect *rect, double *X, double *Y,
double *T)
{
#ifdef __AVX2__
// this function does 4 line-line collisions at once
double Rx = (double)rect->x;
double Ry = (double)rect->y;
double Rw = (double)rect->w;
double Rh = (double)rect->h;

// here we start to setup the variables
__m256d x1_256d = _mm256_set1_pd(line->x1);
__m256d y1_256d = _mm256_set1_pd(line->y1);
__m256d x2_256d = _mm256_set1_pd(line->x2);
__m256d y2_256d = _mm256_set1_pd(line->y2);
__m256d x3_256d = _mm256_set_pd(Rx, Rx, Rx, Rx + Rw);
__m256d y3_256d = _mm256_set_pd(Ry, Ry, Ry + Rh, Ry);
__m256d x4_256d = _mm256_set_pd(Rx + Rw, Rx, Rx + Rw, Rx + Rw);
__m256d y4_256d = _mm256_set_pd(Ry, Ry + Rh, Ry + Rh, Ry + Rh);

// here we calculate the differences between the the coords of the points
__m256d x1_m_x2_256d = _mm256_sub_pd(x1_256d, x2_256d);
__m256d y3_m_y4_256d = _mm256_sub_pd(y3_256d, y4_256d);
__m256d y1_m_y2_256d = _mm256_sub_pd(y1_256d, y2_256d);
__m256d x3_m_x4_256d = _mm256_sub_pd(x3_256d, x4_256d);

// we calculate the denominator of the equations
__m256d den_256d =
_mm256_sub_pd(_mm256_mul_pd(x1_m_x2_256d, y3_m_y4_256d),
_mm256_mul_pd(y1_m_y2_256d, x3_m_x4_256d));

// if the denominator is 0 then the line is parallel to the other line
// in this occasion this can't be true here as a line will never be
// parallel to all four sides of a rectangle
__m256d den_zero_256d =
_mm256_cmp_pd(den_256d, _mm256_setzero_pd(), _CMP_EQ_OQ);

// we dont want to cause any floating point errors by dividing by 0
// so we set the ones that are equal to 0 to 1
den_256d = _mm256_or_pd(den_zero_256d, den_256d);

// we calculate the rest of the differences between the coords of the
// points
__m256d x1_m_x3_256d = _mm256_sub_pd(x1_256d, x3_256d);
__m256d y1_m_y3_256d = _mm256_sub_pd(y1_256d, y3_256d);

// calculate the t values
__m256d t_256d = _mm256_sub_pd(_mm256_mul_pd(x1_m_x3_256d, y3_m_y4_256d),
_mm256_mul_pd(y1_m_y3_256d, x3_m_x4_256d));
t_256d = _mm256_div_pd(t_256d, den_256d);

// calculate the u values
__m256d u_256d = _mm256_sub_pd(_mm256_mul_pd(x1_m_x2_256d, y1_m_y3_256d),
_mm256_mul_pd(y1_m_y2_256d, x1_m_x3_256d));
u_256d =
_mm256_mul_pd(_mm256_div_pd(u_256d, den_256d), _mm256_set1_pd(-1.0));

// we check this condition t >= 0 && t <= 1 && u >= 0 && u <= 1
__m256d ones_256d = _mm256_set1_pd(1.0);
__m256d zeros_256d = _mm256_set1_pd(0.0);
__m256d t_zero_256d = _mm256_cmp_pd(t_256d, zeros_256d, _CMP_GE_OQ);
__m256d t_one_256d = _mm256_cmp_pd(t_256d, ones_256d, _CMP_LE_OQ);
__m256d u_zero_256d = _mm256_cmp_pd(u_256d, zeros_256d, _CMP_GE_OQ);
__m256d u_one_256d = _mm256_cmp_pd(u_256d, ones_256d, _CMP_LE_OQ);
__m256d t_u_256d = _mm256_and_pd(_mm256_and_pd(t_zero_256d, t_one_256d),
_mm256_and_pd(u_zero_256d, u_one_256d));

// if no lines touch the rectangle then this will be true
if (_mm256_movemask_pd(t_u_256d) == 0x0) {
return 0;
}

double t = DBL_MAX;

// here we know that there is at least one intersection so
// we search for the smallest t value that still meets the above conditions
int i = 0;
for (i = 0; i < 4; i++) {
if (((double *)&t_u_256d)[i]) {
t = MIN(t, ((double *)&t_256d)[i]);
}
}

// outputs
if (T)
*T = t;
if (X)
*X = line->x1 + t * (line->x2 - line->x1);
if (Y)
*Y = line->y1 + t * (line->y2 - line->y1);

return 1;
#else

double x = (double)rect->x;
double y = (double)rect->y;
double w = (double)rect->w;
Expand Down Expand Up @@ -238,6 +317,7 @@ pgIntersection_LineRect(pgLineBase *line, SDL_Rect *rect, double *X, double *Y,
}

return ret;
#endif /* ~__AVX2__ */
}

static int
Expand Down