In [1]:
import polars as pl
from polars.testing.parametric import dataframes, column

In [None]:
pl.Config.set_tbl_rows(100)

Book = pl.Enum("XYZ")

def generate(size=20):
    return dataframes(
    [
        column("id", dtype=pl.UInt16, unique=True, allow_null=False), 
        column("value", dtype=pl.Float64, allow_null=False), 
        column("category", dtype = Book, allow_null=False)
    ], 
    min_size=size, max_size=size)

original = generate().example()

In [3]:
original

id,value,category
u16,f64,enum
63323,1.0278e+139,"""X"""
63685,1.687e+221,"""Y"""
54131,5.0194e+16,"""Y"""
54624,-4.0034e-10,"""Z"""
299,-3.3401e+41,"""Z"""
10238,-3.7161e+137,"""Y"""
21722,-4.4864e+45,"""Y"""
21562,-4.9719e+16,"""Z"""
62647,1.0604999999999999e+86,"""Z"""
10981,-6.474200000000001e+73,"""Y"""


In [None]:
def mutate(df, k,j):
    return df.with_columns(
        category = pl.when(pl.col.category == k).then(pl.lit(j).cast(Book)).otherwise("category"),
    )

In [5]:
new = pl.concat([mutate(original, "X","Z").head(18), generate(3).example()])

In [6]:
new

id,value,category
u16,f64,enum
63323,1.0278e+139,"""Z"""
63685,1.687e+221,"""Y"""
54131,5.0194e+16,"""Y"""
54624,-4.0034e-10,"""Z"""
299,-3.3401e+41,"""Z"""
10238,-3.7161e+137,"""Y"""
21722,-4.4864e+45,"""Y"""
21562,-4.9719e+16,"""Z"""
62647,1.0604999999999999e+86,"""Z"""
10981,-6.474200000000001e+73,"""Y"""


In [7]:
original.group_by("category").agg(pl.sum("value"))

category,value
enum,f64
"""X""",-8.5754e+291
"""Y""",1.687e+221
"""Z""",-2.5249e+172


In [8]:
new.group_by("category").agg(pl.sum("value"))

category,value
enum,f64
"""Z""",-8.5754e+291
"""Y""",1.687e+221
"""X""",-1.7268e-306


In [9]:
full = original.join(new, on="id", how="full")

In [10]:
minus = (
    full
    .filter((pl.col.category=="X") & (pl.col.category_right!="X"))
    .select(pl.col.value.sum())
)

In [11]:
plus = (full.filter(
   ( (pl.col.category != "X") | pl.col.category.is_null()) & (pl.col.category_right=="X") 
)
.select(pl.col.value_right.sum()))

In [12]:
original.filter(pl.col.category=="X").select(pl.col.value.sum()) - minus + plus

value
f64
-1.7268e-306


In [13]:
new.group_by("category").agg(pl.sum("value"))

category,value
enum,f64
"""Z""",-8.5754e+291
"""Y""",1.687e+221
"""X""",-1.7268e-306
