In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from matplotlib import pyplot as plt

import medmnist
from medmnist import INFO, Evaluator

In [2]:
lap_egns_1 = torch.load("eigenvalues/lap_egns_1_pneum.pt")
lap_egns_2 = torch.load("eigenvalues/lap_egns_2_pneum.pt")
lap_egns_3 = torch.load("eigenvalues/lap_egns_3_pneum.pt")
lap_egns_4 = torch.load("eigenvalues/lap_egns_4_pneum.pt")
lap_egns_5 = torch.load("eigenvalues/lap_egns_5_pneum.pt")

lap_egns_test = torch.load("eigenvalues/lap_egns_test_pneum.pt")

train_labels = torch.load("eigenvalues/train_labels_pneum.pt")
test_labels = torch.load("eigenvalues/test_labels_pneum.pt")

In [3]:
train_eigns = torch.cat([lap_egns_1, lap_egns_2, lap_egns_3, lap_egns_4, lap_egns_5])

In [4]:
train_eigns.shape

torch.Size([4708, 14, 3])

In [5]:
train_eigns[:, :, 0:2].flatten(start_dim=1).shape

torch.Size([4708, 28])

In [6]:
train_data = [(lap, lb) for lap, lb in zip(train_eigns[:, :, 1], train_labels.squeeze().tolist())]  # train_eigns[:, :, 1] 

In [7]:
test_data = [(lap, lb) for lap, lb in zip(lap_egns_test[:, :, 1], test_labels.squeeze().tolist())] # lap_egns_test[:, :, 1] 

# Training

In [8]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64,
                                           shuffle=True)

In [9]:
train_eigns[:, :, 1].shape

torch.Size([4708, 14])

In [10]:
model = nn.Sequential(
            nn.Linear(14, 2048),
            nn.Tanh(),
            nn.Linear(2048, 2048),
            nn.Tanh(),
            nn.Linear(2048, 1024),
            nn.Tanh(),
            nn.Linear(1024, 512),
            nn.Tanh(),
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, 2),
            # nn.LogSoftmax(dim=1)
)

In [11]:
learning_rate = 5e-3

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

n_epochs = 150

In [None]:
for epoch in range(n_epochs):
    for eig_vals, label in train_loader:
        out = model(eig_vals)
        loss = loss_fn(out, label)
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %d, Loss: %f" % (epoch, float(loss)))
    # if loss <= 0.02:
    #     break

Epoch: 0, Loss: 0.579289
Epoch: 1, Loss: 0.543538
Epoch: 2, Loss: 0.536936
Epoch: 3, Loss: 0.563300
Epoch: 4, Loss: 0.496939
Epoch: 5, Loss: 0.560550
Epoch: 6, Loss: 0.619036
Epoch: 7, Loss: 0.613952
Epoch: 8, Loss: 0.476434
Epoch: 9, Loss: 0.565849
Epoch: 10, Loss: 0.522263
Epoch: 11, Loss: 0.613749
Epoch: 12, Loss: 0.517053
Epoch: 13, Loss: 0.620651
Epoch: 14, Loss: 0.574933
Epoch: 15, Loss: 0.566060
Epoch: 16, Loss: 0.569068
Epoch: 17, Loss: 0.488928
Epoch: 18, Loss: 0.463417
Epoch: 19, Loss: 0.510320
Epoch: 20, Loss: 0.551376
Epoch: 21, Loss: 0.609563
Epoch: 22, Loss: 0.443546
Epoch: 23, Loss: 0.508792
Epoch: 24, Loss: 0.432332
Epoch: 25, Loss: 0.565774
Epoch: 26, Loss: 0.584620
Epoch: 27, Loss: 0.712303
Epoch: 28, Loss: 0.520854
Epoch: 29, Loss: 0.659442
Epoch: 30, Loss: 0.582544
Epoch: 31, Loss: 0.584849
Epoch: 32, Loss: 0.486756
Epoch: 33, Loss: 0.614719
Epoch: 34, Loss: 0.471055
Epoch: 35, Loss: 0.575174
Epoch: 36, Loss: 0.596581
Epoch: 37, Loss: 0.554742
Epoch: 38, Loss: 0.598

In [None]:
# train_loader = torch.utils.data.DataLoader(train_data, batch_size=64,
#                                            shuffle=False)
correct = 0
total = 0

with torch.no_grad():
    for laps, labels in train_loader:
        outputs = model(laps)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
        
print("Accuracy: %f" % (correct / total))

In [None]:
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64,
                                         shuffle=False)

correct = 0
total = 0

with torch.no_grad():
    for laps, labels in test_loader:
        outputs = model(laps)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
        
print("Accuracy: %f" % (correct / total))