## Tylko testy do większych zbiorów danych

In [38]:
import os
import pandas as pd
from torchvision.io import read_image
import re
import wfdb
import wfdb.processing
import scipy
from torch.utils.data import Dataset
import numpy as np
import json
import torch.nn as nn
import torch
from tqdm import tqdm
import torch.nn.functional as F

In [26]:
def extract_segment_with_padding(z, k, N):
    # Rozmiar segmentu to 2N + 1
    start_idx = k - N
    end_idx = k + N + 1  # Indeks końcowy +1, ponieważ Python używa wykluczającego indeksu
    
    # Upewnij się, że start_idx i end_idx mieszczą się w granicach tablicy
    if start_idx < 0:
        # Jeśli start_idx jest poza zakresem, dopełnij na początku
        padding_left = np.median(z[:end_idx])  # Wypełniamy medianą
        segment = np.concatenate([np.full(-start_idx, padding_left), z[:end_idx]])
    elif end_idx > len(z):
        # Jeśli end_idx jest poza zakresem, dopełnij na końcu
        padding_right = np.median(z[start_idx:])  # Wypełniamy medianą
        segment = np.concatenate([z[start_idx:], np.full(end_idx - len(z), padding_right)])
    else:
        # Normalny przypadek, kiedy zakres mieści się w tablicy
        segment = z[start_idx:end_idx]
    
    return segment

class MIT_BIH_Arythmia(Dataset):
    def __init__(self,N, M, dataset_dir = 'Datasets/files/', fs = 10, filename = "MIT-BIH_Arrythmia.json"):
        """
        n - number of samples of orginal signal resampled to fs, interval [-n,n]
        m - qrs times, interval [-m,m]
        """
        ecg_list = []
        exclusion_lst = ["00735", "03665", "04043", "04936", "05091", "06453", "08378", "08405", "08434", "08455"]
        for file in os.listdir(dataset_dir):
            name = re.match(r'^(.*\d\d+)\.atr$', file)
            if name:
                if name.group(1) in exclusion_lst:
                    continue
            if name:
                record = wfdb.rdsamp(f"{dataset_dir}{name.group(1)}") 
                annotation = wfdb.rdann(f"{dataset_dir}{name.group(1)}", 'atr')
                signal = record[0][:,0]
                fs_original = record[1]["fs"]
                num_samples_target = int(signal.shape[0] * fs / fs_original)
                resampled_signal = scipy.signal.resample(signal, num_samples_target)
                annotation_times_resampled = (annotation.sample * fs) / fs_original
                resampled_annotation = wfdb.Annotation('atr',annotation.symbol,annotation_times_resampled.astype(int),aux_note=annotation.aux_note)
                ecg_list.append({"name": name.group(1),"rec" : resampled_signal, "ann" : resampled_annotation})
        self.samples_list = []
        self.label_list = []
        self.qrs_samples = []
        for dic in ecg_list:
            print(dic["name"])
            # xqrs = wfdb.processing.XQRS(sig=dic["rec"], fs=fs)
            # xqrs.detect()
            # qrs_inds = xqrs.qrs_inds
            for n,i in enumerate(dic["ann"].sample):
                self.label_list.append(1 if dic["ann"].aux_note[n] == '(AFIB' else 0)
                self.samples_list.append(list(extract_segment_with_padding(dic["rec"], dic["ann"].sample[n],N)))
                # nearest_qrs_idx = find_nearest_qrs_index(dic["ann"].sample[n], qrs_inds)
                # self.qrs_samples.append(list(extract_segment_with_padding(qrs_inds,nearest_qrs_idx,M)))
        data = {
            'samples_list': self.samples_list,  # This would work if the segments are simple numeric lists
            'label_list': self.label_list,
            'qrs_samples': self.qrs_samples
        }
        with open(filename, 'w') as f:
            json.dump(data, f)
                
    def __len__(self):
        return len(self.label_list)

    def __getitem__(self, idx):
        data = torch.Tensor(self.samples_list[idx]).unsqueeze(0)
        label = self.label_list[idx]
        return data, label

In [27]:
ds = MIT_BIH_Arythmia(100,5,fs=100)

04015
04048
04126
04746
04908
05121
05261
06426
06995
07162
07859
07879
07910
08215
08219


In [32]:


class SimpleConv(nn.Module):
    def __init__(self, input = 201, input_ch = 1, num_classes = 2):
        super(SimpleConv, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(input_ch, 64, kernel_size=7, padding='same'),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, padding='same'),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=3, padding='same'),      # out 1 x 128 x n
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),                        # out 1 x 128 x n//2
            nn.Conv1d(128, 128, kernel_size=3, padding='same'),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=3, padding='same'),     # out 1 x 256 x n//2
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2),                        # out 1 x 256 x n//4
            nn.Conv1d(256, 256, kernel_size=3, padding='same'),     
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=3, padding='same'),     
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(2),                        # out 1 x 512 x n//8
            nn.Flatten(),
            nn.Linear(512*(input//8), 256),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )
        self.model.to('cuda:0')

    def forward(self, x):

        return self.model(x)
    
    def train_model(self, train_loader, valid_loader, num_epochs = 5, learning_rate=0.001, save_best = False, save_thr = 0.94):
        best_accuracy = 0.0
        total_step = len(train_loader)
        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.RMSprop(self.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

        for epoch in range(num_epochs):
            # self.train()
            correct = 0
            total = 0
            for i, (images, labels) in enumerate(tqdm(train_loader)):
                # Move tensors to the configured device
                images = images.float().to("cuda")
                labels = labels.type(torch.LongTensor)
                labels = labels.to("cuda")


                optimizer.zero_grad()

                # Forward pass
                outputs = self.forward(images)
                loss = criterion(outputs, labels)
                # Backward and optimize
                loss.backward()
                
                optimizer.step()

                # accuracy
                _, predicted = torch.max(outputs.data, 1)
                correct += (torch.eq(predicted, labels)).sum().item()
                total += labels.size(0)

                del images, labels, outputs

            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.4f}'
                            .format(epoch+1, num_epochs, i+1, total_step, loss.item(), (float(correct))/total))


            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Validation
            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in valid_loader:
                    images = images.float().to("cuda")
                    labels = labels.to("cuda")
                    outputs = self.forward(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (torch.eq(predicted, labels)).sum().item()
                    del images, labels, outputs
                if(((100 * correct / total) > best_accuracy) and save_best and ((100 * correct / total) > save_thr)):
                    torch.save(self.state_dict(), "best_resnet50_MINST-DVS2.pt")

                print('Accuracy of the network: {} %'.format( 100 * correct / total))

In [33]:
model = SimpleConv()

In [34]:
from torch.utils.data import DataLoader, random_split
train_set, val_set = random_split(ds, [0.8, 0.2])
train = DataLoader(train_set, batch_size=32, shuffle=True)
val = DataLoader(val_set, batch_size=32, shuffle=True)

In [36]:
model.train_model(train,val,num_epochs=90)

100%|██████████| 8/8 [00:00<00:00, 22.17it/s]


Epoch [1/90], Step [8/8], Loss: 0.6183, Accuracy: 0.5943
Accuracy of the network: 58.333333333333336 %


100%|██████████| 8/8 [00:00<00:00, 133.84it/s]


Epoch [2/90], Step [8/8], Loss: 1.4461, Accuracy: 0.5246
Accuracy of the network: 48.333333333333336 %


100%|██████████| 8/8 [00:00<00:00, 123.06it/s]


Epoch [3/90], Step [8/8], Loss: 3.3798, Accuracy: 0.5533
Accuracy of the network: 51.666666666666664 %


100%|██████████| 8/8 [00:00<00:00, 132.21it/s]


Epoch [4/90], Step [8/8], Loss: 2.4421, Accuracy: 0.6148
Accuracy of the network: 66.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [5/90], Step [8/8], Loss: 5.2606, Accuracy: 0.5533
Accuracy of the network: 68.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 136.52it/s]


Epoch [6/90], Step [8/8], Loss: 13.7246, Accuracy: 0.6680
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 140.28it/s]


Epoch [7/90], Step [8/8], Loss: 3.8379, Accuracy: 0.6885
Accuracy of the network: 73.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 137.41it/s]


Epoch [8/90], Step [8/8], Loss: 4.7775, Accuracy: 0.7541
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 136.03it/s]


Epoch [9/90], Step [8/8], Loss: 0.2740, Accuracy: 0.7787
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 135.65it/s]


Epoch [10/90], Step [8/8], Loss: 0.6243, Accuracy: 0.7828
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 140.35it/s]


Epoch [11/90], Step [8/8], Loss: 0.2722, Accuracy: 0.8238
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.21it/s]


Epoch [12/90], Step [8/8], Loss: 0.6333, Accuracy: 0.8197
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 137.93it/s]


Epoch [13/90], Step [8/8], Loss: 0.1415, Accuracy: 0.8115
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 140.35it/s]


Epoch [14/90], Step [8/8], Loss: 0.4538, Accuracy: 0.8525
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 142.86it/s]


Epoch [15/90], Step [8/8], Loss: 0.2235, Accuracy: 0.8443
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 140.35it/s]


Epoch [16/90], Step [8/8], Loss: 0.4898, Accuracy: 0.8484
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 127.96it/s]


Epoch [17/90], Step [8/8], Loss: 0.4999, Accuracy: 0.8525
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 135.63it/s]


Epoch [18/90], Step [8/8], Loss: 0.4363, Accuracy: 0.8443
Accuracy of the network: 65.0 %


100%|██████████| 8/8 [00:00<00:00, 133.34it/s]


Epoch [19/90], Step [8/8], Loss: 0.2631, Accuracy: 0.8156
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 137.93it/s]


Epoch [20/90], Step [8/8], Loss: 0.5088, Accuracy: 0.8279
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [21/90], Step [8/8], Loss: 0.2876, Accuracy: 0.8607
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 137.93it/s]


Epoch [22/90], Step [8/8], Loss: 1.2097, Accuracy: 0.8566
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [23/90], Step [8/8], Loss: 0.9204, Accuracy: 0.8238
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [24/90], Step [8/8], Loss: 0.6674, Accuracy: 0.8730
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 131.15it/s]


Epoch [25/90], Step [8/8], Loss: 0.3806, Accuracy: 0.8607
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 129.03it/s]


Epoch [26/90], Step [8/8], Loss: 0.3797, Accuracy: 0.8811
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 128.81it/s]


Epoch [27/90], Step [8/8], Loss: 0.3340, Accuracy: 0.8525
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 134.44it/s]


Epoch [28/90], Step [8/8], Loss: 0.1278, Accuracy: 0.8648
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 135.60it/s]


Epoch [29/90], Step [8/8], Loss: 0.2931, Accuracy: 0.8525
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 135.60it/s]


Epoch [30/90], Step [8/8], Loss: 0.3090, Accuracy: 0.9262
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.51it/s]


Epoch [31/90], Step [8/8], Loss: 0.1855, Accuracy: 0.8566
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 139.84it/s]


Epoch [32/90], Step [8/8], Loss: 0.2376, Accuracy: 0.8566
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 137.68it/s]


Epoch [33/90], Step [8/8], Loss: 0.2436, Accuracy: 0.8525
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 133.02it/s]


Epoch [34/90], Step [8/8], Loss: 0.1901, Accuracy: 0.8975
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 135.50it/s]


Epoch [35/90], Step [8/8], Loss: 0.6430, Accuracy: 0.9016
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 129.03it/s]


Epoch [36/90], Step [8/8], Loss: 0.2503, Accuracy: 0.8730
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 131.15it/s]


Epoch [37/90], Step [8/8], Loss: 0.2281, Accuracy: 0.8770
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 133.34it/s]


Epoch [38/90], Step [8/8], Loss: 0.3112, Accuracy: 0.8811
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [39/90], Step [8/8], Loss: 0.3390, Accuracy: 0.9057
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 134.78it/s]


Epoch [40/90], Step [8/8], Loss: 0.2905, Accuracy: 0.8770
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 137.34it/s]


Epoch [41/90], Step [8/8], Loss: 0.4977, Accuracy: 0.9057
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 131.59it/s]


Epoch [42/90], Step [8/8], Loss: 1.2280, Accuracy: 0.8770
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 129.49it/s]


Epoch [43/90], Step [8/8], Loss: 0.4734, Accuracy: 0.8197
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 136.25it/s]


Epoch [44/90], Step [8/8], Loss: 0.6404, Accuracy: 0.8934
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 122.46it/s]


Epoch [45/90], Step [8/8], Loss: 0.3504, Accuracy: 0.8648
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 132.82it/s]


Epoch [46/90], Step [8/8], Loss: 0.4613, Accuracy: 0.8893
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 134.57it/s]


Epoch [47/90], Step [8/8], Loss: 0.2235, Accuracy: 0.8975
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.60it/s]


Epoch [48/90], Step [8/8], Loss: 0.1174, Accuracy: 0.9016
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [49/90], Step [8/8], Loss: 0.2038, Accuracy: 0.8934
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [50/90], Step [8/8], Loss: 0.1676, Accuracy: 0.9262
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [51/90], Step [8/8], Loss: 0.1658, Accuracy: 0.9221
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [52/90], Step [8/8], Loss: 0.2203, Accuracy: 0.9180
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [53/90], Step [8/8], Loss: 0.2514, Accuracy: 0.8934
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 132.20it/s]


Epoch [54/90], Step [8/8], Loss: 0.1833, Accuracy: 0.8934
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 131.15it/s]


Epoch [55/90], Step [8/8], Loss: 0.3795, Accuracy: 0.8770
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 136.78it/s]


Epoch [56/90], Step [8/8], Loss: 0.2018, Accuracy: 0.8402
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.86it/s]


Epoch [57/90], Step [8/8], Loss: 0.2461, Accuracy: 0.8770
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 133.93it/s]


Epoch [58/90], Step [8/8], Loss: 0.3674, Accuracy: 0.8975
Accuracy of the network: 70.0 %


100%|██████████| 8/8 [00:00<00:00, 135.60it/s]


Epoch [59/90], Step [8/8], Loss: 0.3125, Accuracy: 0.8893
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 135.60it/s]


Epoch [60/90], Step [8/8], Loss: 0.2469, Accuracy: 0.9057
Accuracy of the network: 73.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 134.47it/s]


Epoch [61/90], Step [8/8], Loss: 0.0951, Accuracy: 0.9262
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [62/90], Step [8/8], Loss: 0.2482, Accuracy: 0.9262
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 136.73it/s]


Epoch [63/90], Step [8/8], Loss: 0.1793, Accuracy: 0.9467
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 124.00it/s]


Epoch [64/90], Step [8/8], Loss: 0.1755, Accuracy: 0.9385
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [65/90], Step [8/8], Loss: 0.0152, Accuracy: 0.9508
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.33it/s]


Epoch [66/90], Step [8/8], Loss: 0.1968, Accuracy: 0.9098
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 130.73it/s]


Epoch [67/90], Step [8/8], Loss: 0.0997, Accuracy: 0.9180
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 135.19it/s]


Epoch [68/90], Step [8/8], Loss: 0.1613, Accuracy: 0.8934
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.59it/s]


Epoch [69/90], Step [8/8], Loss: 0.0439, Accuracy: 0.9139
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 137.94it/s]


Epoch [70/90], Step [8/8], Loss: 0.2114, Accuracy: 0.8893
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 131.15it/s]


Epoch [71/90], Step [8/8], Loss: 0.3148, Accuracy: 0.8975
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.56it/s]


Epoch [72/90], Step [8/8], Loss: 0.3386, Accuracy: 0.9303
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 123.52it/s]


Epoch [73/90], Step [8/8], Loss: 0.1113, Accuracy: 0.9549
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 133.82it/s]


Epoch [74/90], Step [8/8], Loss: 0.1464, Accuracy: 0.9303
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 134.09it/s]


Epoch [75/90], Step [8/8], Loss: 0.0885, Accuracy: 0.9344
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 135.09it/s]


Epoch [76/90], Step [8/8], Loss: 0.1416, Accuracy: 0.8811
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 129.31it/s]


Epoch [77/90], Step [8/8], Loss: 0.4580, Accuracy: 0.9303
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 133.81it/s]


Epoch [78/90], Step [8/8], Loss: 0.1090, Accuracy: 0.9057
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 136.51it/s]


Epoch [79/90], Step [8/8], Loss: 0.2186, Accuracy: 0.9057
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 132.61it/s]


Epoch [80/90], Step [8/8], Loss: 0.0924, Accuracy: 0.9672
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 134.98it/s]


Epoch [81/90], Step [8/8], Loss: 0.1046, Accuracy: 0.9426
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 126.39it/s]


Epoch [82/90], Step [8/8], Loss: 0.2054, Accuracy: 0.9508
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 133.84it/s]


Epoch [83/90], Step [8/8], Loss: 0.0689, Accuracy: 0.9754
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 135.76it/s]


Epoch [84/90], Step [8/8], Loss: 0.0682, Accuracy: 0.9836
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 135.72it/s]


Epoch [85/90], Step [8/8], Loss: 0.0795, Accuracy: 0.9549
Accuracy of the network: 75.0 %


100%|██████████| 8/8 [00:00<00:00, 128.82it/s]


Epoch [86/90], Step [8/8], Loss: 0.0984, Accuracy: 0.9713
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 134.43it/s]


Epoch [87/90], Step [8/8], Loss: 0.2426, Accuracy: 0.9385
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 138.65it/s]


Epoch [88/90], Step [8/8], Loss: 0.0559, Accuracy: 0.9713
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 133.40it/s]


Epoch [89/90], Step [8/8], Loss: 0.0936, Accuracy: 0.9672
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 135.63it/s]

Epoch [90/90], Step [8/8], Loss: 0.3073, Accuracy: 0.9508
Accuracy of the network: 78.33333333333333 %





## Model a La resnet

In [55]:
class ResNetBlock(nn.Module):
    def __init__(self,in_channels, out_channels):
        """
        output same as input
        """
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm1d(out_channels),
                        nn.ReLU(inplace=False))  # Changed inplace to False
        self.conv2 = nn.Sequential(
                        nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm1d(out_channels),
                        nn.ReLU(inplace=False))
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        if(in_channels != out_channels):
            self.residual = nn.Sequential(
                nn.Conv1d(self.in_channels, out_channels, kernel_size=1, stride=1),
                nn.BatchNorm1d(out_channels),
            )

    def forward(self,x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.in_channels != self.out_channels:
            residual = self.residual(x)
        else:
            residual = x
        return F.relu(out + residual, inplace=False)


class ResNetLike(nn.Module):
    def __init__(self, input = 201, input_ch = 1, num_classes = 2):
        super(ResNetLike, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(input_ch, 64, kernel_size=7, padding='same'),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            ResNetBlock(64,64),
            ResNetBlock(64,64),
            ResNetBlock(64,64),
            ResNetBlock(64,128),    # out 1 x 128 x n
            nn.MaxPool1d(2),        # out 1 x 128 x n//2
            ResNetBlock(128,128),
            ResNetBlock(128,128),
            ResNetBlock(128,256),
            nn.MaxPool1d(2),        # out 1 x 256 x n//2
            ResNetBlock(256,256),
            ResNetBlock(256,256),
            ResNetBlock(256,512),
            nn.MaxPool1d(2),        # out 1 x 512 x n//8
            nn.Flatten(),
            nn.Linear(512*(input//8), 256),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )
        self.model.to('cuda:0')

    def forward(self, x):

        return self.model(x)
    
    def train_model(self, train_loader, valid_loader, num_epochs = 5, learning_rate=0.001, save_best = False, save_thr = 0.94):
        best_accuracy = 0.0
        total_step = len(train_loader)
        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.RMSprop(self.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

        for epoch in range(num_epochs):
            # self.train()
            correct = 0
            total = 0
            for i, (images, labels) in enumerate(tqdm(train_loader)):
                # Move tensors to the configured device
                images = images.float().to("cuda")
                labels = labels.type(torch.LongTensor)
                labels = labels.to("cuda")


                optimizer.zero_grad()

                # Forward pass
                outputs = self.forward(images)
                loss = criterion(outputs, labels)
                # Backward and optimize
                loss.backward()
                
                optimizer.step()

                # accuracy
                _, predicted = torch.max(outputs.data, 1)
                correct += (torch.eq(predicted, labels)).sum().item()
                total += labels.size(0)

                del images, labels, outputs

            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.4f}'
                            .format(epoch+1, num_epochs, i+1, total_step, loss.item(), (float(correct))/total))


            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Validation
            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in valid_loader:
                    images = images.float().to("cuda")
                    labels = labels.to("cuda")
                    outputs = self.forward(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (torch.eq(predicted, labels)).sum().item()
                    del images, labels, outputs
                if(((100 * correct / total) > best_accuracy) and save_best and ((100 * correct / total) > save_thr)):
                    torch.save(self.state_dict(), "best_resnet50_MINST-DVS2.pt")

                print('Accuracy of the network: {} %'.format( 100 * correct / total))

In [56]:
model_res = ResNetLike()


In [58]:
model_res.train_model(train,val,num_epochs=90)

100%|██████████| 8/8 [00:00<00:00, 21.30it/s]


Epoch [1/90], Step [8/8], Loss: 93.1512, Accuracy: 0.5533
Accuracy of the network: 56.666666666666664 %


100%|██████████| 8/8 [00:00<00:00, 54.90it/s]


Epoch [2/90], Step [8/8], Loss: 1.4380, Accuracy: 0.4836
Accuracy of the network: 46.666666666666664 %


100%|██████████| 8/8 [00:00<00:00, 50.83it/s]


Epoch [3/90], Step [8/8], Loss: 20.5204, Accuracy: 0.5779
Accuracy of the network: 50.0 %


100%|██████████| 8/8 [00:00<00:00, 51.28it/s]


Epoch [4/90], Step [8/8], Loss: 54.8610, Accuracy: 0.5738
Accuracy of the network: 53.333333333333336 %


100%|██████████| 8/8 [00:00<00:00, 55.22it/s]


Epoch [5/90], Step [8/8], Loss: 10.7801, Accuracy: 0.5533
Accuracy of the network: 56.666666666666664 %


100%|██████████| 8/8 [00:00<00:00, 52.59it/s]


Epoch [6/90], Step [8/8], Loss: 8.3922, Accuracy: 0.6148
Accuracy of the network: 46.666666666666664 %


100%|██████████| 8/8 [00:00<00:00, 54.06it/s]


Epoch [7/90], Step [8/8], Loss: 1.8869, Accuracy: 0.6967
Accuracy of the network: 68.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 55.61it/s]


Epoch [8/90], Step [8/8], Loss: 8.8884, Accuracy: 0.6475
Accuracy of the network: 65.0 %


100%|██████████| 8/8 [00:00<00:00, 53.15it/s]


Epoch [9/90], Step [8/8], Loss: 2.0470, Accuracy: 0.7582
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.69it/s]


Epoch [10/90], Step [8/8], Loss: 0.8766, Accuracy: 0.7582
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.51it/s]


Epoch [11/90], Step [8/8], Loss: 1.5411, Accuracy: 0.8033
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.09it/s]


Epoch [12/90], Step [8/8], Loss: 1.9178, Accuracy: 0.7500
Accuracy of the network: 66.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.92it/s]


Epoch [13/90], Step [8/8], Loss: 0.5313, Accuracy: 0.8074
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.69it/s]


Epoch [14/90], Step [8/8], Loss: 0.7469, Accuracy: 0.7705
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.36it/s]


Epoch [15/90], Step [8/8], Loss: 5.0516, Accuracy: 0.8115
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.79it/s]


Epoch [16/90], Step [8/8], Loss: 0.9523, Accuracy: 0.8115
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.94it/s]


Epoch [17/90], Step [8/8], Loss: 1.7838, Accuracy: 0.7787
Accuracy of the network: 66.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.51it/s]


Epoch [18/90], Step [8/8], Loss: 0.4310, Accuracy: 0.7295
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.36it/s]


Epoch [19/90], Step [8/8], Loss: 1.1139, Accuracy: 0.7910
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 56.33it/s]


Epoch [20/90], Step [8/8], Loss: 1.3674, Accuracy: 0.7541
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.17it/s]


Epoch [21/90], Step [8/8], Loss: 0.9588, Accuracy: 0.8238
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.78it/s]

Epoch [22/90], Step [8/8], Loss: 2.0576, Accuracy: 0.8156





Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 55.12it/s]


Epoch [23/90], Step [8/8], Loss: 0.3173, Accuracy: 0.8156
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.71it/s]


Epoch [24/90], Step [8/8], Loss: 0.2977, Accuracy: 0.8525
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.87it/s]


Epoch [25/90], Step [8/8], Loss: 3.4992, Accuracy: 0.8484
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.68it/s]


Epoch [26/90], Step [8/8], Loss: 0.4191, Accuracy: 0.8484
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.94it/s]


Epoch [27/90], Step [8/8], Loss: 1.3717, Accuracy: 0.8730
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 55.04it/s]


Epoch [28/90], Step [8/8], Loss: 1.6789, Accuracy: 0.8525
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.61it/s]


Epoch [29/90], Step [8/8], Loss: 1.1139, Accuracy: 0.8607
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 55.75it/s]


Epoch [30/90], Step [8/8], Loss: 0.1946, Accuracy: 0.8770
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.93it/s]


Epoch [31/90], Step [8/8], Loss: 1.2864, Accuracy: 0.8852
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.63it/s]


Epoch [32/90], Step [8/8], Loss: 0.5508, Accuracy: 0.8893
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.99it/s]


Epoch [33/90], Step [8/8], Loss: 0.2462, Accuracy: 0.8852
Accuracy of the network: 90.0 %


100%|██████████| 8/8 [00:00<00:00, 53.82it/s]


Epoch [34/90], Step [8/8], Loss: 0.1127, Accuracy: 0.8975
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 54.58it/s]


Epoch [35/90], Step [8/8], Loss: 0.2262, Accuracy: 0.8893
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.73it/s]


Epoch [36/90], Step [8/8], Loss: 0.4730, Accuracy: 0.8811
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.38it/s]


Epoch [37/90], Step [8/8], Loss: 0.3115, Accuracy: 0.8648
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.61it/s]


Epoch [38/90], Step [8/8], Loss: 0.5010, Accuracy: 0.8689
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 52.46it/s]


Epoch [39/90], Step [8/8], Loss: 0.3188, Accuracy: 0.8730
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.55it/s]


Epoch [40/90], Step [8/8], Loss: 0.9477, Accuracy: 0.8566
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 54.35it/s]


Epoch [41/90], Step [8/8], Loss: 0.4767, Accuracy: 0.8156
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.70it/s]


Epoch [42/90], Step [8/8], Loss: 0.2053, Accuracy: 0.8566
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.64it/s]


Epoch [43/90], Step [8/8], Loss: 0.2343, Accuracy: 0.8770
Accuracy of the network: 73.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.43it/s]


Epoch [44/90], Step [8/8], Loss: 0.2734, Accuracy: 0.8566
Accuracy of the network: 90.0 %


100%|██████████| 8/8 [00:00<00:00, 51.53it/s]


Epoch [45/90], Step [8/8], Loss: 0.1990, Accuracy: 0.8811
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.79it/s]


Epoch [46/90], Step [8/8], Loss: 0.5809, Accuracy: 0.8811
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.49it/s]


Epoch [47/90], Step [8/8], Loss: 0.3982, Accuracy: 0.8648
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.32it/s]


Epoch [48/90], Step [8/8], Loss: 0.2528, Accuracy: 0.8811
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.84it/s]


Epoch [49/90], Step [8/8], Loss: 0.2471, Accuracy: 0.8566
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.34it/s]


Epoch [50/90], Step [8/8], Loss: 0.0992, Accuracy: 0.9098
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.72it/s]

Epoch [51/90], Step [8/8], Loss: 0.5568, Accuracy: 0.8811





Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.00it/s]


Epoch [52/90], Step [8/8], Loss: 0.2057, Accuracy: 0.8811
Accuracy of the network: 91.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.68it/s]


Epoch [53/90], Step [8/8], Loss: 0.5225, Accuracy: 0.9016
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 52.30it/s]


Epoch [54/90], Step [8/8], Loss: 0.3334, Accuracy: 0.8689
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.85it/s]


Epoch [55/90], Step [8/8], Loss: 0.3519, Accuracy: 0.8730
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 53.71it/s]


Epoch [56/90], Step [8/8], Loss: 0.5060, Accuracy: 0.8730
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.78it/s]

Epoch [57/90], Step [8/8], Loss: 0.1326, Accuracy: 0.8893





Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.89it/s]


Epoch [58/90], Step [8/8], Loss: 0.1932, Accuracy: 0.8934
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.99it/s]


Epoch [59/90], Step [8/8], Loss: 0.4314, Accuracy: 0.9016
Accuracy of the network: 91.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 55.35it/s]


Epoch [60/90], Step [8/8], Loss: 0.3484, Accuracy: 0.9303
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.67it/s]


Epoch [61/90], Step [8/8], Loss: 0.2336, Accuracy: 0.9139
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.21it/s]


Epoch [62/90], Step [8/8], Loss: 0.6560, Accuracy: 0.8730
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.10it/s]


Epoch [63/90], Step [8/8], Loss: 0.3899, Accuracy: 0.9180
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.71it/s]


Epoch [64/90], Step [8/8], Loss: 0.3308, Accuracy: 0.8689
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.60it/s]


Epoch [65/90], Step [8/8], Loss: 0.5531, Accuracy: 0.9303
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 55.24it/s]


Epoch [66/90], Step [8/8], Loss: 0.5889, Accuracy: 0.8934
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 54.23it/s]


Epoch [67/90], Step [8/8], Loss: 0.1487, Accuracy: 0.9016
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 54.66it/s]


Epoch [68/90], Step [8/8], Loss: 0.5498, Accuracy: 0.9016
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.57it/s]


Epoch [69/90], Step [8/8], Loss: 0.1901, Accuracy: 0.8934
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 53.57it/s]


Epoch [70/90], Step [8/8], Loss: 0.4125, Accuracy: 0.8566
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 53.47it/s]


Epoch [71/90], Step [8/8], Loss: 0.1912, Accuracy: 0.8770
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.73it/s]


Epoch [72/90], Step [8/8], Loss: 0.3421, Accuracy: 0.8934
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.93it/s]


Epoch [73/90], Step [8/8], Loss: 0.1655, Accuracy: 0.9098
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.24it/s]


Epoch [74/90], Step [8/8], Loss: 0.2711, Accuracy: 0.8934
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.20it/s]


Epoch [75/90], Step [8/8], Loss: 0.1395, Accuracy: 0.9303
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.77it/s]


Epoch [76/90], Step [8/8], Loss: 0.2944, Accuracy: 0.9016
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.21it/s]


Epoch [77/90], Step [8/8], Loss: 0.0337, Accuracy: 0.9057
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.61it/s]


Epoch [78/90], Step [8/8], Loss: 0.1750, Accuracy: 0.8893
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.40it/s]


Epoch [79/90], Step [8/8], Loss: 0.3042, Accuracy: 0.8893
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 55.18it/s]


Epoch [80/90], Step [8/8], Loss: 0.2932, Accuracy: 0.9098
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 54.50it/s]


Epoch [81/90], Step [8/8], Loss: 0.1610, Accuracy: 0.9098
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 54.44it/s]


Epoch [82/90], Step [8/8], Loss: 0.4899, Accuracy: 0.8770
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.00it/s]


Epoch [83/90], Step [8/8], Loss: 0.0999, Accuracy: 0.9016
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.61it/s]


Epoch [84/90], Step [8/8], Loss: 0.4029, Accuracy: 0.9057
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 49.73it/s]


Epoch [85/90], Step [8/8], Loss: 0.1433, Accuracy: 0.8648
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.98it/s]


Epoch [86/90], Step [8/8], Loss: 0.1607, Accuracy: 0.9098
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.75it/s]


Epoch [87/90], Step [8/8], Loss: 0.2940, Accuracy: 0.9180
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.89it/s]


Epoch [88/90], Step [8/8], Loss: 0.0151, Accuracy: 0.9426
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.31it/s]


Epoch [89/90], Step [8/8], Loss: 0.2555, Accuracy: 0.9262
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 54.31it/s]

Epoch [90/90], Step [8/8], Loss: 0.2927, Accuracy: 0.9180
Accuracy of the network: 85.0 %





In [59]:
model_res.train_model(train,val,num_epochs=90,learning_rate=0.0001)

100%|██████████| 8/8 [00:00<00:00, 27.42it/s]


Epoch [1/90], Step [8/8], Loss: 0.0862, Accuracy: 0.9426
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.43it/s]


Epoch [2/90], Step [8/8], Loss: 0.1785, Accuracy: 0.9426
Accuracy of the network: 88.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.59it/s]


Epoch [3/90], Step [8/8], Loss: 0.1559, Accuracy: 0.9303
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.30it/s]


Epoch [4/90], Step [8/8], Loss: 0.0608, Accuracy: 0.9836
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 49.64it/s]


Epoch [5/90], Step [8/8], Loss: 0.0466, Accuracy: 0.9631
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.10it/s]


Epoch [6/90], Step [8/8], Loss: 0.0496, Accuracy: 0.9754
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.21it/s]


Epoch [7/90], Step [8/8], Loss: 0.0048, Accuracy: 0.9836
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 51.45it/s]


Epoch [8/90], Step [8/8], Loss: 0.0679, Accuracy: 0.9918
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.35it/s]


Epoch [9/90], Step [8/8], Loss: 0.0078, Accuracy: 1.0000
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.72it/s]


Epoch [10/90], Step [8/8], Loss: 0.1204, Accuracy: 0.9918
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.61it/s]


Epoch [11/90], Step [8/8], Loss: 0.0169, Accuracy: 0.9918
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.00it/s]


Epoch [12/90], Step [8/8], Loss: 0.2027, Accuracy: 0.9959
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 49.51it/s]


Epoch [13/90], Step [8/8], Loss: 0.0105, Accuracy: 0.9959
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.23it/s]


Epoch [14/90], Step [8/8], Loss: 0.0205, Accuracy: 0.9877
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.56it/s]


Epoch [15/90], Step [8/8], Loss: 0.0103, Accuracy: 0.9918
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.95it/s]


Epoch [16/90], Step [8/8], Loss: 0.0302, Accuracy: 0.9959
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 54.20it/s]


Epoch [17/90], Step [8/8], Loss: 0.0232, Accuracy: 1.0000
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 54.86it/s]


Epoch [18/90], Step [8/8], Loss: 0.0383, Accuracy: 0.9877
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.96it/s]


Epoch [19/90], Step [8/8], Loss: 0.0130, Accuracy: 0.9836
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.03it/s]


Epoch [20/90], Step [8/8], Loss: 0.0440, Accuracy: 0.9959
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.92it/s]


Epoch [21/90], Step [8/8], Loss: 0.0055, Accuracy: 1.0000
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.66it/s]


Epoch [22/90], Step [8/8], Loss: 0.0190, Accuracy: 0.9877
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 52.63it/s]


Epoch [23/90], Step [8/8], Loss: 0.0129, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.65it/s]


Epoch [24/90], Step [8/8], Loss: 0.0313, Accuracy: 0.9959
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.97it/s]


Epoch [25/90], Step [8/8], Loss: 0.0121, Accuracy: 1.0000
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.15it/s]


Epoch [26/90], Step [8/8], Loss: 0.0031, Accuracy: 0.9959
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.79it/s]


Epoch [27/90], Step [8/8], Loss: 0.0526, Accuracy: 0.9918
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.39it/s]


Epoch [28/90], Step [8/8], Loss: 0.0276, Accuracy: 0.9959
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 52.56it/s]


Epoch [29/90], Step [8/8], Loss: 0.0371, Accuracy: 0.9877
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.45it/s]


Epoch [30/90], Step [8/8], Loss: 0.0229, Accuracy: 1.0000
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 52.15it/s]


Epoch [31/90], Step [8/8], Loss: 0.0397, Accuracy: 0.9959
Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.45it/s]


Epoch [32/90], Step [8/8], Loss: 0.0360, Accuracy: 1.0000
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 50.59it/s]


Epoch [33/90], Step [8/8], Loss: 0.0038, Accuracy: 0.9959
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.95it/s]


Epoch [34/90], Step [8/8], Loss: 0.0032, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.70it/s]


Epoch [35/90], Step [8/8], Loss: 0.0008, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.63it/s]


Epoch [36/90], Step [8/8], Loss: 0.0025, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.94it/s]


Epoch [37/90], Step [8/8], Loss: 0.0063, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.54it/s]


Epoch [38/90], Step [8/8], Loss: 0.0056, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.56it/s]


Epoch [39/90], Step [8/8], Loss: 0.0049, Accuracy: 1.0000
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.84it/s]


Epoch [40/90], Step [8/8], Loss: 0.0013, Accuracy: 1.0000
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.77it/s]


Epoch [41/90], Step [8/8], Loss: 0.0145, Accuracy: 1.0000
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.92it/s]


Epoch [42/90], Step [8/8], Loss: 0.0036, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.12it/s]


Epoch [43/90], Step [8/8], Loss: 0.0012, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.91it/s]


Epoch [44/90], Step [8/8], Loss: 0.0104, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.01it/s]


Epoch [45/90], Step [8/8], Loss: 0.0038, Accuracy: 0.9959
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.27it/s]


Epoch [46/90], Step [8/8], Loss: 0.0156, Accuracy: 0.9959
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 49.50it/s]


Epoch [47/90], Step [8/8], Loss: 0.0242, Accuracy: 0.9877
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 52.05it/s]


Epoch [48/90], Step [8/8], Loss: 0.0527, Accuracy: 0.9877
Accuracy of the network: 73.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 53.00it/s]


Epoch [49/90], Step [8/8], Loss: 0.0074, Accuracy: 0.9754
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.58it/s]


Epoch [50/90], Step [8/8], Loss: 0.0058, Accuracy: 0.9836
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 51.65it/s]


Epoch [51/90], Step [8/8], Loss: 0.0033, Accuracy: 0.9918
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.42it/s]


Epoch [52/90], Step [8/8], Loss: 0.0122, Accuracy: 0.9959
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.58it/s]


Epoch [53/90], Step [8/8], Loss: 0.0208, Accuracy: 0.9877
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 50.44it/s]


Epoch [54/90], Step [8/8], Loss: 0.0024, Accuracy: 0.9959
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 51.15it/s]


Epoch [55/90], Step [8/8], Loss: 0.0207, Accuracy: 0.9918
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.58it/s]


Epoch [56/90], Step [8/8], Loss: 0.0661, Accuracy: 0.9959
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 53.07it/s]


Epoch [57/90], Step [8/8], Loss: 0.0027, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.20it/s]


Epoch [58/90], Step [8/8], Loss: 0.0071, Accuracy: 1.0000
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 49.94it/s]


Epoch [59/90], Step [8/8], Loss: 0.0066, Accuracy: 0.9918
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 48.62it/s]


Epoch [60/90], Step [8/8], Loss: 0.0012, Accuracy: 1.0000
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.09it/s]


Epoch [61/90], Step [8/8], Loss: 0.0115, Accuracy: 0.9959
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.56it/s]


Epoch [62/90], Step [8/8], Loss: 0.0150, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.12it/s]


Epoch [63/90], Step [8/8], Loss: 0.0054, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.40it/s]


Epoch [64/90], Step [8/8], Loss: 0.0045, Accuracy: 1.0000
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 48.52it/s]


Epoch [65/90], Step [8/8], Loss: 0.0113, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.52it/s]


Epoch [66/90], Step [8/8], Loss: 0.0339, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.82it/s]


Epoch [67/90], Step [8/8], Loss: 0.0015, Accuracy: 0.9959
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.86it/s]


Epoch [68/90], Step [8/8], Loss: 0.0021, Accuracy: 1.0000
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 49.48it/s]


Epoch [69/90], Step [8/8], Loss: 0.0015, Accuracy: 1.0000
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 51.83it/s]


Epoch [70/90], Step [8/8], Loss: 0.0572, Accuracy: 0.9959
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 51.32it/s]


Epoch [71/90], Step [8/8], Loss: 0.0069, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.31it/s]


Epoch [72/90], Step [8/8], Loss: 0.0035, Accuracy: 1.0000
Accuracy of the network: 85.0 %


100%|██████████| 8/8 [00:00<00:00, 49.72it/s]


Epoch [73/90], Step [8/8], Loss: 0.0214, Accuracy: 1.0000
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.80it/s]


Epoch [74/90], Step [8/8], Loss: 0.0008, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 50.90it/s]


Epoch [75/90], Step [8/8], Loss: 0.0034, Accuracy: 1.0000
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 50.58it/s]


Epoch [76/90], Step [8/8], Loss: 0.0011, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 50.71it/s]


Epoch [77/90], Step [8/8], Loss: 0.0036, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.03it/s]


Epoch [78/90], Step [8/8], Loss: 0.0010, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 52.57it/s]


Epoch [79/90], Step [8/8], Loss: 0.0037, Accuracy: 1.0000
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.26it/s]


Epoch [80/90], Step [8/8], Loss: 0.0030, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.38it/s]


Epoch [81/90], Step [8/8], Loss: 0.0135, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.23it/s]


Epoch [82/90], Step [8/8], Loss: 0.0004, Accuracy: 1.0000
Accuracy of the network: 80.0 %


100%|██████████| 8/8 [00:00<00:00, 51.83it/s]


Epoch [83/90], Step [8/8], Loss: 0.0060, Accuracy: 1.0000
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.95it/s]


Epoch [84/90], Step [8/8], Loss: 0.0055, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 52.82it/s]


Epoch [85/90], Step [8/8], Loss: 0.0123, Accuracy: 1.0000
Accuracy of the network: 78.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 49.65it/s]


Epoch [86/90], Step [8/8], Loss: 0.0005, Accuracy: 1.0000
Accuracy of the network: 81.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 51.75it/s]


Epoch [87/90], Step [8/8], Loss: 0.0132, Accuracy: 1.0000
Accuracy of the network: 76.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 50.53it/s]

Epoch [88/90], Step [8/8], Loss: 0.0080, Accuracy: 0.9877





Accuracy of the network: 86.66666666666667 %


100%|██████████| 8/8 [00:00<00:00, 50.86it/s]


Epoch [89/90], Step [8/8], Loss: 0.0316, Accuracy: 0.9877
Accuracy of the network: 83.33333333333333 %


100%|██████████| 8/8 [00:00<00:00, 51.20it/s]

Epoch [90/90], Step [8/8], Loss: 0.0288, Accuracy: 0.9836
Accuracy of the network: 85.0 %



