Skip to content

Commit

Permalink
[Python] use dtype objects; closes #78
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 3, 2020
1 parent 70bf71f commit 4be09f2
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 94 deletions.
128 changes: 128 additions & 0 deletions py-polars/pypolars/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import ctypes


class Int8:
pass


class Int16:
pass


class Int32:
pass


class Int64:
pass


class UInt8:
pass


class UInt16:
pass


class UInt32:
pass


class UInt64:
pass


class Float32:
pass


class Float64:
pass


class Bool:
pass


class Utf8:
pass


class LargeList:
pass


class Date32:
pass


class Date64:
pass


# Don't change the order of these!
dtypes = [
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64,
Bool,
Utf8,
LargeList,
Date32,
Date64,
]
DTYPE_TO_FFINAME = {
Int8: "i8",
Int16: "i16",
Int32: "i32",
Int64: "i64",
UInt8: "u8",
UInt16: "u16",
UInt32: "u32",
UInt64: "u64",
Float32: "f32",
Float64: "f64",
Bool: "bool",
Utf8: "str",
LargeList: "largelist",
Date32: "date32",
Date64: "date64",
}


def dtype_to_ctype(dtype: "DataType") -> "ctype":
if dtype == UInt8:
ptr_type = ctypes.c_uint8
elif dtype == UInt16:
ptr_type = ctypes.c_uint16
elif dtype == UInt32:
ptr_type = ctypes.c_uint
elif dtype == UInt64:
ptr_type = ctypes.c_ulong
elif dtype == Int8:
ptr_type = ctypes.c_int8
elif dtype == Int16:
ptr_type = ctypes.c_int16
elif dtype == Int32:
ptr_type = ctypes.c_int
elif dtype == Int64:
ptr_type = ctypes.c_long
elif dtype == Float32:
ptr_type = ctypes.c_float
elif dtype == Float64:
ptr_type = ctypes.c_double
elif dtype == Date32:
ptr_type = ctypes.c_int
elif dtype == Date64:
ptr_type = ctypes.c_long
else:
return NotImplemented
return ptr_type
9 changes: 5 additions & 4 deletions py-polars/pypolars/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .pypolars import PyDataFrame, PySeries
from typing import Dict, Sequence, List, Tuple, Optional, Union
from .series import Series, wrap_s
from .datatypes import *
import numpy as np
from typing import TextIO, BinaryIO

Expand Down Expand Up @@ -217,9 +218,9 @@ def __getitem__(self, item):
else:
return wrap_df(self._df.take(item))
dtype = item.dtype
if dtype == "bool":
if dtype == Bool:
return wrap_df(self._df.filter(item.inner()))
if dtype == "u32":
if dtype == UInt32:
return wrap_df(self._df.take_with_series(item.inner()))
return NotImplemented

Expand Down Expand Up @@ -282,11 +283,11 @@ def columns(self) -> List[str]:
return self._df.columns()

@property
def dtypes(self) -> List[str]:
def dtypes(self) -> List[type]:
"""
get dtypes
"""
return self._df.dtypes()
return [dtypes[idx] for idx in self._df.dtypes()]

def replace_at_idx(self, index: int, series: Series):
"""
Expand Down

0 comments on commit 4be09f2

Please sign in to comment.