In [25]:
library(tidyverse)
library(tidymodels)
install.packages("themis")
library(themis)

also installing the dependencies ‘RANN’, ‘ROSE’


Updating HTML index of packages in '.Library'

Making 'packages.html' ...
 done



In [30]:
cleveland <- read_csv("processed.cleveland.data.csv", 
                      col_names = c("age",
                                    "sex", #c #1 male, 0 female
                                     "cp", #c #chest pain type
                                    "trestbps", 
                                    "chol", 
                                    "fbs", #c
                                    "restecg", #c
                                    "thalach", 
                                    "exang", #c
                                    "oldpeak", 
                                    "slope", #c
                                    "ca", 
                                    "thal", #c
                                    "num"))
slice(cleveland, 1: 5)
#chest pain type based on predictor variables

[1mRows: [22m[34m303[39m [1mColumns: [22m[34m14[39m
[36m──[39m [1mColumn specification[22m [36m────────────────────────────────────────────────────────[39m
[1mDelimiter:[22m ","
[31mchr[39m  (2): ca, thal
[32mdbl[39m (12): age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpea...

[36mℹ[39m Use `spec()` to retrieve the full column specification for this data.
[36mℹ[39m Specify the column types or set `show_col_types = FALSE` to quiet this message.


age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,num
<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<chr>,<chr>,<dbl>
63,1,1,145,233,1,2,150,0,2.3,3,0.0,6.0,0
67,1,4,160,286,0,2,108,1,1.5,2,3.0,3.0,2
67,1,4,120,229,0,2,129,1,2.6,2,2.0,7.0,1
37,1,3,130,250,0,0,187,0,3.5,3,0.0,3.0,0
41,0,2,130,204,0,2,172,0,1.4,1,0.0,3.0,0


In [43]:
set.seed(101)

data <- select(cleveland, trestbps, chol, thalach, oldpeak, cp) |>
mutate(cp = as_factor(cp))

spec <- nearest_neighbor(weight_func = "rectangular", neighbors =2) |>
set_engine("kknn") |>
set_mode("classification")

heart_recipe <- recipe(cp ~ trestbps + chol + thalach + oldpeak, data = data) |>
step_scale(all_predictors()) |>
step_center(all_predictors()) |>
step_upsample(cp, over_ratio = 1)

heart_wf <- workflow() |>
add_recipe(heart_recipe) |>
add_model(spec) |>
fit(data = data)

preds <- predict(heart_wf, data) |>
bind_cols(data)


accuracy <- nrow(filter(preds, .pred_class == cp))/303
accuracy 

# neighbors vs. accuracy
# 5 - .63
# 10 - .51
# 7 - .56
# 3 - .72
# 2 - .95


# improvements:
# 1. tune neighbours for max accuracy
# 2. oversample data in a better way, since cp type 1 is way less common than cp type 4 in the dataset.

In [18]:
counts <- data |>
group_by(cp) |>
summarize(count=n())

counts

cp,count
<fct>,<int>
1,23
2,50
3,86
4,144
