In [27]:
# case_when method

import numpy as np
import pandas as pd
from pandas import Series

In [28]:
# check the version
pd.__version__

'2.2.0'

In [29]:
# case_when allows us to say: when this condition is true, give me these values for those
# elements 

np.random.seed(0)
s = Series(np.random.randint(0, 1000, 10))
s

0    684
1    559
2    629
3    192
4    835
5    763
6    707
7    359
8      9
9    723
dtype: int64

In [32]:
# I want to negate all of the values in s that are > 600

# s.loc[s>600] = -s

# caselist is a list of 2-element tuples
# first element: a condition, basically a boolean series
# second element: what values we should use to replace elements where the series is True

s.case_when(caselist=[(s>600, -s)])

0   -684
1    559
2   -629
3    192
4   -835
5   -763
6   -707
7    359
8      9
9   -723
dtype: int64

In [33]:
s

0    684
1    559
2    629
3    192
4    835
5    763
6    707
7    359
8      9
9    723
dtype: int64

In [34]:
# what if I want to negate any value of s that's >600, and I want to double
# any value in s that's < 200?

s.case_when(caselist=[(s>600, -s),
                      (s<200, s*2)])

0   -684
1    559
2   -629
3    384
4   -835
5   -763
6   -707
7    359
8     18
9   -723
dtype: int64

In [35]:
np.random.seed(0)
s = Series(np.random.randint(0, 1000, 10))

s.case_when(caselist=[(s%4==0, 4),
                      (s%3==0, 3)])

0      4
1    559
2    629
3      4
4    835
5    763
6    707
7    359
8      3
9      3
dtype: int64

In [36]:
s.case_when(caselist=[(s%3==0, 3),
                      (s%4==0, 4)])

0      3
1    559
2    629
3      3
4    835
5    763
6    707
7    359
8      3
9      3
dtype: int64

In [37]:
s.case_when(caselist=[(s%3==0, 'by 3'),
                      (s%4==0, 'by 4')])

0    by 3
1     559
2     629
3    by 3
4     835
5     763
6     707
7     359
8    by 3
9    by 3
dtype: object

In [39]:
# what if I want a default value?

s.case_when(caselist=[(s%3==0, 'by 3'),
                      (s%4==0, 'by 4'),
                      (s==s, 'neither')])

0       by 3
1    neither
2    neither
3       by 3
4    neither
5    neither
6    neither
7    neither
8       by 3
9       by 3
dtype: object

In [40]:
np.random.seed(0)
s = Series(np.random.randint(0, 1000, 10))

# let's find out which numbers contain the digit 9
s.case_when(caselist=[(s.astype(str).str.contains('9'), 'has 9'),
                      (s==s, 'no 9')])

0     no 9
1    has 9
2    has 9
3    has 9
4     no 9
5     no 9
6     no 9
7    has 9
8    has 9
9     no 9
dtype: object

In [41]:
# any outlier values should be turned into NaN

np.random.seed(0)
s = Series(np.random.randint(0, 1000, 10))

s.case_when(caselist=[(s < s.mean() - s.std(), np.nan),
                      (s > s.mean() + s.std(), np.nan)])
                      

0    684.0
1    559.0
2    629.0
3      NaN
4      NaN
5    763.0
6    707.0
7    359.0
8      NaN
9    723.0
dtype: float64

In [44]:
np.random.seed(0)
s = Series(np.random.randint(0, 1000, 10))

s.case_when(caselist=[(lambda s_: s_.lt(500), 'low'),
                      (lambda s_: s_.ge(500), 'high')])
                      


0    high
1    high
2    high
3     low
4    high
5    high
6    high
7     low
8     low
9    high
dtype: object

In [45]:
from pandas import DataFrame


In [47]:
np.random.seed(0)
df = DataFrame(np.random.randint(0, 100, [4,4]),
              index=list('abcd'),
              columns=list('wxyz'))

df

Unnamed: 0,w,x,y,z
a,44,47,64,67
b,67,9,83,21
c,36,87,70,88
d,88,12,58,65


In [50]:
df['w'].case_when([(lambda s_: s_ > s_.mean(), 9),
                  (lambda s_: s_ == s_, 0)])

a    0
b    9
c    0
d    9
Name: w, dtype: int64