Libraries

In [1]:
library(repr)
library(tidyverse)
library(tidymodels)
library(ggplot2)

library(kknn)

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.3.2 ──
[32m✔[39m [34mggplot2[39m 3.4.2     [32m✔[39m [34mpurrr  [39m 1.0.1
[32m✔[39m [34mtibble [39m 3.2.1     [32m✔[39m [34mdplyr  [39m 1.1.1
[32m✔[39m [34mtidyr  [39m 1.3.0     [32m✔[39m [34mstringr[39m 1.5.0
[32m✔[39m [34mreadr  [39m 2.1.3     [32m✔[39m [34mforcats[39m 0.5.2
── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()
── [1mAttaching packages[22m ────────────────────────────────────── tidymodels 1.0.0 ──

[32m✔[39m [34mbroom       [39m 1.0.2     [32m✔[39m [34mrsample     [39m 1.1.1
[32m✔[39m [34mdials       [39m 1.1.0     [32m✔[39m [34mtune        [39m 1.0.1
[32m✔[39m [34minfer       [39m 1.0.4     [32m✔[39m [34mworkflows   [39m 1.1.2
[32m✔[39

Needed code from Patrick

In [2]:
url <- "https://raw.githubusercontent.com/perdomopatrick/group7/main/forestfires.csv"
data <- read_csv(url)

clean_data <- data|>
      select(-X,-Y,-month,-day)

head(clean_data, n= 5)

set.seed(1133) 

data_split <- initial_split(clean_data, prop = 0.75, strata = area)
data_training <- training(data_split)
data_testing <- testing(data_split)

[1mRows: [22m[34m517[39m [1mColumns: [22m[34m13[39m
[36m──[39m [1mColumn specification[22m [36m────────────────────────────────────────────────────────[39m
[1mDelimiter:[22m ","
[31mchr[39m  (2): month, day
[32mdbl[39m (11): X, Y, FFMC, DMC, DC, ISI, temp, RH, wind, rain, area

[36mℹ[39m Use `spec()` to retrieve the full column specification for this data.
[36mℹ[39m Specify the column types or set `show_col_types = FALSE` to quiet this message.


FFMC,DMC,DC,ISI,temp,RH,wind,rain,area
<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>
86.2,26.2,94.3,5.1,8.2,51,6.7,0.0,0
90.6,35.4,669.1,6.7,18.0,33,0.9,0.0,0
90.6,43.7,686.9,6.7,14.6,33,1.3,0.0,0
91.7,33.3,77.5,9.0,8.3,97,4.0,0.2,0
89.3,51.3,102.2,9.6,11.4,99,1.8,0.0,0


Tune your model to choose the best k-neighbours using cross validation.

In [3]:
# Tuning for best k-neighbours

fire_recipe <- recipe(area ~ RH + rain + DMC + wind + ISI, data = data_training) |>
  step_scale(all_predictors()) |>
  step_center(all_predictors())

fire_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
                  set_engine("kknn") |>
                  set_mode("regression")

fire_vfold <- vfold_cv(data_training, v = 5, strata = area)

fire_wkflw <- workflow() |>
  add_recipe(fire_wind_recipe) |>
  add_model(fire_wind_spec)

gridvals <- tibble(neighbors = seq(1, 100))

fire_results <- fire_wind_wkflw |>
  tune_grid(resamples = fire_wind_vfold, grid = gridvals) |>
  collect_metrics() |>
  filter(.metric == "rmse")

fire_min <- fire_wind_results |>
  filter(mean == min(mean))

fire_min

ERROR: Error in is_model_spec(spec): object 'fire_wind_spec' not found


In [None]:
# Visualization of best neighbors
fire_neighbors <- fire_wind_results |>
  ggplot(aes(x = neighbors, y = mean)) +
  geom_point() +
  geom_line(colour = "blue") +
  geom_vline(aes(xintercept = fire_wind_min$neighbors, colour = "red")) + # takes the neighbors value from min object
  labs(x = "Neighbors", y = "RMSE", caption = "Graph 2") +
  theme(legend.position="none")

fire_neighbors

Now perform the knn-regression on your test set.

In [None]:
# results with optimal knn neighbors value

kmin <- fire_wind_min |> pull(neighbors)

new_fire_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = kmin) |>
  set_engine("kknn") |>
  set_mode("regression")

new_fire_fit <- workflow() |>
  add_recipe(fire_wind_recipe) |>
  add_model(new_fire_wind_spec) |>
  fit(data = data_training)

new_fire_results <- new_fire_wind_fit |>
  predict(data_testing) |>
  bind_cols(data_testing) |>
  metrics(truth = area, estimate = .pred) |>
  filter(.metric == 'rmse')

new_fire_results