In [None]:
import torch
import torch.nn as nn 
import torch.optim as optim

In [None]:
from torch.utils.data import DataLoader, Dataset

class RFSignalDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        sample = {'data': data, 'label': self.labels[index]}
        return sample

# Define a custom neural network model
class RFSignalClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(RFSignalClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [None]:
# Create datasets and data loaders
train_dataset = RFSignalDataset(X_train, y_train)
test_dataset = RFSignalDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

# Initialize the model and optimizer
input_size = (22766528)  # Specify the input size based on your X_train shape
num_classes = 3 # Specify the number of classes
model = RFSignalClassifier(input_size, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()  # For classification tasks

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        data = batch['data']
        labels = batch['label']
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation loop (optional)
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for batch in test_loader:
            data = batch['data']
            labels = batch['label']
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Accuracy: {accuracy:.2f}%')