# NA in cuDF UDFs
cuDFs design decision to store null informtion in bitmasks is really smart, and makes things very performant and tractable in memory bound circumstances such as GPU operations. However this design when coupled with the natural inefficiencies that arise from any kind of serial iteration over our data has made `<NA>` support in general user defined functions hard to solution for. This notebook offers an approach based on jitting a UDF's arguments as a special custom Numba type to produce a generic PTX function. This function is than inlined into a general kernel in libcudf and passed the relevant data and masks inside of libcudf.  

#### Problem setup: concrete example

In [1]:
import pandas as pd
import numpy as np
import cudf

In [2]:
df = cudf.DataFrame({
    'x': [1, None, 3],
    'y': [1, 2, None]
})
df.head()

Unnamed: 0,x,y
0,1.0,1.0
1,,2.0
2,3.0,


Consider the following UDF on two variables adapted from https://docs.rapids.ai/api/cudf/stable/guide-to-udfs.html. This API is fairly different from the pandas API, for several reasons:
- In cuDF, We need to write a loop over arrays in classic numba syntax
- In cuDF, the function returns into an output column we provide as an argument
- The result is different!

In [3]:
def pandas_add(x, y):
    if x is not pd.NA and x < 2:
        return x + y
    else:
        return x
    
pandas_df = df.to_pandas(nullable=True)
pandas_df['out'] = pandas_df.apply(lambda row: pandas_add(row['x'], row['y']), axis=1)
pandas_df.head()

Unnamed: 0,x,y,out
0,1.0,1.0,2.0
1,,2.0,
2,3.0,,3.0


In [4]:
def gpu_add(x, y, out):
    for i, (xi, yi) in enumerate(zip(x, y)):
        if xi < 2:
            out[i] = xi + yi
        else:
            out[i] = xi

Problem: The null mask of `y` needs to only be considered if `x > 0`. But it

In [5]:
df = df.apply_rows(gpu_add,
              incols=['x', 'y'],
              outcols={'out':np.float64},
              kwargs={})
df.head()

Unnamed: 0,x,y,out
0,1.0,1.0,2.0
1,,2.0,
2,3.0,,


We also don't support comparing `cudf.NA` in any of our UDFs, in any way.

In [6]:
def gpu_add_error(x, y, out):
    for i, (xi, yi) in enumerate(zip(x, y)):
        if xi is pd.NA:
            return 5
        else:
            return xi + yi

In [7]:
df = df.apply_rows(gpu_add_error,
              incols=['x', 'y'],
              outcols={'out':np.float64},
              kwargs={})
df.head()

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1mNo implementation of function Function(<numba.cuda.compiler.DeviceFunctionTemplate object at 0x7f7cb83dca90>) found for signature:
 
 >>> gpu_add_error <CUDA device function>(array(int64, 1d, A), array(int64, 1d, A), array(float64, 1d, A))
 
There are 2 candidate implementations:
[1m  - Of which 2 did not match due to:
  Overload in function 'gpu_add_error <CUDA device function>': File: ../../../../../../ipynb/<ipython-input-6-e11feefd7c2c>: Line 1.
    With argument(s): '(array(int64, 1d, A), array(int64, 1d, A), array(float64, 1d, A))':[0m
[1m   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   [1m[1mUnknown attribute 'NA' of type Module(<module 'pandas' from '/home/nfs/brmiller/anaconda3/envs/cudf_dev/lib/python3.7/site-packages/pandas/__init__.py'>)
   [1m
   File "<ipython-input-6-e11feefd7c2c>", line 3:[0m
   [1mdef gpu_add_error(x, y, out):
       <source elided>
       for i, (xi, yi) in enumerate(zip(x, y)):
   [1m        if xi is pd.NA:
   [0m        [1m^[0m[0m
   [0m
   [0m[1mDuring: typing of get attribute at <ipython-input-6-e11feefd7c2c> (3)[0m
   [1m
   File "<ipython-input-6-e11feefd7c2c>", line 3:[0m
   [1mdef gpu_add_error(x, y, out):
       <source elided>
       for i, (xi, yi) in enumerate(zip(x, y)):
   [1m        if xi is pd.NA:
   [0m        [1m^[0m[0m
[0m
  raised from /home/nfs/brmiller/anaconda3/envs/cudf_dev/lib/python3.7/site-packages/numba/core/typeinfer.py:1071
[0m
[0m[1mDuring: resolving callee type: Function(<numba.cuda.compiler.DeviceFunctionTemplate object at 0x7f7cb83dca90>)[0m
[0m[1mDuring: typing of call at <string> (8)
[0m
[1m
File "<string>", line 8:[0m
[1m<source missing, REPL/exec in use?>[0m


#### Why
This is because nulls are generally handled "pessimistically", which roughly means the value of the output mask bit is just set to be a big `or` between all the input column's bitmasks. This isn't a problem in pandas, because the UDF is applied by looping through the rows and individually passing each value elementwise through the UDF. When it encounters a null, the value that gets passed is `pd.NA`, which behaves the way it needs to for the function to return the correct value for that row. 
#### The two things we want to do then are:
- Make the API feel a little more natural
- explicitly be able to handle nulls in a dynamic way

# Detour: The cuDF UnaryOp Compilation Pipeline

```
Python Function -> Numba -> PTX Code -> libcudf parser -> inlineable function -> Jitify -> Execution
                                                                                    |                                         
                                                             data pointers ---------^
                                                             headers ---------------^
                                                             extra kernel code -----^
```

The proposed solution to this problem draws heavily on the existing concepts in cuDF's unaryop machinery. This is a situation where the API feels really natural and is quite compatible with pandas, even though ours is named `applymap` and theirs is named `apply` for some reason.

In [8]:
x = cudf.Series([1, None, 3])

def f(x):
    return x + 1

In [9]:
x.applymap(f)

0       2
1    <NA>
2       4
dtype: int64

In [10]:
x_pd = x.to_pandas(nullable=True)
x_pd

0       1
1    <NA>
2       3
dtype: Int64

In [11]:
x_pd.apply(f)

0       2
1    <NA>
2       4
dtype: object

The null handling here is pretty simple - it's always a copy of the original bitmask. But that's not why we're here. Let's pop the hood. From https://github.com/rapidsai/cudf/blob/branch-0.19/python/cudf/cudf/core/column/numerical.py#L721-L726:

```
def _numeric_column_unaryop(operand: ColumnBase, op: str) -> ColumnBase:
    if callable(op):
        return libcudf.transform.transform(operand, op)

    op = libcudf.unary.UnaryOp[op.upper()]
    return libcudf.unary.unary_operation(operand, op)
```

From here the `transform` cython picks up the callable python function as well as the `Column` to which it is to be applied. Here's some pseudocode for what happens inside it:

```
def transform(Column input, op):
    signature = get_signature(input)
    compiled_op = cudautils.compile_udf(op, signature)
    c_str = compiled_op[0].encode('UTF-8')

    c_output = move(
                 libcudf_transform(
                    input,
                    c_str
                 )

    )

    return Column.from_unique_ptr(move(c_output))
```

What we have so far then is:
1. Our input column
2. A PTX function compiled by Numba based off the python function, and the type of the argument being passed

`cudautils.compile_udf` calls out to Numba to transform the pure python function into PTX code through LLVM IR and a series of compilation steps. The pipeline inside Numba goes something like:

```
Python function -> python bytecode -> type inference -> lowering -> LLVM IR -> PTX code
```

#### Python bytecode: These are instructions for the python interpreter

In [12]:
import dis
dis.dis(f)

  4           0 LOAD_FAST                0 (x)
              2 LOAD_CONST               1 (1)
              4 BINARY_ADD
              6 RETURN_VALUE


#### Type inference: Assembly level languages only operate in terms of primitive types. 

Thus to generate LLVM IR, Numba needs to know the types of every variable at every point during the function, from arguments to return values. This is one of the reasons `signature` is a required arg to `compile_udf`. We only get this information at runtime, because the user can pass anything into their UDF. When they apply their UDF to a `Series`, it's only at that point Numba can know that the `x` in `f(x)` is of type `int64` for instance - and only then can it actually complete the type inference portion of the process.

#### Lowering
Once type inference is complete and Numba knows the types of all the input, output, and intermediate variables, it combines that with the algorithmic information from the python function's bytecode and produces LLVM IR in a process called "lowering". LLVM IR is like a platform independent assembly language. One can compile from LLVM IR to assembly code for any platform, including into PTX code for NVIDIA GPUs


#### What does this mean for us?
It means that what we get out of `cudautils.compile_udf` is an actual string containing a PTX function, compiled by Numba for arguments of the type `input.dtype`. It is important to note that this function is a function that operates, like the original function, on a single element. It does NOT contain a kernel. In fact, here's exactly what it is:

In [13]:
from cudf.utils.cudautils import compile_udf

In [14]:
from numba.np import numpy_support
numba_type = numpy_support.from_dtype(np.dtype('int64'))
ptx, _ = compile_udf(f, (numba_type,))

print(ptx)

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-27506705
// Cuda compilation tools, release 10.2, V10.2.89
// Based on LLVM 3.4svn
//

.version 6.5
.target sm_70
.address_size 64

	// .globl	_ZN8__main__5f$248Ex
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__5f$248Ex;

.visible .func  (.param .b32 func_retval0) _ZN8__main__5f$248Ex(
	.param .b64 _ZN8__main__5f$248Ex_param_0,
	.param .b64 _ZN8__main__5f$248Ex_param_1
)
{
	.reg .b32 	%r<2>;
	.reg .b64 	%rd<4>;


	ld.param.u64 	%rd1, [_ZN8__main__5f$248Ex_param_0];
	ld.param.u64 	%rd2, [_ZN8__main__5f$248Ex_param_1];
	add.s64 	%rd3, %rd2, 1;
	st.u64 	[%rd1], %rd3;
	mov.u32 	%r1, 0;
	st.param.b32	[func_retval0+0], %r1;
	ret;
}


 


```
// .globl main

.visible .func  (.param .b32 return_value) main(
	.param .b64 param_0,    # TWO input parameters
	.param .b64 param_1      
)
{
	.reg .b32 	%r<2>;      # declare two 32-bit registers, named %r1 and %r2
	.reg .b64 	%rd<4>;     # declare 4 64-bit registers named %rd1, %rd2, %rd3, %rd4


	ld.param.u64 	%rd1, [param_0];       # load param_0 into %rd1
	ld.param.u64 	%rd2, [param_1];       # load param_1 into %rd2
	add.s64 	%rd3, %rd2, 1;             # take the value of %rd2 (e.g. param_1) add 1, place in %rd3
	st.u64 	[%rd1], %rd3;                  # store the value of %rd3 into the location pointer to by %rd1
	mov.u32 	%r1, 0;                    # move 0 into %r1
	st.param.b32	[func_retval0+0], %r1; # place the value of %r1 into the return value
	ret;                                   # return 0
}
```

# What happens next?
Libcudf takes it from here. Broadly speaking, what happens at this point is libcudf hacks together a string (which consists of several elements) which ends up being handed off to jitify and compiled into a final kernel. Jitify then launches that kernel, taking the pointer to the beginning of the actual data column to be transformed as an argument. The three elements are:

1. A header
2. A outer "calling" kernel that generically calls the PTX function
3. A processed version of the PTX function that inlines it directly into CUDA

The libcudf parser essentially takes the PTX function as above and turns it into a generically callable inlinable function. The final file that gets passed off to jitify to be compiled looks like this:

```

#pragma once

// Include Jitify's cstddef header first
#include <cstddef>

#include <cuda/std/climits>
#include <cuda/std/cstddef>
#include <cuda/std/limits>
#include <cudf/types.hpp>
#include <cudf/wrappers/timestamps.hpp>
#include <cudf/utilities/bit.hpp>

template <typename TypeOut, typename TypeIn>
    __global__
    void kernel(cudf::size_type size,
                    TypeOut* out_data, TypeIn* in_data) {
        int tid = threadIdx.x;
        int blkid = blockIdx.x;
        int blksz = blockDim.x;
        int gridsz = gridDim.x;

        int start = tid + blkid * blksz;
        int step = blksz * gridsz;

        for (cudf::size_type i=start; i<size; i+=step) {
          GENERIC_UNARY_OP(&out_data[i], in_data[i]);  
        }
    }


__device__ __inline__ void GENERIC_UNARY_OP (
  int64_t* _ZN8__main__5f_241Ex_param_0, 
  long int _ZN8__main__5f_241Ex_param_1
){

 asm volatile ("{");  asm volatile ("  .reg .b32 _r<2>;");
   /**   .reg .b32 	%r<2>  */
  asm volatile ("  .reg .b64 _rd<4>;");
   /**   .reg .b64 	%rd<4>  */
  asm volatile ("  mov.u64 _rd1,  %0;": : "l"(_ZN8__main__5f_241Ex_param_0));
   /**   ld.param.u64 	%rd1, [_ZN8__main__5f$241Ex_param_0]
  asm volatile ("  mov.u64 _rd2,  %0;": : "l"(_ZN8__main__5f_241Ex_param_1));
   /**   ld.param.u64 	%rd2, [_ZN8__main__5f$241Ex_param_1] 
  asm volatile ("  add.s64 _rd3, _rd2, 1;");
   /**   add.s64 	%rd3, %rd2, 1  */
  asm volatile ("  st.u64 [_rd1], _rd3;");
   /**   st.u64 	[%rd1], %rd3  */
  asm volatile ("  mov.u32 _r1, 0;");
   /**   mov.u32 	%r1, 0  */
  asm volatile (" /** *** SNIP. *** */");
   /**   st.param.b32	[func_retval0+0], %r1  */
  asm volatile ("bra RETTGT;");
 asm volatile ("RETTGT:}");

}

```

Apart from returning the data back to the user, that's more or less the process. 

# What does all this have to do with NAs?
The pipeline is based on the idea that anything you could want to do with a single value of `x` arithmatically is expressable as a generic PTX function of `x` as along with some type information, that numba can generate for you. The rest of the machinery is just meant to deliver the data to this function threadwise. We're going to extend this concept to a function of four variables instead of one: a masked binary operation `x + y` where the four arguments are:

1. `x`
2. `y`
3. `x.mask`
4. `y.mask`


We're going to modify the general kernel that calls `GENERIC_UNARY_OP` and generalizes it to accept these four arguments and call a `GENERIC_BINARY_OP` instead (with two extra arguments - the mask bools)

# Creating a Numba extension type

Remember how Numba produces PTX code from a python function and some type information? We're going to create a new Type in Numba that is build around a Struct:

```
struct Masked {
    int64_t value;
    bool valid;
}
```
And we're going to add an overload of `add` (`+`, `operator.add`) to Numba's registry of function signatures that correctly handles null semantics. Then we're going to JIT the incoming python function and use a `Masked` type for every argument. 

#### Tell Numba that a `MaskedType` exists, and not much else

In [15]:
from numba.core.extending import types
class MaskedType(types.Type):
    # A corresponding MaskedType for numba
    # numba can only generate LLVM IR for things
    # that it recognizes. This is the most basic
    # thing needed for numba to recognize the type,
    # all it really says is "there's a type, 
    # called MaskedType". name is for __repr__
    def __init__(self):
        super().__init__(name="Masked")
        
numba_masked = MaskedType()

In [16]:
from numba.core.extending import make_attribute_wrapper

make_attribute_wrapper(MaskedType, "value", "value")
make_attribute_wrapper(MaskedType, "valid", "valid")

#### Tell Numba what this type looks like. In our case, it's a struct. 

In [17]:
from numba.core.extending import register_model, models

@register_model(MaskedType)
class MaskedModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [("value", types.int64), ("valid", types.bool_)]
        models.StructModel.__init__(self, dmm, fe_type, members)

#### Register an overload of `operator.add` with Numba's registry of `CUDA` functions. 
This is part of the typing phase. When we pass `f(x, y): return x + y` into Numba and say that `x` and `y` are of type `Masked`, it hits the `x + y` statement and goes looking for an overload of `add` with a signature matching those operands. It works by either finding a match and the end or not having one. This piece of code conditionally emits the signature it needs to find, when prompted with two arguments of type `Masked`. One can see how they might dynamically return different types depending on arguments. But this roughly says "when Numba looks for an overload of `add` that takes two `Masked` as arguments, let it know that there is one, and it will return a `Masked`. 

In [18]:
from numba.cuda.cudadecl import registry as cuda_registry
import operator
from numba.core.typing.templates import AbstractTemplate


@cuda_registry.register_global(operator.add)
class MaskedScalarAdd(AbstractTemplate):
    # abstracttemplate vs concretetemplate
    def generic(self, args, kws):
        if isinstance(args[0], MaskedType) and isinstance(args[1], MaskedType):
            return signature(numba_masked, numba_masked, numba_masked)

#### Implement Masked + Masked
So far, Numba knows:
- There's a `MaskedType`. 
- There's an overload of `operator.add` that accepts two `MaskedType` and returns a `MaskedType`

Now it essentially needs an implementation for that overload of `operator.add`.

In [19]:
from numba.cuda.cudaimpl import lower as cuda_lower

@cuda_lower(operator.add, MaskedType, MaskedType)
def masked_scalar_add_impl(context, builder, sig, args):
    # get the types from the signature
    masked_type_1, masked_type_2 = sig.args
    masked_return_type = sig.return_type

    # create LLVM IR structs
    m1 = cgutils.create_struct_proxy(masked_type_1)(
        context, builder, value=args[0]
    )
    m2 = cgutils.create_struct_proxy(masked_type_2)(
        context, builder, value=args[1]
    )
    result = cgutils.create_struct_proxy(masked_return_type)(context, builder)

    valid = builder.and_(m1.valid, m2.valid)
    result.valid = valid
    with builder.if_then(valid):
        result.value = builder.add(m1.value, m2.value)

    return result._getvalue()

# Testing it Out

In [None]:
from numba import cuda
def compile_masked(func):
    signature = (numba_masked, numba_masked)
    ptx, _ = cuda.compile_ptx_for_current_device(func, signature, device=True)
    return ptx

In [None]:
def f(x, y):
    return x + y

In [None]:
#ptx = compile_masked(f)

```
        // .globl       _ZN8__main__6f$2411E6Masked6Masked                                                                                                                                                                                                                   
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__6f$2411E6Masked6Masked;                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                             
.visible .func  (.param .b32 func_retval0) _ZN8__main__6f$2411E6Masked6Masked(                                                                                                                                                                                               
        .param .b64 _ZN8__main__6f$2411E6Masked6Masked_param_0,                                                                                                                                                                                                              
        .param .b64 _ZN8__main__6f$2411E6Masked6Masked_param_1,                                                                                                                                                                                                              
        .param .b32 _ZN8__main__6f$2411E6Masked6Masked_param_2,                                                                                                                                                                                                              
        .param .b64 _ZN8__main__6f$2411E6Masked6Masked_param_3,                                                                                                                                                                                                              
        .param .b32 _ZN8__main__6f$2411E6Masked6Masked_param_4                                                                                                                                                                                                               
)                                                                                                                                                                                                                                                                            
{                                                                                                                                                                                                                                                                            
        .reg .pred      %p<4>;                                                                                                                                                                                                                                               
        .reg .b16       %rs<4>;                                                                                                                                                                                                                                              
        .reg .b32       %r<2>;                                                                                                                                                                                                                                               
        .reg .b64       %rd<6>;                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                             
        ld.param.u64    %rd1, [_ZN8__main__6f$2411E6Masked6Masked_param_0];                                                                                                                                                                                                  
        ld.param.u64    %rd2, [_ZN8__main__6f$2411E6Masked6Masked_param_1];                                                                                                                                                                                                  
        ld.param.u64    %rd3, [_ZN8__main__6f$2411E6Masked6Masked_param_3];                                                                                                                                                                                                  
        ld.param.u8     %rs1, [_ZN8__main__6f$2411E6Masked6Masked_param_2];                                                                                                                                                                                                  
        setp.ne.s16     %p1, %rs1, 0;                                                                                                                                                                                                                                        
        ld.param.u8     %rs2, [_ZN8__main__6f$2411E6Masked6Masked_param_4];                                                                                                                                                                                                  
        setp.ne.s16     %p2, %rs2, 0;                                                                                                                                                                                                                                        
        and.pred        %p3, %p1, %p2;                                                                                                                                                                                                                                       
        add.s64         %rd4, %rd3, %rd2;                                                                                                                                                                                                                                    
        selp.b64        %rd5, %rd4, 0, %p3;                                                                                                                                                                                                                                  
        selp.u16        %rs3, 1, 0, %p3;                                                                                                                                                                                                                                     
        st.u64  [%rd1], %rd5;                                                                                                                                                                                                                                                
        st.u8   [%rd1+8], %rs3;                                                                                                                                                                                                                                              
        mov.u32         %r1, 0;                                                                                                                                                                                                                                              
        st.param.b32    [func_retval0+0], %r1;                                                                                            
        ret;                                                                                                                                                                                                                                                                 
}        
```

#### Then, this is the whole file being passed to jitify:

```cuda

    #pragma once

    // Include Jitify's cstddef header first
    #include <cstddef>

    #include <cuda/std/climits>
    #include <cuda/std/cstddef>
    #include <cuda/std/limits>
    #include <cudf/types.hpp>
    #include <cudf/wrappers/timestamps.hpp>
    #include <cudf/utilities/bit.hpp>

    struct Masked {
      int64_t value;
      bool valid;
    };
    
   

    void null_kernel(cudf::size_type size,
                     TypeOut* out_data, 
                     TypeLhs* lhs_data,
                     TypeRhs* rhs_data,
                     bool* out_mask,
                     cudf::bitmask_type const* lhs_mask,
                     cudf::size_type lhs_offset,
                     cudf::bitmask_type const* rhs_mask,
                     cudf::size_type rhs_offset
    ) {
        int tid = threadIdx.x;
        int blkid = blockIdx.x;
        int blksz = blockDim.x;
        int gridsz = gridDim.x;

        int start = tid + blkid * blksz;
        int step = blksz * gridsz;

        Masked output;

        char l_valid;
        char r_valid;

        long int l_data;
        long int r_data;

        for (cudf::size_type i=start; i<size; i+=step) {
          l_valid = lhs_mask ? cudf::bit_is_set(lhs_mask, lhs_offset + i) : true;
          r_valid = rhs_mask ? cudf::bit_is_set(rhs_mask, rhs_offset + i) : true;
          l_data = lhs_data[i];
          r_data = rhs_data[i];

          GENERIC_BINARY_OP(&output.value, lhs_data[i], l_valid, rhs_data[i], r_valid);

          out_data[i] = output.value;
          out_mask[i] = output.valid;
             
      }   
      
__device__ __inline__ void GENERIC_BINARY_OP(                                                                                                                                                                                                                                
  int64_t* _ZN8__main__6f_2413E6Masked6Masked_param_0,                                                                                                                                                                                                                       
  long int _ZN8__main__6f_2413E6Masked6Masked_param_1,                                                                                                                                                                                                                       
  char _ZN8__main__6f_2413E6Masked6Masked_param_2,                                                                                                                                                                                                                           
  long int _ZN8__main__6f_2413E6Masked6Masked_param_3,                                                                                                                                                                                                                       
  char _ZN8__main__6f_2413E6Masked6Masked_param_4                                                                                                                                                                                                                            
){                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                             
 asm volatile ("{");  asm volatile ("  .reg .pred _p<4>;");                                                                                                                                                                                                                  
   /**   .reg .pred     %p<4>  */                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                             
  asm volatile ("  .reg .b16 _rs<4>;");                                                                                                                                                                                                                                      
   /**   .reg .b16      %rs<4>  */                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                             
  asm volatile ("  .reg .b32 _r<2>;");                                                                                                                                                                                                                                       
   /**   .reg .b32      %r<2>  */                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                             
  asm volatile ("  .reg .b64 _rd<6>;");                                                                                                                                                                                                                                      
   /**   .reg .b64      %rd<6>  */                                                                                                                                                                                                                                           

  asm volatile ("  mov.u64 _rd1,  %0;": : "l"(_ZN8__main__6f_2413E6Masked6Masked_param_0));
   /**   ld.param.u64   %rd1, [_ZN8__main__6f$2413E6Masked6Masked_param_0]  */

  asm volatile ("  mov.u64 _rd2,  %0;": : "l"(_ZN8__main__6f_2413E6Masked6Masked_param_1));
   /**   ld.param.u64   %rd2, [_ZN8__main__6f$2413E6Masked6Masked_param_1]  */

  asm volatile ("  mov.u64 _rd3,  %0;": : "l"(_ZN8__main__6f_2413E6Masked6Masked_param_3));
   /**   ld.param.u64   %rd3, [_ZN8__main__6f$2413E6Masked6Masked_param_3]  */

  asm volatile ("  cvt.u8.u8 _rs1,  %0;": : "h"( static_cast<short>(_ZN8__main__6f_2413E6Masked6Masked_param_2)));
   /**   ld.param.u8    %rs1, [_ZN8__main__6f$2413E6Masked6Masked_param_2]  */

  asm volatile ("  setp.ne.s16 _p1, _rs1, 0;");
   /**   setp.ne.s16    %p1, %rs1, 0  */

  asm volatile ("  cvt.u8.u8 _rs2,  %0;": : "h"( static_cast<short>(_ZN8__main__6f_2413E6Masked6Masked_param_4)));
   /**   ld.param.u8    %rs2, [_ZN8__main__6f$2413E6Masked6Masked_param_4]  */

  asm volatile ("  setp.ne.s16 _p2, _rs2, 0;");
   /**   setp.ne.s16    %p2, %rs2, 0  */

  asm volatile ("  and.pred _p3, _p1, _p2;");
   /**   and.pred       %p3, %p1, %p2  */

  asm volatile ("  add.s64 _rd4, _rd3, _rd2;");
   /**   add.s64        %rd4, %rd3, %rd2  */

  asm volatile ("  selp.b64 _rd5, _rd4, 0, _p3;");
   /**   selp.b64       %rd5, %rd4, 0, %p3  */

  asm volatile ("  selp.u16 _rs3, 1, 0, _p3;");
   /**   selp.u16       %rs3, 1, 0, %p3  */

  asm volatile ("  st.u64 [_rd1], _rd5;");
   /**   st.u64         [%rd1], %rd5  */

  asm volatile ("  st.u8 [_rd1+8], _rs3;");
   /**   st.u8  [%rd1+8], %rs3  */

  asm volatile ("  mov.u32 _r1, 0;");
   /**   mov.u32        %r1, 0  */

  asm volatile (" /** *** The way we parse the CUDA PTX assumes the function returns the return value through the first function parameter. Thus the `st.param.***` instructions are not processed. *** */");
   /**   st.param.b32   [func_retval0+0], %r1  */

  asm volatile ("bra RETTGT;");


 asm volatile ("RETTGT:}");}      
          
```

# Test it
Here are some very basic cython bindings just used for the purposes of testing this exact functionality
```
def masked_binary_op(Column A, Column B, op, Column output_column, Column output_mask):
    cdef column_view A_view = A.view()
    cdef column_view B_view = B.view()

    cdef string c_str
    cdef type_id c_tid
    cdef data_type c_dtype

    if A.dtype != np.dtype('int64') or B.dtype != np.dtype('int64'):
        raise TypeError('int64 please')
    
 
    from cudf.core.udf import compile_udf
    c_str = compile_udf(op).encode('UTF-8')

    c_tid = <type_id> (
        <underlying_type_t_type_id> np_to_cudf_types[np.dtype('int64')]
    )
    c_dtype = data_type(c_tid)

    cdef column_view outcol_view = output_column.view()
    cdef column_view outmsk_view = output_mask.view()

    with nogil:
        c_output = move(libcudf_transform.masked_binary_op(
            A_view,
            B_view,
            c_str,
            c_dtype,
            outcol_view,
            outmsk_view
        ))
```

In [20]:
from cudf._lib.transform import masked_binary_op

def demo_udf(func, s1, s2):
    col1, col2 = s1._column, s2._column

    output_column = cudf.core.column.as_column(np.arange(8), dtype='int64')
    output_mask = cudf.core.column.as_column([False] * 8)

    result_col = masked_binary_op(col1, col2, func, output_column, output_mask)
    return cudf.Series(result_col)

In [30]:
def f(x, y):
    return x + y

s1 = cudf.Series([1, None, 3,    None, 2, 2,    5, None])
s2 = cudf.Series([1, 2,    None, None, 4, None, 5, None])

demo_udf(f, s1, s2)

0       2
1    <NA>
2    <NA>
3    <NA>
4       6
5    <NA>
6      10
7    <NA>
dtype: int64

In [31]:
s1 + s2

0       2
1    <NA>
2    <NA>
3    <NA>
4       6
5    <NA>
6      10
7    <NA>
dtype: int64

# `cudf.NA`
In general, we wan't the capability to work with `cudf.NA` inside our functions directly. To do this, we're just going to reapply the same machinery to overload what happens when we add a `MaskedType` to `cudf.NA`.

#### Create an NAType

In [21]:
from cudf.core.scalar import _NAType
class NAType(types.Type):
    # "There is a type called NAType"
    def __init__(self):
        super().__init__(name="NA")

numba_na = NAType()

In [22]:
from numba.core.extending import typeof_impl
@typeof_impl.register(_NAType)
def typeof_na(val, c):
    # instances of _NAType will be 
    # treaded as instances of NAType. 
    return numba_na



register_model(NAType)(models.OpaqueModel)

numba.core.datamodel.models.OpaqueModel

#### `operator.add` typing for Masked <-> NA

In [23]:
@cuda_registry.register_global(operator.add)
class MaskedScalarAddNull(AbstractTemplate):
    def generic(self, args, kws):
        if isinstance(args[0], MaskedType) and isinstance(args[1], NAType):
            return signature(numba_masked, numba_masked, numba_na)
  

#### Lowering: AKA what to actually do when this is requested
This says that when `+` is invoked between a `MaskedType` and an `NAType`, to make a new `MaskedType`, set it's validity to zero and return it.

In [24]:
from numba.cuda.cudaimpl import registry as cuda_lowering_registry

@cuda_lower(operator.add, MaskedType, NAType)
def masked_scalar_add_na_impl(context, builder, sig, args):
#    return_type = sig.return_type
    # use context to get llvm type for a bool
    result = cgutils.create_struct_proxy(numba_masked)(context, builder)
    result.valid = context.get_constant(types.boolean, 0)
    return result._getvalue()


@cuda_lowering_registry.lower_constant(NAType)
def constant_dummy(context, builder, ty, pyval):
    # This handles None, etc.
    return context.get_dummy_value()

# Constants
At this point the pattern is familiar. Register an overload that emits a signature if the operands match a `MaskedType` and a constant. The lowering is logically fairly simple. If the `MaskedType` is null, the answer is null, else the answer is a new `MaskedType` whose `value` is the sum of the inputs `value` and the constant.

In [25]:
from llvmlite import ir

@cuda_registry.register_global(operator.add)
class MaskedScalarAddConstant(AbstractTemplate):
    def generic(self, args, kws):
        if isinstance(args[0], MaskedType) and isinstance(args[1], types.Integer):
            return signature(numba_masked, numba_masked, types.int64)

@cuda_lower(operator.add, MaskedType, types.Integer)
def masked_scalar_add_constant_impl(context, builder, sig, input_values):
    masked_type, const_type = sig.args

    indata = cgutils.create_struct_proxy(masked_type)(context, builder, value=input_values[0])
    result = cgutils.create_struct_proxy(numba_masked)(context, builder)
    #to_add_const = context.get_constant(const_type, input_values[1])

    result.valid = context.get_constant(types.boolean, 0)
    with builder.if_then(indata.valid):
        result.value = builder.add(indata.value, input_values[1])
        result.valid = context.get_constant(types.boolean, 1)

    return result._getvalue()


In [26]:

def f(x, y):
    return x + y + cudf.NA

s1 = cudf.Series([1, None, 3,    None, 2, 2,    5, None])
s2 = cudf.Series([1, 2,    None, None, 4, None, 5, None])

result = demo_udf(f, s1, s2)

In [27]:
result

0    <NA>
1    <NA>
2    <NA>
3    <NA>
4    <NA>
5    <NA>
6    <NA>
7    <NA>
dtype: int64

In [28]:
def f(x, y):
    return x + y + 1

s1 = cudf.Series([1, None, 3,    None, 2, 2,    5, None])
s2 = cudf.Series([1, 2,    None, None, 4, None, 5, None])

result = demo_udf(f, s1, s2)

In [29]:
result

0       3
1    <NA>
2    <NA>
3    <NA>
4       7
5    <NA>
6      11
7    <NA>
dtype: int64