In [3]:
# Install necessary libraries
!pip install pytorch-tabnet
!pip install captum
!pip install optuna
!pip install imbalanced-learn
!pip install dask-expr
!pip install scikit-learn-contrib
!pip install lightgbm

# Data manipulation and analysis
import pandas as pd
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Preprocessing and modeling
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# Handling imbalanced data
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import TomekLinks

# Deep Learning Model
from pytorch_tabnet.tab_model import TabNetClassifier

# Explainable AI
import shap
from captum.attr import IntegratedGradients

# Hyperparameter Optimization
import optuna
from optuna import Trial

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

# For model saving and loading
import joblib

# Import torch for TabNet
import torch

# Define the PermutationImportanceTabNet class
class PermutationImportanceTabNet(TabNetClassifier):
    def __init__(self, input_dim, feature_names, permutation_prob=0.1, importance_decay=0.99, *args, **kwargs):
        """
        Initializes the PermutationImportanceTabNet.

        Parameters:
        - input_dim (int): Number of input features.
        - feature_names (list): List of feature names.
        - permutation_prob (float): Probability of applying permutation during a forward pass.
        - importance_decay (float): Decay factor for importance scores to smooth over epochs.
        - *args, **kwargs: Additional arguments for TabNetClassifier.
        """
        super(PermutationImportanceTabNet, self).__init__(input_dim=input_dim, *args, **kwargs)
        self.permutation_prob = permutation_prob
        self.importance_scores = torch.zeros(input_dim)
        self.importance_decay = importance_decay  # To smooth importance scores
        self.feature_names = feature_names  # List of feature names for interpretability

    def forward(self, X, y=None):
        """
        Overrides the forward pass to include permutation-based feature importance.

        Parameters:
        - X (torch.Tensor): Input features.
        - y (torch.Tensor, optional): Target labels.

        Returns:
        - out (torch.Tensor): Model outputs.
        - M_loss (float): Mask loss.
        """
        # Original forward pass
        out, M_loss = super(PermutationImportanceTabNet, self).forward(X, y)

        # Apply permutation with a certain probability
        if torch.rand(1).item() < self.permutation_prob:
            # Iterate over each feature to assess its importance
            for i in range(X.size(1)):
                # Clone the input to avoid in-place modifications
                X_permuted = X.clone()

                # Permute the values of the i-th feature across the batch
                X_permuted[:, i] = X_permuted[torch.randperm(X_permuted.size(0)), i]

                # Forward pass with permuted feature
                out_permuted, _ = super(PermutationImportanceTabNet, self).forward(X_permuted, y)

                # Compute predictions
                preds = out.argmax(dim=1)
                preds_permuted = out_permuted.argmax(dim=1)

                # Calculate accuracy
                acc = accuracy_score(y.cpu().numpy(), preds.cpu().numpy())
                acc_perm = accuracy_score(y.cpu().numpy(), preds_permuted.cpu().numpy())

                # Drop in accuracy signifies feature importance
                drop = acc - acc_perm

                # Update importance scores with decay
                self.importance_scores[i] = self.importance_decay * self.importance_scores[i] + (1 - self.importance_decay) * drop

            # Normalize importance scores to sum to 1 for interpretability
            if self.importance_scores.sum() != 0:
                self.importance_scores = self.importance_scores / self.importance_scores.sum()

            # Print feature importance scores
            print("\nFeature Importance Scores after Permutation:")
            for idx, score in enumerate(self.importance_scores):
                feature_name = self.feature_names[idx]
                print(f"{feature_name}: {score.item():.4f}")

        return out, M_loss

# Load the dataset
data = pd.read_csv('/content/fetal_health.csv')

# Display the first five rows to verify
print("First five rows of the dataset:")
print(data.head())

# Check the shape of the dataset
print(f"\nDataset Shape: {data.shape}")

# Features to drop based on prior analysis
features_to_drop = [
    'fetal_movement',
    'histogram_width',
    'histogram_max',
    'mean_value_of_long_term_variability',
    'histogram_number_of_peaks',
    'light_decelerations',
    'histogram_tendency',
    'histogram_number_of_zeroes',
    'severe_decelerations',
    'baseline value',
    'histogram_min'
]

# Drop the specified features
data_dropped = data.drop(columns=features_to_drop)

# Verify the remaining features
print("\nFeatures after dropping less important ones:")
print(data_dropped.columns.tolist())

# Check the new shape of the dataset
print(f"\nNew Dataset Shape after dropping features: {data_dropped.shape}")

# Convert 'fetal_health' to integer
data_dropped['fetal_health'] = data_dropped['fetal_health'].astype(int)

# Mapping numerical classes to descriptive labels
health_mapping = {1: 'Normal', 2: 'Suspect', 3: 'Pathological'}
data_dropped['fetal_health_label'] = data_dropped['fetal_health'].map(health_mapping)

# Display the mapping
print("\nDataset with Mapped Labels:")
print(data_dropped[['fetal_health', 'fetal_health_label']].head())

# Features and target
X = data_dropped.drop(['fetal_health', 'fetal_health_label'], axis=1)
y = data_dropped['fetal_health']

# Initialize SMOTE with 'auto' strategy to resample all classes
smote = SMOTE(sampling_strategy='auto', random_state=42)

# Apply SMOTE to the dataset
X_smote, y_smote = smote.fit_resample(X, y)

# Initialize Tomek Links
tomek = TomekLinks()

# Apply Tomek Links to clean the dataset
X_resampled, y_resampled = tomek.fit_resample(X_smote, y_smote)

# Display the shape of the resampled dataset and class distribution
print(f"\nResampled X shape after SMOTE + Tomek Links: {X_resampled.shape}")
print(f"Resampled y distribution after SMOTE + Tomek Links:\n{y_resampled.value_counts()}")

# Split the resampled data (70% train, 30% test) with stratification
X_train, X_test, y_train, y_test = train_test_split(
    X_resampled, y_resampled, test_size=0.3, random_state=42, stratify=y_resampled
)

# Display the shapes of the training and testing sets
print(f"\nTraining set shape: {X_train.shape}")
print(f"Testing set shape: {X_test.shape}")

# Initialize the MinMaxScaler
scaler = MinMaxScaler()

# Fit the scaler on the training data and transform
X_train_scaled = scaler.fit_transform(X_train)

# Transform the testing data
X_test_scaled = scaler.transform(X_test)

# Convert the scaled arrays back to DataFrames for easier handling
X_train_scaled = pd.DataFrame(X_train_scaled, columns=X.columns, index=X_train.index)
X_test_scaled = pd.DataFrame(X_test_scaled, columns=X.columns, index=X_test.index)

# Verify scaling by checking min and max values
print("\nMin of Scaled Training Features (Should be 0):")
print(X_train_scaled.min())

print("\nMax of Scaled Training Features (Should be 1):")
print(X_train_scaled.max())

# Adjust the target values so they start from 0
y_train = y_train - 1
y_test = y_test - 1

# Display the adjusted target distributions
print("\nAdjusted y_train distribution:")
print(pd.Series(y_train).value_counts())

print("\nAdjusted y_test distribution:")
print(pd.Series(y_test).value_counts())

# Further split the training data into training and validation sets
X_train_final, X_valid, y_train_final, y_valid = train_test_split(
    X_train_scaled, y_train, test_size=0.2, random_state=42, stratify=y_train
)

# Display the shapes of the final training and validation sets
print(f"\nFinal Training set shape: {X_train_final.shape}")
print(f"Validation set shape: {X_valid.shape}")

# -------------------
# Hyperparameter Optimization with Optuna
# -------------------
def objective(trial: Trial):
    # Define the hyperparameter space
    n_d = trial.suggest_int('n_d', 32, 128)
    n_a = trial.suggest_int('n_a', 32, 128)
    n_steps = trial.suggest_int('n_steps', 3, 10)
    gamma = trial.suggest_float('gamma', 1.0, 2.0)
    lambda_sparse = trial.suggest_float('lambda_sparse', 1e-4, 1e-2, log=True)
    learning_rate = trial.suggest_float('learning_rate', 1e-3, 1e-1, log=True)
    batch_size = trial.suggest_categorical('batch_size', [128, 256, 512])

    # Initialize TabNet with current hyperparameters
    tabnet = TabNetClassifier(
        n_d=n_d,
        n_a=n_a,
        n_steps=n_steps,
        gamma=gamma,
        lambda_sparse=lambda_sparse,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=learning_rate),
        mask_type='sparsemax',
        verbose=0
    )

    # Train the model on the final training set
    tabnet.fit(
        X_train=X_train_final.values,
        y_train=y_train_final.values,
        eval_set=[(X_valid.values, y_valid.values)],
        eval_name=['valid'],
        eval_metric=['accuracy'],
        max_epochs=100,
        patience=20,
        batch_size=batch_size,
        virtual_batch_size=128
    )

    # Predict on the validation set
    y_pred = tabnet.predict(X_valid.values)
    accuracy = accuracy_score(y_valid, y_pred)

    return accuracy

# Create and optimize the Optuna study
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50, timeout=3600)  # Adjust n_trials and timeout as needed

print("Best Hyperparameters: ", study.best_params)
print("Best Validation Accuracy: ", study.best_value)



[31mERROR: Could not find a version that satisfies the requirement scikit-learn-contrib (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for scikit-learn-contrib[0m[31m


[I 2025-01-12 07:05:55,243] A new study created in memory with name: no-name-739949e5-f554-43ec-8096-7d7fbfac204c


First five rows of the dataset:
   baseline value  accelerations  fetal_movement  uterine_contractions  \
0           120.0          0.000             0.0                 0.000   
1           132.0          0.006             0.0                 0.006   
2           133.0          0.003             0.0                 0.008   
3           134.0          0.003             0.0                 0.008   
4           132.0          0.007             0.0                 0.008   

   light_decelerations  severe_decelerations  prolongued_decelerations  \
0                0.000                   0.0                       0.0   
1                0.003                   0.0                       0.0   
2                0.003                   0.0                       0.0   
3                0.003                   0.0                       0.0   
4                0.000                   0.0                       0.0   

   abnormal_short_term_variability  mean_value_of_short_term_variability  \
0 

[I 2025-01-12 07:09:08,148] Trial 0 finished with value: 0.950937950937951 and parameters: {'n_d': 128, 'n_a': 67, 'n_steps': 7, 'gamma': 1.6614778909407883, 'lambda_sparse': 0.0003674811482794505, 'learning_rate': 0.053065035211493534, 'batch_size': 256}. Best is trial 0 with value: 0.950937950937951.



Early stopping occurred at epoch 75 with best_epoch = 55 and best_valid_accuracy = 0.96825


[I 2025-01-12 07:12:38,267] Trial 1 finished with value: 0.9682539682539683 and parameters: {'n_d': 69, 'n_a': 76, 'n_steps': 10, 'gamma': 1.036562066833326, 'lambda_sparse': 0.00017310020872191319, 'learning_rate': 0.0025858439594720482, 'batch_size': 128}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 73 with best_epoch = 53 and best_valid_accuracy = 0.95382


[I 2025-01-12 07:16:05,699] Trial 2 finished with value: 0.9538239538239538 and parameters: {'n_d': 121, 'n_a': 65, 'n_steps': 8, 'gamma': 1.5431685896698735, 'lambda_sparse': 0.0010945116691778315, 'learning_rate': 0.011328685093238472, 'batch_size': 128}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 96 with best_epoch = 76 and best_valid_accuracy = 0.95094


[I 2025-01-12 07:18:31,878] Trial 3 finished with value: 0.950937950937951 and parameters: {'n_d': 99, 'n_a': 34, 'n_steps': 8, 'gamma': 1.4585069841194684, 'lambda_sparse': 0.00020678759684493472, 'learning_rate': 0.09935821040527626, 'batch_size': 256}. Best is trial 1 with value: 0.9682539682539683.


Stop training because you reached max_epochs = 100 with best_epoch = 99 and best_valid_accuracy = 0.94228


[I 2025-01-12 07:21:29,830] Trial 4 finished with value: 0.9422799422799423 and parameters: {'n_d': 101, 'n_a': 95, 'n_steps': 8, 'gamma': 1.5142795723030338, 'lambda_sparse': 0.008783603601289715, 'learning_rate': 0.006670546160823257, 'batch_size': 512}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 72 with best_epoch = 52 and best_valid_accuracy = 0.95527


[I 2025-01-12 07:22:53,845] Trial 5 finished with value: 0.9552669552669553 and parameters: {'n_d': 99, 'n_a': 37, 'n_steps': 6, 'gamma': 1.3641102038496329, 'lambda_sparse': 0.0008296876212410576, 'learning_rate': 0.026007212580657062, 'batch_size': 256}. Best is trial 1 with value: 0.9682539682539683.


Stop training because you reached max_epochs = 100 with best_epoch = 80 and best_valid_accuracy = 0.95527


[I 2025-01-12 07:24:10,085] Trial 6 finished with value: 0.9552669552669553 and parameters: {'n_d': 51, 'n_a': 64, 'n_steps': 4, 'gamma': 1.8656905177279732, 'lambda_sparse': 0.00010726451072331761, 'learning_rate': 0.05078031831722836, 'batch_size': 256}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 69 with best_epoch = 49 and best_valid_accuracy = 0.96392


[I 2025-01-12 07:25:12,190] Trial 7 finished with value: 0.963924963924964 and parameters: {'n_d': 115, 'n_a': 65, 'n_steps': 3, 'gamma': 1.0429480562724731, 'lambda_sparse': 0.0007551739774701521, 'learning_rate': 0.005218127227572316, 'batch_size': 256}. Best is trial 1 with value: 0.9682539682539683.


Stop training because you reached max_epochs = 100 with best_epoch = 90 and best_valid_accuracy = 0.94228


[I 2025-01-12 07:26:50,712] Trial 8 finished with value: 0.9422799422799423 and parameters: {'n_d': 46, 'n_a': 112, 'n_steps': 5, 'gamma': 1.7799227601263083, 'lambda_sparse': 0.00834015039026745, 'learning_rate': 0.007121248000041844, 'batch_size': 512}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 78 with best_epoch = 58 and best_valid_accuracy = 0.96392


[I 2025-01-12 07:28:40,494] Trial 9 finished with value: 0.963924963924964 and parameters: {'n_d': 107, 'n_a': 83, 'n_steps': 5, 'gamma': 1.0162394730548785, 'lambda_sparse': 0.0001238417432903007, 'learning_rate': 0.0015244904771003636, 'batch_size': 256}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 50 with best_epoch = 30 and best_valid_accuracy = 0.94084


[I 2025-01-12 07:31:23,338] Trial 10 finished with value: 0.9408369408369408 and parameters: {'n_d': 68, 'n_a': 109, 'n_steps': 10, 'gamma': 1.2532342769982632, 'lambda_sparse': 0.0026388398503359987, 'learning_rate': 0.0011322048274515432, 'batch_size': 128}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 47 with best_epoch = 27 and best_valid_accuracy = 0.96392


[I 2025-01-12 07:32:08,487] Trial 11 finished with value: 0.963924963924964 and parameters: {'n_d': 72, 'n_a': 56, 'n_steps': 3, 'gamma': 1.0444508374914492, 'lambda_sparse': 0.00046123242949527615, 'learning_rate': 0.0028000590058821786, 'batch_size': 128}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 90 with best_epoch = 70 and best_valid_accuracy = 0.96104


[I 2025-01-12 07:36:36,138] Trial 12 finished with value: 0.961038961038961 and parameters: {'n_d': 84, 'n_a': 84, 'n_steps': 10, 'gamma': 1.1857606550935684, 'lambda_sparse': 0.002120385421880758, 'learning_rate': 0.0034361300849609765, 'batch_size': 128}. Best is trial 1 with value: 0.9682539682539683.



Early stopping occurred at epoch 57 with best_epoch = 37 and best_valid_accuracy = 0.97114


[I 2025-01-12 07:37:41,904] Trial 13 finished with value: 0.9711399711399712 and parameters: {'n_d': 33, 'n_a': 128, 'n_steps': 3, 'gamma': 1.1761592252229112, 'lambda_sparse': 0.00036511785702883487, 'learning_rate': 0.002956299320724345, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.


Stop training because you reached max_epochs = 100 with best_epoch = 83 and best_valid_accuracy = 0.96392


[I 2025-01-12 07:41:34,369] Trial 14 finished with value: 0.963924963924964 and parameters: {'n_d': 33, 'n_a': 99, 'n_steps': 9, 'gamma': 1.2322680575282827, 'lambda_sparse': 0.00026983692414714945, 'learning_rate': 0.0021734319948215623, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 85 with best_epoch = 65 and best_valid_accuracy = 0.96537


[I 2025-01-12 07:44:38,620] Trial 15 finished with value: 0.9653679653679653 and parameters: {'n_d': 56, 'n_a': 125, 'n_steps': 6, 'gamma': 1.1859530680486106, 'lambda_sparse': 0.00019636845417664623, 'learning_rate': 0.013037687984984772, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 61 with best_epoch = 41 and best_valid_accuracy = 0.95094


[I 2025-01-12 07:45:34,977] Trial 16 finished with value: 0.950937950937951 and parameters: {'n_d': 32, 'n_a': 48, 'n_steps': 4, 'gamma': 1.9900792198671846, 'lambda_sparse': 0.0005026759442447027, 'learning_rate': 0.004030266424807247, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 74 with best_epoch = 54 and best_valid_accuracy = 0.96825


[I 2025-01-12 07:48:57,201] Trial 17 finished with value: 0.9682539682539683 and parameters: {'n_d': 84, 'n_a': 126, 'n_steps': 7, 'gamma': 1.3740824663704292, 'lambda_sparse': 0.00016282662917101187, 'learning_rate': 0.0018297875834455389, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.


Stop training because you reached max_epochs = 100 with best_epoch = 98 and best_valid_accuracy = 0.95671


[I 2025-01-12 07:51:22,518] Trial 18 finished with value: 0.9567099567099567 and parameters: {'n_d': 63, 'n_a': 78, 'n_steps': 9, 'gamma': 1.1258275143437144, 'lambda_sparse': 0.0003310441649102684, 'learning_rate': 0.01826128359478549, 'batch_size': 512}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 72 with best_epoch = 52 and best_valid_accuracy = 0.95382


[I 2025-01-12 07:53:09,374] Trial 19 finished with value: 0.9538239538239538 and parameters: {'n_d': 43, 'n_a': 93, 'n_steps': 5, 'gamma': 1.330788454583992, 'lambda_sparse': 0.0013285797077193019, 'learning_rate': 0.0010334815092520459, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 82 with best_epoch = 62 and best_valid_accuracy = 0.96681


[I 2025-01-12 07:57:28,829] Trial 20 finished with value: 0.9668109668109668 and parameters: {'n_d': 77, 'n_a': 112, 'n_steps': 9, 'gamma': 1.1034189447326803, 'lambda_sparse': 0.0005465362845907321, 'learning_rate': 0.0029228890950667324, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 82 with best_epoch = 62 and best_valid_accuracy = 0.9596


[I 2025-01-12 08:01:03,321] Trial 21 finished with value: 0.9595959595959596 and parameters: {'n_d': 88, 'n_a': 114, 'n_steps': 7, 'gamma': 1.3414129811597204, 'lambda_sparse': 0.0001640907562665475, 'learning_rate': 0.001866081873148882, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 93 with best_epoch = 73 and best_valid_accuracy = 0.96248


[I 2025-01-12 08:05:21,911] Trial 22 finished with value: 0.9624819624819625 and parameters: {'n_d': 84, 'n_a': 127, 'n_steps': 7, 'gamma': 1.4214468942829412, 'lambda_sparse': 0.0002334756135464357, 'learning_rate': 0.0014508600133250587, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.



Early stopping occurred at epoch 62 with best_epoch = 42 and best_valid_accuracy = 0.96825


[I 2025-01-12 08:06:55,163] Trial 23 finished with value: 0.9682539682539683 and parameters: {'n_d': 60, 'n_a': 120, 'n_steps': 4, 'gamma': 1.28929895439721, 'lambda_sparse': 0.00014294646624166726, 'learning_rate': 0.0022519727233780364, 'batch_size': 128}. Best is trial 13 with value: 0.9711399711399712.


Best Hyperparameters:  {'n_d': 33, 'n_a': 128, 'n_steps': 3, 'gamma': 1.1761592252229112, 'lambda_sparse': 0.00036511785702883487, 'learning_rate': 0.002956299320724345, 'batch_size': 128}
Best Validation Accuracy:  0.9711399711399712


In [4]:
# 8. Retrain TabNet with Best Hyperparameters Using PermutationImportanceTabNet
# -------------------
# Extract best hyperparameters
best_params = study.best_params

# Adjust keys if necessary
# (Ensure that 'learning_rate' and 'batch_size' are correctly handled)
# In this case, no adjustment is needed as keys are consistent

# Define feature names for interpretability
feature_names = X.columns.tolist()

# Determine the input dimension from the training data
input_dim = X_train_final.shape[1]

# Initialize the Permutation Importance TabNet with the correct input_dim and feature_names
perm_importance_tabnet = PermutationImportanceTabNet(
    input_dim=input_dim,
    feature_names=feature_names,
    n_d=best_params['n_d'],
    n_a=best_params['n_a'],
    n_steps=best_params['n_steps'],
    gamma=best_params['gamma'],
    lambda_sparse=best_params['lambda_sparse'],
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=best_params['learning_rate']),
    mask_type='sparsemax',
    permutation_prob=0.1,          # 10% chance to apply permutation
    importance_decay=0.99,         # Decay factor for smoothing
    verbose=1
)

# Train the Permutation Importance TabNet model
perm_importance_tabnet.fit(
    X_train=X_train_final.values,
    y_train=y_train_final.values,
    eval_set=[(X_valid.values, y_valid.values), (X_test_scaled.values, y_test.values)],
    eval_name=['train', 'valid'],
    eval_metric=['accuracy'],
    max_epochs=100,
    patience=20,
    batch_size=best_params['batch_size'],
    virtual_batch_size=128
)

# Predict and evaluate on the test set
y_pred_perm_importance = perm_importance_tabnet.predict(X_test_scaled.values)
print("\nPermutation Importance TabNet Classification Report:")
print(classification_report(y_test, y_pred_perm_importance, target_names=['Normal', 'Suspect', 'Pathological']))

# Access and print feature importance scores
print("\nFeature Importance Scores:")
for idx, score in enumerate(perm_importance_tabnet.importance_scores):
    feature_name = perm_importance_tabnet.feature_names[idx]
    print(f"{feature_name}: {score.item():.4f}")

epoch 0  | loss: 0.73588 | train_accuracy: 0.34632 | valid_accuracy: 0.35243 |  0:00:03s
epoch 1  | loss: 0.36865 | train_accuracy: 0.37807 | valid_accuracy: 0.36927 |  0:00:08s
epoch 2  | loss: 0.29494 | train_accuracy: 0.5267  | valid_accuracy: 0.51011 |  0:00:15s
epoch 3  | loss: 0.26794 | train_accuracy: 0.57576 | valid_accuracy: 0.58288 |  0:00:20s
epoch 4  | loss: 0.24623 | train_accuracy: 0.56999 | valid_accuracy: 0.56469 |  0:00:24s
epoch 5  | loss: 0.24563 | train_accuracy: 0.63203 | valid_accuracy: 0.65499 |  0:00:29s
epoch 6  | loss: 0.23157 | train_accuracy: 0.65945 | valid_accuracy: 0.66509 |  0:00:32s
epoch 7  | loss: 0.20214 | train_accuracy: 0.68398 | valid_accuracy: 0.69879 |  0:00:33s
epoch 8  | loss: 0.21204 | train_accuracy: 0.7215  | valid_accuracy: 0.73248 |  0:00:35s
epoch 9  | loss: 0.20335 | train_accuracy: 0.73737 | valid_accuracy: 0.7473  |  0:00:37s
epoch 10 | loss: 0.19136 | train_accuracy: 0.79654 | valid_accuracy: 0.78976 |  0:00:40s
epoch 11 | loss: 0.16