# Ron Au | Assignment 2: Beer Review Project

## Brief
You have been tasked to work on a project to deploy a Machine Learning model into production. You will train a custom neural networks model that will accurately predict a type of beer based on some rating criterias such as appearance, aroma, palate or taste. You will also need to build a web app and deploy it online in order to serve your model for real time predictions.

## Meta / Options

In [15]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
  sys.path.append(module_path)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import numpy as np
import pandas as pd

In [17]:
from IPython.display import display

pd.options.display.max_columns = None

## Data Exploration

### Load dataset

In [18]:
df = pd.read_csv('../data/raw/beer_reviews.csv')

### Create new dataframe from only business features

In [19]:
dfc = df[['brewery_name', 'review_appearance', 'review_aroma', 'review_palate', 'review_taste', 'beer_style']]

In [20]:
display(dfc.head())
display(dfc.describe())
print(dfc.info())

Unnamed: 0,brewery_name,review_appearance,review_aroma,review_palate,review_taste,beer_style
0,Vecchio Birraio,2.5,2.0,1.5,1.5,Hefeweizen
1,Vecchio Birraio,3.0,2.5,3.0,3.0,English Strong Ale
2,Vecchio Birraio,3.0,2.5,3.0,3.0,Foreign / Export Stout
3,Vecchio Birraio,3.5,3.0,2.5,3.0,German Pilsener
4,Caldera Brewing Company,4.0,4.5,4.0,4.5,American Double / Imperial IPA


Unnamed: 0,review_appearance,review_aroma,review_palate,review_taste
count,1586614.0,1586614.0,1586614.0,1586614.0
mean,3.841642,3.735636,3.743701,3.79286
std,0.6160928,0.6976167,0.6822184,0.7319696
min,0.0,1.0,1.0,1.0
25%,3.5,3.5,3.5,3.5
50%,4.0,4.0,4.0,4.0
75%,4.0,4.0,4.0,4.5
max,5.0,5.0,5.0,5.0


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1586614 entries, 0 to 1586613
Data columns (total 6 columns):
 #   Column             Non-Null Count    Dtype  
---  ------             --------------    -----  
 0   brewery_name       1586599 non-null  object 
 1   review_appearance  1586614 non-null  float64
 2   review_aroma       1586614 non-null  float64
 3   review_palate      1586614 non-null  float64
 4   review_taste       1586614 non-null  float64
 5   beer_style         1586614 non-null  object 
dtypes: float64(4), object(2)
memory usage: 72.6+ MB
None


In [21]:
num_brewery_name = len(dfc['brewery_name'].unique())
num_beer_style = len(dfc['beer_style'].unique())

print("Unique 'brewery_name': ", num_brewery_name)
print("Unique 'beer_style': ", num_beer_style)

Unique 'brewery_name':  5743
Unique 'beer_style':  104


## Data Pre-processing

### Drop rows with empty values and reset index

In [22]:
dfc = dfc.dropna()
dfc = dfc.reset_index(drop=True)
display(dfc.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1586599 entries, 0 to 1586598
Data columns (total 6 columns):
 #   Column             Non-Null Count    Dtype  
---  ------             --------------    -----  
 0   brewery_name       1586599 non-null  object 
 1   review_appearance  1586599 non-null  float64
 2   review_aroma       1586599 non-null  float64
 3   review_palate      1586599 non-null  float64
 4   review_taste       1586599 non-null  float64
 5   beer_style         1586599 non-null  object 
dtypes: float64(4), object(2)
memory usage: 72.6+ MB


None

### Separate features from target

In [23]:
X = dfc.iloc[:,0:-1] # feature columns
y = dfc.iloc[:,-1:] # target column

### Encode brewery_name column

ETA 3 minutes

In [24]:
from src.data.preprocess import target_encode

X, target_encoder = target_encode(X, y, cols=['brewery_name'])

### Encode beer_style column

1. Encode values
2. Extract mode
3. Re-encode using mode for unknown values

This will allow unseen values to use an assumed guess instead of raising an error

In [25]:
from src.data.preprocess import ordinal_encode

y, ordinal_encoder = ordinal_encode(y, columns=['beer_style'])

mode = y.mode().iloc[0][0]
print('Most common beer_style is:', ordinal_encoder.inverse_transform([[mode]])[0][0], '\nEncoded as:', mode)

Most common beer_style is: American IPA 
Encoded as: 12


### Scale numerical data

In [26]:
from src.data.preprocess import min_max_scale

numeric_columns = ['review_appearance', 'review_aroma', 'review_palate', 'review_taste']
X, min_max_scaler = min_max_scale(X, numeric_columns)

### Concatenate dataframe and save to CSV
ETA 1.5 minutes

In [27]:
dfc = pd.concat([X, y], axis=1)

dfc.to_csv('../data/processed/beer_reviews.csv', index=False)

### Load dataframe from CSV
ETA 30 seconds. This step is introduced so there is the option of only running cells below during experimentation, skipping the ~5 minutes worth of cells above

In [28]:
dfc = pd.read_csv('../data/processed/beer_reviews.csv')

In [29]:
display(dfc)
dfc.info()

Unnamed: 0,review_appearance,review_aroma,review_palate,review_taste,brewery_name_English Strong Ale,brewery_name_Foreign / Export Stout,brewery_name_German Pilsener,brewery_name_American Double / Imperial IPA,brewery_name_Herbed / Spiced Beer,brewery_name_Light Lager,brewery_name_Oatmeal Stout,brewery_name_American Pale Lager,brewery_name_Rauchbier,brewery_name_American Pale Ale (APA),brewery_name_American Porter,brewery_name_Belgian Strong Dark Ale,brewery_name_American IPA,brewery_name_American Stout,brewery_name_Russian Imperial Stout,brewery_name_American Amber / Red Ale,brewery_name_American Strong Ale,brewery_name_Märzen / Oktoberfest,brewery_name_American Adjunct Lager,brewery_name_American Blonde Ale,brewery_name_Euro Pale Lager,brewery_name_English Brown Ale,brewery_name_Scotch Ale / Wee Heavy,brewery_name_Fruit / Vegetable Beer,brewery_name_American Double / Imperial Stout,brewery_name_Belgian Pale Ale,brewery_name_English Bitter,brewery_name_English Porter,brewery_name_Irish Dry Stout,brewery_name_American Barleywine,brewery_name_Belgian Strong Pale Ale,brewery_name_Doppelbock,brewery_name_Maibock / Helles Bock,brewery_name_Pumpkin Ale,brewery_name_Dortmunder / Export Lager,brewery_name_Euro Strong Lager,brewery_name_Euro Dark Lager,brewery_name_Low Alcohol Beer,brewery_name_Weizenbock,brewery_name_Extra Special / Strong Bitter (ESB),brewery_name_Bock,brewery_name_English India Pale Ale (IPA),brewery_name_Altbier,brewery_name_Kölsch,brewery_name_Munich Dunkel Lager,brewery_name_Rye Beer,brewery_name_American Pale Wheat Ale,brewery_name_Milk / Sweet Stout,brewery_name_Schwarzbier,brewery_name_Vienna Lager,brewery_name_American Amber / Red Lager,brewery_name_Scottish Ale,brewery_name_Witbier,brewery_name_American Black Ale,brewery_name_Saison / Farmhouse Ale,brewery_name_English Barleywine,brewery_name_English Dark Mild Ale,brewery_name_California Common / Steam Beer,brewery_name_Czech Pilsener,brewery_name_English Pale Ale,brewery_name_Belgian IPA,brewery_name_Tripel,brewery_name_Flanders Oud Bruin,brewery_name_American Brown Ale,brewery_name_Winter Warmer,brewery_name_Smoked Beer,brewery_name_Dubbel,brewery_name_Flanders Red Ale,brewery_name_Dunkelweizen,brewery_name_Roggenbier,brewery_name_Keller Bier / Zwickel Bier,brewery_name_Belgian Dark Ale,brewery_name_Bière de Garde,brewery_name_Japanese Rice Lager,brewery_name_Black & Tan,brewery_name_Irish Red Ale,brewery_name_Chile Beer,brewery_name_English Stout,brewery_name_Cream Ale,brewery_name_American Wild Ale,brewery_name_American Double / Imperial Pilsner,brewery_name_Scottish Gruit / Ancient Herbed Ale,brewery_name_Wheatwine,brewery_name_American Dark Wheat Ale,brewery_name_American Malt Liquor,brewery_name_Baltic Porter,brewery_name_Munich Helles Lager,brewery_name_Kristalweizen,brewery_name_English Pale Mild Ale,brewery_name_Lambic - Fruit,brewery_name_Old Ale,brewery_name_Quadrupel (Quad),brewery_name_Braggot,brewery_name_Lambic - Unblended,brewery_name_Eisbock,brewery_name_Berliner Weissbier,brewery_name_Kvass,brewery_name_Faro,brewery_name_Gueuze,brewery_name_Gose,brewery_name_Happoshu,brewery_name_Sahti,brewery_name_Bière de Champagne / Bière Brut,beer_style
0,0.5,0.250,0.125,0.125,0.249389,0.249391,0.249416,0.000134,0.000016,0.000022,0.000028,0.000014,0.000006,0.000099,0.000079,0.000059,0.000183,0.000038,0.000084,0.000071,0.000050,0.000037,0.000048,0.000020,0.000028,0.000030,0.000027,0.000053,0.000079,0.00003,0.000014,0.000017,0.000020,0.000042,0.000049,0.000034,0.000017,0.000024,0.000007,0.000004,0.000007,0.000002,0.000015,0.000027,0.000018,0.000025,0.000012,0.000013,0.000012,0.000016,0.000038,0.000021,0.000015,0.000014,0.000015,0.000014,0.000047,0.000018,0.000049,0.000021,0.000004,0.000006,0.000020,0.000036,0.000019,0.000047,0.000008,0.000039,0.000032,0.000005,0.000031,0.000010,0.000011,7.262342e-07,0.000004,0.00001,0.00001,0.000002,0.000004,0.000012,0.000004,0.000005,0.000008,0.000028,0.000008,0.000004,0.000006,0.000002,0.000006,0.000018,0.000012,0.000003,0.000001,0.000017,0.000023,0.000028,0.000002,0.000002,0.000004,0.000005,4.628574e-07,9.490914e-07,0.000009,0.000001,3.755846e-07,0.000002,0.000002,65
1,0.6,0.375,0.500,0.500,0.249389,0.249391,0.249416,0.000134,0.000016,0.000022,0.000028,0.000014,0.000006,0.000099,0.000079,0.000059,0.000183,0.000038,0.000084,0.000071,0.000050,0.000037,0.000048,0.000020,0.000028,0.000030,0.000027,0.000053,0.000079,0.00003,0.000014,0.000017,0.000020,0.000042,0.000049,0.000034,0.000017,0.000024,0.000007,0.000004,0.000007,0.000002,0.000015,0.000027,0.000018,0.000025,0.000012,0.000013,0.000012,0.000016,0.000038,0.000021,0.000015,0.000014,0.000015,0.000014,0.000047,0.000018,0.000049,0.000021,0.000004,0.000006,0.000020,0.000036,0.000019,0.000047,0.000008,0.000039,0.000032,0.000005,0.000031,0.000010,0.000011,7.262342e-07,0.000004,0.00001,0.00001,0.000002,0.000004,0.000012,0.000004,0.000005,0.000008,0.000028,0.000008,0.000004,0.000006,0.000002,0.000006,0.000018,0.000012,0.000003,0.000001,0.000017,0.000023,0.000028,0.000002,0.000002,0.000004,0.000005,4.628574e-07,9.490914e-07,0.000009,0.000001,3.755846e-07,0.000002,0.000002,51
2,0.6,0.375,0.500,0.500,0.249389,0.249391,0.249416,0.000134,0.000016,0.000022,0.000028,0.000014,0.000006,0.000099,0.000079,0.000059,0.000183,0.000038,0.000084,0.000071,0.000050,0.000037,0.000048,0.000020,0.000028,0.000030,0.000027,0.000053,0.000079,0.00003,0.000014,0.000017,0.000020,0.000042,0.000049,0.000034,0.000017,0.000024,0.000007,0.000004,0.000007,0.000002,0.000015,0.000027,0.000018,0.000025,0.000012,0.000013,0.000012,0.000016,0.000038,0.000021,0.000015,0.000014,0.000015,0.000014,0.000047,0.000018,0.000049,0.000021,0.000004,0.000006,0.000020,0.000036,0.000019,0.000047,0.000008,0.000039,0.000032,0.000005,0.000031,0.000010,0.000011,7.262342e-07,0.000004,0.00001,0.00001,0.000002,0.000004,0.000012,0.000004,0.000005,0.000008,0.000028,0.000008,0.000004,0.000006,0.000002,0.000006,0.000018,0.000012,0.000003,0.000001,0.000017,0.000023,0.000028,0.000002,0.000002,0.000004,0.000005,4.628574e-07,9.490914e-07,0.000009,0.000001,3.755846e-07,0.000002,0.000002,59
3,0.7,0.500,0.375,0.500,0.249389,0.249391,0.249416,0.000134,0.000016,0.000022,0.000028,0.000014,0.000006,0.000099,0.000079,0.000059,0.000183,0.000038,0.000084,0.000071,0.000050,0.000037,0.000048,0.000020,0.000028,0.000030,0.000027,0.000053,0.000079,0.00003,0.000014,0.000017,0.000020,0.000042,0.000049,0.000034,0.000017,0.000024,0.000007,0.000004,0.000007,0.000002,0.000015,0.000027,0.000018,0.000025,0.000012,0.000013,0.000012,0.000016,0.000038,0.000021,0.000015,0.000014,0.000015,0.000014,0.000047,0.000018,0.000049,0.000021,0.000004,0.000006,0.000020,0.000036,0.000019,0.000047,0.000008,0.000039,0.000032,0.000005,0.000031,0.000010,0.000011,7.262342e-07,0.000004,0.00001,0.00001,0.000002,0.000004,0.000012,0.000004,0.000005,0.000008,0.000028,0.000008,0.000004,0.000006,0.000002,0.000006,0.000018,0.000012,0.000003,0.000001,0.000017,0.000023,0.000028,0.000002,0.000002,0.000004,0.000005,4.628574e-07,9.490914e-07,0.000009,0.000001,3.755846e-07,0.000002,0.000002,61
4,0.8,0.875,0.750,0.875,0.000000,0.000000,0.002250,0.001125,0.010124,0.000000,0.002250,0.001125,0.047244,0.185602,0.022497,0.040495,0.494938,0.001125,0.032621,0.128234,0.002250,0.001125,0.015748,0.007874,0.000000,0.000000,0.000000,0.002250,0.001125,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000,0.000000,9
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1586594,0.7,0.750,0.750,0.750,0.000000,0.000000,0.000000,0.007168,0.000000,0.000000,0.014337,0.007168,0.000000,0.017921,0.086022,0.078853,0.093190,0.046595,0.003584,0.043011,0.003584,0.003584,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.025090,0.00000,0.000000,0.003584,0.028674,0.000000,0.025090,0.000000,0.003584,0.060932,0.000000,0.000000,0.000000,0.000000,0.000000,0.039427,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.007168,0.003584,0.000000,0.000000,0.068100,0.000000,0.021505,0.000000,0.014337,0.000000,0.000000,0.003584,0.010753,0.068100,0.014337,0.146953,0.000000,0.003584,0.000000,0.000000,0.017921,0.003584,0.000000,0.000000e+00,0.000000,0.00000,0.00000,0.000000,0.000000,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000,0.000000,85
1586595,0.5,1.000,0.250,0.750,0.000000,0.000000,0.000000,0.007168,0.000000,0.000000,0.014337,0.007168,0.000000,0.017921,0.086022,0.078853,0.093190,0.046595,0.003584,0.043011,0.003584,0.003584,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.025090,0.00000,0.000000,0.003584,0.028674,0.000000,0.025090,0.000000,0.003584,0.060932,0.000000,0.000000,0.000000,0.000000,0.000000,0.039427,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.007168,0.003584,0.000000,0.000000,0.068100,0.000000,0.021505,0.000000,0.014337,0.000000,0.000000,0.003584,0.010753,0.068100,0.014337,0.146953,0.000000,0.003584,0.000000,0.000000,0.017921,0.003584,0.000000,0.000000e+00,0.000000,0.00000,0.00000,0.000000,0.000000,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000,0.000000,85
1586596,0.6,0.625,0.625,0.750,0.000000,0.000000,0.000000,0.007168,0.000000,0.000000,0.014337,0.007168,0.000000,0.017921,0.086022,0.078853,0.093190,0.046595,0.003584,0.043011,0.003584,0.003584,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.025090,0.00000,0.000000,0.003584,0.028674,0.000000,0.025090,0.000000,0.003584,0.060932,0.000000,0.000000,0.000000,0.000000,0.000000,0.039427,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.007168,0.003584,0.000000,0.000000,0.068100,0.000000,0.021505,0.000000,0.014337,0.000000,0.000000,0.003584,0.010753,0.068100,0.014337,0.146953,0.000000,0.003584,0.000000,0.000000,0.017921,0.003584,0.000000,0.000000e+00,0.000000,0.00000,0.00000,0.000000,0.000000,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000,0.000000,85
1586597,0.9,0.875,0.875,0.875,0.000000,0.000000,0.000000,0.007168,0.000000,0.000000,0.014337,0.007168,0.000000,0.017921,0.086022,0.078853,0.093190,0.046595,0.003584,0.043011,0.003584,0.003584,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.025090,0.00000,0.000000,0.003584,0.028674,0.000000,0.025090,0.000000,0.003584,0.060932,0.000000,0.000000,0.000000,0.000000,0.000000,0.039427,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.007168,0.003584,0.000000,0.000000,0.068100,0.000000,0.021505,0.000000,0.014337,0.000000,0.000000,0.003584,0.010753,0.068100,0.014337,0.146953,0.000000,0.003584,0.000000,0.000000,0.017921,0.003584,0.000000,0.000000e+00,0.000000,0.00000,0.00000,0.000000,0.000000,0.007168,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.003584,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000,0.000000,0.000000e+00,0.000000,0.000000,85


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1586599 entries, 0 to 1586598
Columns: 108 entries, review_appearance to beer_style
dtypes: float64(107), int64(1)
memory usage: 1.3 GB


### Create dataset splits

In [30]:
from src.data.sets import split_sets_random, save_sets, load_sets

X_train, y_train, X_val, y_val, X_test, y_test = split_sets_random(dfc, target_col='beer_style', test_ratio=0.2, to_numpy=True)

save_sets(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test, path='../data/processed/')

In [31]:
from src.data.sets import load_sets

X_train, y_train, X_val, y_val, X_test, y_test = load_sets(path='../data/processed/')

In [32]:
from src.models.pytorch import PytorchDataset

train_dataset = PytorchDataset(X=X_train, y=y_train)
val_dataset = PytorchDataset(X=X_val, y=y_val)
test_dataset = PytorchDataset(X=X_test, y=y_test)

### Establish baseline

In [33]:
from src.models.null import NullModel
from src.models.performance import print_class_perf

baseline_model = NullModel(target_type='classification')
y_base = baseline_model.fit_predict(y_train)

print_class_perf(y_base, y_train, set_name='Training', average='weighted')

Accuracy Training: 0.07430256975352931
F1 Training: 0.01027805764859095


## Create Network Architecture

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [35]:
class PytorchMultiClass(nn.Module):
    def __init__(self, num_features, num_classes):
        super(PytorchMultiClass, self).__init__()
        
        self.layer_1 = nn.Linear(num_features, 256)
        self.layer_2 = nn.Linear(256, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = F.dropout(F.relu(self.layer_1(x)), training=self.training)
        x = self.layer_2(x)
        return self.softmax(x)

In [36]:
from src.models.pytorch import get_device

model = PytorchMultiClass(X_train.shape[1], num_beer_style)

device = get_device()
model.to(device)

print(device)
print(model)

cuda:0
PytorchMultiClass(
  (layer_1): Linear(in_features=107, out_features=256, bias=True)
  (layer_2): Linear(in_features=256, out_features=104, bias=True)
  (softmax): Softmax(dim=1)
)


In [37]:
from torch.utils.data import DataLoader

In [38]:
criterion = nn.CrossEntropyLoss()

In [39]:
optimiser = torch.optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.01)

In [40]:
def train_classification(train_data, model, criterion, optimiser, batch_size, device, scheduler=None, generate_batch=None):
    """Train a Pytorch multi-class classification model

    Parameters
    ----------
    train_data : torch.utils.data.Dataset
        Pytorch dataset
    model: torch.nn.Module
        Pytorch Model
    criterion: function
        Loss function
    optimiser: torch.optim
        Optimiser
    bacth_size : int
        Number of observations per batch
    device : str
        Name of the device used for the model
    scheduler : torch.optim.lr_scheduler
        Pytorch Scheduler used for updating learning rate
    collate_fn : function
        Function defining required pre-processing steps

    Returns
    -------
    Float
        Loss score
    Float:
        Accuracy Score
    """
    
    # Set model to training mode
    model.train()
    train_loss = 0
    train_acc = 0

    # Create data loader
    data = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
    
    # Iterate through data by batch of observations
    for feature, target_class in data:
        
        # Reset gradients
        optimiser.zero_grad()
        
        # Load data to specified device
        feature = feature.to(device)
        target_class = target_class.to(device)
        
        # Make predictions
        #output = model(feature)
        output = model(feature.float())
        
        # Calculate loss for given batch
        #loss = criterion(output, target_class)
        loss = criterion(output, target_class.long())
        
        # Calculate global loss
        train_loss += loss.item()
        
        # Calculate gradients
        loss.backward()
        
        # Update Weights
        optimiser.step()

        # Calculate global accuracy
        train_acc += (output.argmax(1) == target_class).sum().item()

    # Adjust the learning rate
    if scheduler:
        scheduler.step()

    return train_loss / len(train_data), train_acc / len(train_data)

In [41]:
def test_classification(test_data, model, criterion, batch_size, device, generate_batch=None):
    """Calculate performance of a Pytorch multi-class classification model

    Parameters
    ----------
    test_data : torch.utils.data.Dataset
        Pytorch dataset
    model: torch.nn.Module
        Pytorch Model
    criterion: function
        Loss function
    bacth_size : int
        Number of observations per batch
    device : str
        Name of the device used for the model
    collate_fn : function
        Function defining required pre-processing steps

    Returns
    -------
    Float
        Loss score
    Float:
        Accuracy Score
    """    
    
    # Set model to evaluation mode
    model.eval()
    test_loss = 0
    test_acc = 0

    # Create data loader
    data = DataLoader(test_data, batch_size=batch_size, collate_fn=generate_batch)
    
    # Iterate through data by batch of observations
    for feature, target_class in data:
        
        # Load data to specified device
        feature = feature.to(device)
        target_class = target_class.to(device)
        
        # Set no update to gradients
        with torch.no_grad():
            
            # Make predictions
            #output = model(feature)
            output = model(feature.float())
            
            # Calculate loss for given batch
            #loss = criterion(output, target_class)
            loss = criterion(output, target_class.long())
            
            # Calculate global loss
            test_loss += loss.item()

            # Calculate global accuracy
            test_acc += (output.argmax(1) == target_class).sum().item()

    return test_loss / len(test_data), test_acc / len(test_data)

## Train Model

### Load saved weights from previous training

In [42]:
#model = PytorchMultiClass(X_train.shape[1], num_beer_style)
optimiser = torch.optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.01)

state = torch.load('../models/deep_beer_state.pt')

model.load_state_dict(state['model_state_dict'])
optimiser.load_state_dict(state['optimiser_state_dict'])

epoch = state['epoch']
train_loss = state['train_loss']
train_acc = state['train_acc']
valid_loss = state['valid_loss']
valid_acc = state['valid_acc']

get_device()

device(type='cuda', index=0)

### Run Epochs

In [43]:
N_EPOCHS = 5
BATCH_SIZE = 32
if epoch < 1:
    epoch = 1

In [44]:
for i in range(N_EPOCHS):
    train_loss, train_acc = train_classification(train_dataset, model=model, criterion=criterion, optimiser=optimiser, batch_size=BATCH_SIZE, device=device)
    valid_loss, valid_acc = test_classification(val_dataset, model=model, criterion=criterion, batch_size=BATCH_SIZE, device=device)

    print(f'Epoch: {epoch}')
    print(f'\t(train)\t|\tLoss: {train_loss:.4f}\t|\tAcc: {train_acc * 100:.1f}%')
    print(f'\t(valid)\t|\tLoss: {valid_loss:.4f}\t|\tAcc: {valid_acc * 100:.1f}%')
    epoch += 1

Epoch: 85
	(train)	|	Loss: 0.1372	|	Acc: 27.2%
	(valid)	|	Loss: 0.1372	|	Acc: 27.2%
Epoch: 86
	(train)	|	Loss: 0.1372	|	Acc: 27.2%
	(valid)	|	Loss: 0.1372	|	Acc: 27.1%
Epoch: 87
	(train)	|	Loss: 0.1372	|	Acc: 27.2%
	(valid)	|	Loss: 0.1372	|	Acc: 27.1%
Epoch: 88
	(train)	|	Loss: 0.1372	|	Acc: 27.2%
	(valid)	|	Loss: 0.1372	|	Acc: 27.2%
Epoch: 89
	(train)	|	Loss: 0.1372	|	Acc: 27.2%
	(valid)	|	Loss: 0.1372	|	Acc: 27.2%


## Save artefacts

In [45]:
import joblib

torch.save(model, '../models/deep_beer.pt')

torch.save({
  'model_state_dict': model.state_dict(),
  'optimiser_state_dict': optimiser.state_dict(),
  'epoch': epoch,
  'train_loss': train_loss,
  'train_acc': train_acc,
  'valid_loss': valid_loss,
  'valid_acc': valid_acc
}, '../models/deep_beer_state.pt')

joblib.dump(ordinal_encoder, '../models/ordinal_encoder.joblib')
joblib.dump(target_encoder, '../models/target_encoder.joblib')
joblib.dump(min_max_scaler, '../models/min_max_scaler.joblib')

['../models/min_max_scaler.joblib']

## Load model

In [46]:
model = PytorchMultiClass(X_train.shape[1], num_beer_style)
optimiser = torch.optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.01)

state = torch.load('../models/deep_beer_state.pt')

model.load_state_dict(state['model_state_dict'])
optimiser.load_state_dict(state['optimiser_state_dict'])

epoch = state['epoch']
train_loss = state['train_loss']
train_acc = state['train_acc']
valid_loss = state['valid_loss']
valid_acc = state['valid_acc']

model.to('cuda:0')

PytorchMultiClass(
  (layer_1): Linear(in_features=107, out_features=256, bias=True)
  (layer_2): Linear(in_features=256, out_features=104, bias=True)
  (softmax): Softmax(dim=1)
)

## Check performance of saved model

In [47]:
test_loss, test_acc = test_classification(test_dataset, model=model, criterion=criterion, batch_size=BATCH_SIZE, device=device)

print(f'Loss: {test_loss:.4f}\nAccuracy: {test_acc:.3f}')

Loss: 0.1372
Accuracy: 0.273


In [48]:
from src.models.performance import check_predictions

check_predictions(test_dataset, model, ordinal_encoder, 100)

0: American IPA | Oatmeal Stout
1: Bock | Winter Warmer
2: American Pale Ale (APA) | American Double / Imperial Stout
3: American Pale Ale (APA) | Quadrupel (Quad)
4: American Amber / Red Ale | American Brown Ale
5: American Double / Imperial IPA ✅
6: Russian Imperial Stout | Fruit / Vegetable Beer
7: Hefeweizen | Doppelbock
8: American IPA ✅
9: American IPA ✅
10: American Double / Imperial IPA | Witbier
11: English Barleywine ✅
12: Euro Pale Lager ✅
13: American Amber / Red Ale | American Amber / Red Lager
14: American IPA | Flanders Oud Bruin
15: Belgian Strong Pale Ale ✅
16: Saison / Farmhouse Ale | Dubbel
17: American Amber / Red Ale ✅
18: Hefeweizen | Munich Dunkel Lager
19: American Pale Ale (APA) | American Pale Wheat Ale
20: American Double / Imperial IPA ✅
21: English India Pale Ale (IPA) | Saison / Farmhouse Ale
22: Milk / Sweet Stout | American Double / Imperial Stout
23: American Amber / Red Ale | Maibock / Helles Bock
24: American Double / Imperial IPA | Saison / Farmhouse