In [1]:
import torch
from torch.utils.data import random_split, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(0)
np.random.seed(0)

In [2]:
real_normal_data = np.load("../eval_data/real_normal.npy")
real_fault_01_data = np.load("../eval_data/real_fault_01.npy")
real_fault_11_data = np.load("../eval_data/real_fault_11.npy")
real_fault_20_data = np.load("../eval_data/real_fault_20.npy")
gen_fault_01_data = np.load("../eval_data/gen_fault_01.npy")
gen_fault_11_data = np.load("../eval_data/gen_fault_11.npy")
gen_fault_20_data = np.load("../eval_data/gen_fault_20.npy")

In [3]:
class FaultDataset(torch.utils.data.Dataset):
    def __init__(self, features, labels):

        self.features = torch.tensor(features)
        self.labels = torch.tensor(labels)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx].type(torch.float32), self.labels[idx].type(torch.long)

In [4]:
class FaultDetector(nn.Module):
    def __init__(self, input_size=52):
        super(FaultDetector, self).__init__()
        self.lstm = nn.LSTM(input_size, 512, 2, batch_first=True)
        self.fc1 = nn.LazyLinear(64)
        self.fc2 = nn.LazyLinear(3)

    def forward(self, x):
        _, (hidden, _) = self.lstm(x)
        x = F.relu(self.fc1(hidden[-1]))
        x = self.fc2(x)
        return x

In [5]:
epochs = 2000
lr = 1e-5
batch_size = 128
train_valid_ratio = 0.9
num_test_per_fault_class = 100
num_samples_for_imbalanced_class = 50
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
def train_model(model, optimizer, criterion, train_loader, valid_loader, epochs):
    pbar = tqdm(range(epochs), leave=True)
    best_validation_loss = np.inf
    best_model = model.state_dict()
    train_loss_list = []
    valid_loss_list = []
    for epoch in pbar:
        model.train()
        running_loss = 0.0
        for bidx, batch in enumerate(train_loader):
            features, labels = batch
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(features)
            output = output.squeeze(dim=-1)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
 
        train_loss = running_loss / len(train_loader.dataset)
        train_loss_list.append(train_loss)

        model.eval()
        with torch.no_grad():
            valid_predictions = []
            valid_gt = []
            running_validation_loss = 0.0
            for bidx, batch in enumerate(valid_loader):
                features, labels = batch
                features, labels = features.to(device), labels.to(device)
                output = model(features)
                output = output.squeeze(dim=-1)
                loss = criterion(output, labels)
                running_validation_loss += loss.item()
                output = torch.argmax(output, dim=-1)
                valid_predictions.extend(output.cpu().detach().numpy())
                valid_gt.extend(labels.cpu().detach().numpy())
            validation_loss = running_validation_loss / len(valid_loader.dataset)
            valid_loss_list.append(validation_loss)

            if validation_loss < best_validation_loss:
                best_model = model.state_dict()
        
        train_loss = running_loss / len(train_loader.dataset)
        acc = np.mean(valid_predictions == valid_gt)
        pbar.set_postfix(train_loss=train_loss, validation_loss=validation_loss, validation_acc=acc)

    train_predictions = []
    train_gt = []
    model.eval()
    with torch.no_grad():
        for bidx, batch in enumerate(train_loader):
            features, labels = batch
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(features)
            output = output.squeeze(dim=-1)
            output = torch.argmax(output, dim=-1)
            train_predictions.extend(output.cpu().detach().numpy())
            train_gt.extend(labels.cpu().detach().numpy())


    model.load_state_dict(best_model)
    train_report = classification_report(train_gt, train_predictions, target_names=["fault_01", "fault_11", "fault_20"], output_dict=False)
    valid_report = classification_report(valid_gt, valid_predictions, target_names=["fault_01", "fault_11", "fault_20"], output_dict=False)

    return train_loss_list, valid_loss_list, train_report, valid_report

In [7]:
def test_model(model, criterion, test_loader):
    with torch.no_grad():
        model.eval()
        running_loss = 0.0
        predictions = []
        gt = []
        for bidx, batch in enumerate(test_loader):
            features, labels = batch
            features, labels = features.to(device), labels.to(device)
            output = model(features)
            output = output.squeeze(dim=-1)
            loss = criterion(output, labels)
            running_loss += loss.item()
            output = torch.argmax(output, dim=-1)
            predictions.extend(output.cpu().detach().numpy())
            gt.extend(labels.cpu().detach().numpy())
        avg_loss = running_loss / len(test_loader.dataset)
        report = classification_report(gt, predictions, target_names=["fault_01", "fault_11", "fault_20"], output_dict=False)
        return avg_loss, report

In [8]:
test_features = np.concatenate([real_fault_01_data[-num_test_per_fault_class:], real_fault_11_data[-num_test_per_fault_class:], real_fault_20_data[-num_test_per_fault_class:]])
test_labels = [[i]*num_test_per_fault_class for i in range(3)]
test_labels = np.concatenate(test_labels)
test_dataset = FaultDataset(test_features, test_labels)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:
augmented_class_split = []
augmented_train_metrics_list = []
augmented_test_metrics_list = []
num_train_per_fault_class = len(real_fault_01_data) - num_test_per_fault_class
for i in range(50, num_train_per_fault_class+1, 50):
    train_features = np.concatenate([real_fault_01_data[:num_train_per_fault_class], real_fault_11_data[:num_train_per_fault_class], real_fault_20_data[:num_samples_for_imbalanced_class], gen_fault_20_data[:i-num_samples_for_imbalanced_class]])
    train_labels = [[i]*num_train_per_fault_class for i in range(2)]
    train_labels.append([2]*(i))
    train_labels = np.concatenate(train_labels)
    print(np.sum(train_labels == 0), np.sum(train_labels == 1), np.sum(train_labels == 2))
    augmented_class_split.append(f"({num_train_per_fault_class}, {num_train_per_fault_class}, {num_samples_for_imbalanced_class} + {i-num_samples_for_imbalanced_class})")
    train_dataset = FaultDataset(train_features, train_labels)
    train_size = int(len(train_dataset) * train_valid_ratio)
    valid_size = len(train_dataset) - train_size
    train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    
    model = FaultDetector()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_metrics = train_model(model, optimizer, criterion, train_loader, valid_loader, epochs)
    test_metrics = test_model(model, criterion, test_loader)
    augmented_train_metrics_list.append(train_metrics)
    augmented_test_metrics_list.append(test_metrics)

300 300 50


100%|██████████| 2000/2000 [15:59<00:00,  2.08it/s, train_loss=1.41e-7, validation_acc=1, validation_loss=2.49e-7] 


300 300 100


100%|██████████| 2000/2000 [17:56<00:00,  1.86it/s, train_loss=2.84e-7, validation_acc=1, validation_loss=4.95e-7] 


300 300 150


100%|██████████| 2000/2000 [19:55<00:00,  1.67it/s, train_loss=3.49e-7, validation_acc=1, validation_loss=5.2e-7]  


300 300 200


100%|██████████| 2000/2000 [19:34<00:00,  1.70it/s, train_loss=0.000167, validation_acc=1, validation_loss=0.000259]


300 300 250


100%|██████████| 2000/2000 [20:42<00:00,  1.61it/s, train_loss=6.24e-7, validation_acc=1, validation_loss=9.69e-7] 


300 300 300


100%|██████████| 2000/2000 [22:27<00:00,  1.48it/s, train_loss=2.6e-5, validation_acc=1, validation_loss=3.22e-5]   


In [10]:
for i in range(len(augmented_train_metrics_list)):
    print(f"#################################")
    print(f"Split {augmented_class_split[i]}")
    print(augmented_train_metrics_list[i][-2])
    print(augmented_train_metrics_list[i][-1])
    print(augmented_test_metrics_list[i][-1])
    print(f"#################################")

#################################
Split (300, 300, 50 + 0)
              precision    recall  f1-score   support

    fault_01       1.00      1.00      1.00       280
    fault_11       1.00      1.00      1.00       259
    fault_20       1.00      1.00      1.00        46

    accuracy                           1.00       585
   macro avg       1.00      1.00      1.00       585
weighted avg       1.00      1.00      1.00       585

              precision    recall  f1-score   support

    fault_01       1.00      1.00      1.00        20
    fault_11       1.00      1.00      1.00        41
    fault_20       1.00      1.00      1.00         4

    accuracy                           1.00        65
   macro avg       1.00      1.00      1.00        65
weighted avg       1.00      1.00      1.00        65

              precision    recall  f1-score   support

    fault_01       1.00      1.00      1.00       100
    fault_11       0.74      1.00      0.85       100
    fault_20    