In [1]:
import pandas as pd
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as f
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
X_train = pd.read_csv("mitbih_train.csv").to_numpy()
X_test = pd.read_csv("mitbih_test.csv").to_numpy()

Y_train = X_train[65000:,-1].astype(int)
Y_test = X_test[:,-1].astype(int)
X_train = X_train[65000:,:-1]
X_test = X_test[:,:-1]

In [11]:
class ResBlock(nn.Module):
    def __init__(self):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, 5, padding='same')
        init.xavier_uniform_(self.conv1.weight)
        init.constant_(self.conv1.bias, 0)
        self.conv2 = nn.Conv1d(32, 32, 5, padding='same')
        init.xavier_uniform_(self.conv2.weight)
        init.constant_(self.conv2.bias, 0)
        self.pool = nn.MaxPool1d(5,stride=2)
        
    def forward(self, data):
        output = f.relu(self.conv1(data))
        output = self.conv2(output)
        output += data
        output = f.relu(output)
        output = self.pool(output)
        return output
    
class implementedModel(nn.Module):
    def __init__(self, device):
        super(implementedModel, self).__init__()
        self.preconv = nn.Conv1d(1, 32, 5, padding='same')
        init.xavier_uniform_(self.preconv.weight)
        init.constant_(self.preconv.bias, 0)
        self.res = []
        for i in range(5):
            self.res.append(ResBlock().to(device))
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, 5)
        
        
    def forward(self, data):
        data = self.preconv(data)
        for resblock in self.res:
            data = resblock(data)
        data = nn.Flatten()(data)
        data = f.relu(self.fc1(data))
        data = f.softmax(self.fc2(data))
        return data

In [12]:
class EEGDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.FloatTensor(X).unsqueeze(dim=1)
        self.Y = torch.LongTensor(Y)
        
    def __len__(self):
        return(self.Y.shape[-1])
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

In [13]:
train_dataset = EEGDataset(X_train, Y_train)
test_dataset = EEGDataset(X_test, Y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(train_dataset, batch_size=64)

In [14]:
device = "mps"
model = implementedModel(device).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr = 0.0001)

In [None]:
epoch = 100
for e in range(epoch):
    train_loss = 0
    train_acc = 0
    for i, (datas, labels) in enumerate(tqdm(train_loader)):
        datas = datas.to("mps")
        labels = labels.to("mps")
        optimizer.zero_grad()
        outputs = model(datas)
        loss = loss_fn(outputs, labels)
        loss.backward()
        train_loss += loss.item()
        train_acc += (np.sum(np.array(outputs.argmax(dim=1).tolist()) == np.array(labels.tolist())))
        optimizer.step()
    train_loss /= len(train_loader)
    train_acc /= len(train_loader)
    print("Epoch: ",e,"      Loss: ",train_loss,"      Accuracy: ", train_acc)

  data = f.softmax(self.fc2(data))
100%|█████████████████████████████████████████| 353/353 [00:06<00:00, 58.01it/s]


Epoch:  0       Loss:  1.567128621147307       Accuracy:  22.662889518413596


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.33it/s]


Epoch:  1       Loss:  1.4996430293021052       Accuracy:  26.79886685552408


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.88it/s]


Epoch:  2       Loss:  1.425142418899212       Accuracy:  35.07648725212464


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.26it/s]


Epoch:  3       Loss:  1.380988326694067       Accuracy:  35.30028328611898


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.12it/s]


Epoch:  4       Loss:  1.3582853967002027       Accuracy:  35.43342776203966


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.59it/s]


Epoch:  5       Loss:  1.3379484563643804       Accuracy:  36.51274787535411


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.62it/s]


Epoch:  6       Loss:  1.3161074409403815       Accuracy:  39.079320113314445


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.43it/s]


Epoch:  7       Loss:  1.295440127761776       Accuracy:  40.8385269121813


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.87it/s]


Epoch:  8       Loss:  1.2758935922941472       Accuracy:  42.39660056657224


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.41it/s]


Epoch:  9       Loss:  1.2594506892536586       Accuracy:  43.623229461756374


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.47it/s]


Epoch:  10       Loss:  1.2444683478844403       Accuracy:  44.756373937677054


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.46it/s]


Epoch:  11       Loss:  1.231236676994194       Accuracy:  45.45609065155807


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.87it/s]


Epoch:  12       Loss:  1.2202397665963969       Accuracy:  46.039660056657226


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.80it/s]


Epoch:  13       Loss:  1.2108389851729526       Accuracy:  46.34844192634561


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.97it/s]


Epoch:  14       Loss:  1.2029871589401289       Accuracy:  46.54390934844193


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.24it/s]


Epoch:  15       Loss:  1.1964386799517839       Accuracy:  46.62889518413598


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.50it/s]


Epoch:  16       Loss:  1.1905175069236216       Accuracy:  46.8356940509915


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 64.05it/s]


Epoch:  17       Loss:  1.1854536172013108       Accuracy:  46.96883852691218


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.66it/s]


Epoch:  18       Loss:  1.1809641625996015       Accuracy:  47.15297450424929


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.09it/s]


Epoch:  19       Loss:  1.1768662885633474       Accuracy:  47.31161473087819


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.43it/s]


Epoch:  20       Loss:  1.173220563880445       Accuracy:  47.478753541076486


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.97it/s]


Epoch:  21       Loss:  1.1699745013423413       Accuracy:  47.68271954674221


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.31it/s]


Epoch:  22       Loss:  1.1668472830364454       Accuracy:  47.87818696883853


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.78it/s]


Epoch:  23       Loss:  1.1640145552394747       Accuracy:  47.95467422096317


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.34it/s]


Epoch:  24       Loss:  1.1614746902211868       Accuracy:  48.08498583569405


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.32it/s]


Epoch:  25       Loss:  1.1587671338667613       Accuracy:  48.21813031161473


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.78it/s]


Epoch:  26       Loss:  1.1566411891672497       Accuracy:  48.32294617563739


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.78it/s]


Epoch:  27       Loss:  1.154209514515258       Accuracy:  48.48725212464589


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.97it/s]


Epoch:  28       Loss:  1.1518116885474317       Accuracy:  48.53541076487252


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 60.08it/s]


Epoch:  29       Loss:  1.1500522141416079       Accuracy:  48.67705382436261


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 60.31it/s]


Epoch:  30       Loss:  1.1482371571381436       Accuracy:  48.77620396600567


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 61.06it/s]


Epoch:  31       Loss:  1.1462003391119981       Accuracy:  48.89518413597734


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 62.83it/s]


Epoch:  32       Loss:  1.144697805977408       Accuracy:  49.01133144475921


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 60.94it/s]


Epoch:  33       Loss:  1.143225766781707       Accuracy:  49.12181303116147


100%|█████████████████████████████████████████| 353/353 [00:05<00:00, 63.28it/s]


Epoch:  34       Loss:  1.1414241811033687       Accuracy:  49.22379603399433


 61%|█████████████████████████▏               | 217/353 [00:03<00:02, 62.53it/s]