In [36]:
import torch.nn as nn
import torch.utils as TorchUtils
import torch.optim as optim
import numpy as np
import yaml
from datetime import datetime
from torchmetrics.classification import Accuracy
import torch
import pandas as pd
from torch.utils.data import Dataset


In [45]:
def label_dict_from_config_file(relative_path):
    with open(relative_path, 'r') as file:
        label_tag = yaml.full_load(file)['gestures']
    return label_tag

class CustomImageDataset(Dataset):
    def __init__(self, data_file):
        self.data = pd.read_csv(data_file)
        self.labels = torch.from_numpy(self.data.iloc[:,0].to_numpy())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        one_hot_label = self.labels[idx]
        torch_data = torch.from_numpy(self.data.iloc[idx,1:].to_numpy(dtype=np.float32))
        return torch_data, one_hot_label

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.watched_metrics = np.inf

    def early_stop(self, current_value):
        if current_value < self.watched_metrics:
            self.watched_metrics = current_value
            self.counter = 0
        elif current_value > (self.watched_metrics + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [39]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        list_label = label_dict_from_config_file('./data/hand_gesture.yaml')
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(63, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(128, len(list_label))
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
    
    def predict(self, x, threshold=0.8):
        logits = self(x)
        softmax_prob = nn.Softmax(dim=1)(logits)
        chosen_ind = torch.argmax(softmax_prob, dim=1)
        return torch.where(softmax_prob[0, chosen_ind] > threshold, chosen_ind, -1)
    
    def predict_with_know_class(self, x):
        logits = self(x)
        softmax_prob = nn.Softmax(dim=1)(logits)
        chosen_ind = torch.argmax(softmax_prob, dim=1)
        return chosen_ind
    
    def score(self, logits):
        return torch.amax(logits, dim=1)
        


In [52]:
def train(train_loader, val_loader, model: NeuralNetwork, optimizer, loss_function, early_stopper):
    best_vloss = 1_000_000
    timestamp = datetime.now().strftime("%d-%m %H:%M")
    list_label = label_dict_from_config_file('./data/hand_gesture.yaml')
    for epoch in range(300):
        model.train(True)
        running_loss = 0.0
        acc_train = Accuracy(num_classes=len(list_label), task='multiclass')
        for batch_number,data in enumerate(train_loader):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)

            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            acc_train.update(model.predict_with_know_class(inputs), labels)
            running_loss += loss.item()
    
        avg_loss = running_loss / len(train_loader)

        # validation step
        model.train(False)
        running_v_loss = 0.0
        acc_val = Accuracy(num_classes=len(list_label), task='multiclass')
        for i, v_data in enumerate(train_loader):
            v_inputs, v_labels = v_data
            outputs = model(v_inputs)
            v_loss = loss_function(outputs, v_labels)
            running_v_loss += v_loss.item()
            acc_val.update(model.predict_with_know_class(v_inputs), v_labels)

        # log the running loss averaged per batch
        # for both training and validation
        print(f"Epoch {epoch}: ")
        print(f"Accuracy train:{acc_train.compute().item()}, val:{acc_val.compute().item()}")
        avg_vloss = running_v_loss / len(val_loader)
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
        print('Training vs. Validation Loss',
                        { 'Training' : avg_loss, 'Validation' : avg_vloss },
                        epoch + 1)
        print('Training vs. Validation accuracy',
                        { 'Training' : acc_train.compute().item()
                        , 'Validation' : acc_val.compute().item() },
                        epoch + 1)
            
        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            best_model_path = f'./model/model_{timestamp}_{model.__class__.__name__}_best'
            torch.save(model.state_dict(), best_model_path)

        if early_stopper.early_stop(avg_vloss):
            print('Early stopping')
            break
        
    model_path = f'./model/model_{timestamp}_{model.__class__.__name__}_last'
    torch.save(model.state_dict(), model_path)

    print(acc_val.compute())
    return model, best_model_path

In [54]:
LIST_LABEL = label_dict_from_config_file("./data/hand_gesture.yaml")
DATA_FOLDER_PATH="./data/"

train_set = CustomImageDataset("./data/landmark_train.csv")
train_loader = TorchUtils.data.DataLoader(train_set, batch_size=40, shuffle=True)

val_set = CustomImageDataset("./data/landmark_val.csv")
val_loader = TorchUtils.data.DataLoader(val_set, batch_size=50, shuffle=False)

model = NeuralNetwork()
loss_function = nn.CrossEntropyLoss()
early_stopper = EarlyStopper(patience=3, min_delta=0.0001)

# optimizer
lr = 0.0001
optimizer = optim.Adam(model.parameters(), lr=lr)

model, best_model_path = train(train_loader, val_loader, model, optimizer, loss_function, early_stopper)

Epoch 0: 
Accuracy train:0.16192345321178436, val:0.29440629482269287
LOSS train 1.7973114985686083 valid 7.704931199550629
Training vs. Validation Loss {'Training': 1.7973114985686083, 'Validation': 7.704931199550629} 1
Training vs. Validation accuracy {'Training': 0.16192345321178436, 'Validation': 0.29440629482269287} 1
Epoch 1: 
Accuracy train:0.23552502691745758, val:0.33856722712516785
LOSS train 1.767214807180258 valid 7.576145907243093
Training vs. Validation Loss {'Training': 1.767214807180258, 'Validation': 7.576145907243093} 2
Training vs. Validation accuracy {'Training': 0.23552502691745758, 'Validation': 0.33856722712516785} 2
Epoch 2: 
Accuracy train:0.2875367999076843, val:0.3356231451034546
LOSS train 1.7417794878666217 valid 7.435703078905742
Training vs. Validation Loss {'Training': 1.7417794878666217, 'Validation': 7.435703078905742} 3
Training vs. Validation accuracy {'Training': 0.2875367999076843, 'Validation': 0.3356231451034546} 3
Epoch 3: 
Accuracy train:0.3179

In [56]:
list_label = label_dict_from_config_file("./data/hand_gesture.yaml")
DATA_FOLDER_PATH="./data/"
testset = CustomImageDataset("./data/landmark_test.csv")

# Test DataLoader instantiation
################## Your Code Here ################## Q6
''' Hoàn thành code bên dưới để  khởi tạo DataLoader cho testset with batch size
20, không cho phép shuffle
'''
test_loader = TorchUtils.data.DataLoader(testset, batch_size=20, shuffle=False)
####################################################



network = NeuralNetwork()
network.load_state_dict(torch.load(best_model_path, weights_only=False))

network.eval()
acc_test = Accuracy(num_classes=len(list_label), task='MULTICLASS')
for i, test_data in enumerate(test_loader):
    test_input, test_label = test_data
    preds = network.predict(test_input)
    acc_test.update(preds, test_label)

    ####################################################

print(network.__class__.__name__)
print(f"Accuracy of model:{acc_test.compute().item()}")
print("========================================================================")

NeuralNetwork
Accuracy of model:0.8070865869522095
