In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from src.data.load_dataset import generate_frequency_detection, generate_frequency_XOR
from src.data.load_dataset import load_frequency_detection, load_frequency_XOR
from src.models.networks import sensilla_RFNet, classical_RFNet
from src.models.utils import train, test

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

### Frequency detection

In [128]:
## params
num_samples, sampling_rate, duration, freq, snr, seed = 7000, 1500, 0.1, 5, 0.8, None
data, labels = generate_frequency_detection(num_samples, sampling_rate, freq, duration, snr, seed)

# dataloader
train_batch_size, train_percentage = 2048, 0.8
train_loader, val_loader, test_loader = load_frequency_detection(num_samples, sampling_rate,
                                                                freq, duration, snr, 
                                                                train_batch_size, train_percentage, seed=2)

In [129]:
# define a network, hyperparams, optimizer, loss_function
inp_size, hidden_size = int(sampling_rate * duration), 250
lowcut, highcut, decay_coef, seed = 2, 8, 6, 2
model = sensilla_RFNet(inp_size, hidden_size, lowcut, highcut, decay_coef=decay_coef, seed=seed).to(device)

# optimizer
lr = 0.1
optimizer = optim.SGD(model.parameters(), lr=lr)

# train
epochs = 20
log_interval = 100
for epoch in range(1, epochs + 1):
    _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
#     val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 0.275880. Accuracy: 9927/10000 (99.27%)



In [130]:
# define classical network
inp_size, hidden_size = int(sampling_rate * duration), 250
model = classical_RFNet(inp_size, hidden_size, seed=10).to(device)

# optimizer
lr = 0.1
optimizer = optim.SGD(model.parameters(), lr=lr)

# train
epochs = 20
log_interval = 100
for epoch in range(1, epochs + 1):
    _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
#     val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 0.301382. Accuracy: 9652/10000 (96.52%)



#### Frequency XOR

In [137]:
## params
num_samples, sampling_rate, duration, freq1, freq2, snr, seed = 7000, 1500, 0.1, 5, 8, 0.8, 5
data, labels = generate_frequency_XOR(num_samples, sampling_rate, freq1, freq2, duration,
                                     snr, seed, shuffle=False)

batch_size, percentage = 2048, 0.8
train_loader, val_loader, test_loader = load_frequency_XOR(num_samples, sampling_rate, freq1, freq2,
                                                           duration, snr, 
                                                           batch_size, percentage, seed)

In [150]:
## V1 network
inp_size, hidden_size = int(sampling_rate * duration), 250
lowcut, highcut, decay_coef, seed = 3, 13, 6, 2
model = sensilla_RFNet(inp_size, hidden_size, lowcut, highcut, decay_coef=decay_coef, seed=seed).to(device)

# optimizer
lr = 0.01
optimizer = optim.SGD(model.parameters(), lr=lr)

# train
epochs = 30
log_interval = 100
for epoch in range(1, epochs + 1):
    _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
#     val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 0.662010. Accuracy: 9474/10000 (94.74%)



In [155]:
## classical network
inp_size, hidden_size = int(sampling_rate * duration), 250
model = classical_RFNet(inp_size, hidden_size, seed=10).to(device)

# optimizer
lr = 0.1
optimizer = optim.SGD(model.parameters(), lr=lr)

# train
epochs = 30
log_interval = 100
for epoch in range(1, epochs + 1):
    _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
#     val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 0.504902. Accuracy: 8350/10000 (83.50%)

