In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

In [2]:
class ModulationClassifier(torch.nn.Module):
    def __init__(self, num_classes=11):
        super(ModulationClassifier, self).__init__()
        
        # Treating input as 2x128 image with 1 channel (depth 1)
        self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=1, out_channels=256, kernel_size=(1,3), padding=(0,2)), torch.nn.BatchNorm2d(256))
        self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=256, out_channels=80, kernel_size=(2,3), padding=(0,2)), torch.nn.BatchNorm2d(80))
        #self.fc1   = torch.nn.Linear(in_features=10560, out_features=256)
        #self.fc2   = torch.nn.Linear(in_features=256, out_features=num_classes)
        self.classifier = torch.nn.Sequential(torch.nn.Linear(in_features=10560, out_features=256),
                                              torch.nn.ReLU(True),
                                             torch.nn.Linear(in_features=256, out_features=11))

    def forward(self, x):
        y1 = F.relu(self.conv1(x))
        y2 = F.relu(self.conv2(y1))
        y3 = torch.flatten(y2, 1)
        y4 = self.classifier(y3)
        return y4

In [3]:
fname = 'RML2016.10a_dict.pkl'
f = open(fname,'rb')
input_data_dict = pickle.load(f, encoding='latin1')

"""
Data size is 1000*2*128
Consider it as 1000 images of size 1*2*128. Where 1 is the color channel.
We need the data in this format:

image - mod - snr
image - mod - snr
image - mod - snr

Create 3 arrays:
image, mod, snr
"""

input_data_dict_keys = sorted(input_data_dict.keys())

modulation_types    = [input_data_dict_keys[i*20][0] for i in range(0,11)]
snr_types           = [input_data_dict_keys[i][1] for i in range(0,20)]

print(modulation_types)
print(snr_types)


image = []
modulation = []
snr = []


print(np.shape(image), np.shape(modulation), np.shape(snr))

for m_cnt,m in enumerate(modulation_types,0):
    for s in snr_types:
        #print(m,str(s))
        image.extend(input_data_dict[(m,s)])
        modulation.extend([m_cnt for _ in range(0,1000)])
        snr.extend([s for _ in range(0,1000)])
        
#print(np.shape(image), np.shape(modulation), np.shape(snr))
image = np.array(image)
modulation = np.array(modulation)
snr = np.array(snr)

"""
image = np.zeros((220000,2,128), dtype='float32')
modulation = np.zeros((220000), dtype='int32')
snr = np.zeros((220000), dtype='int32')
cnt = 0

for m_cnt,m in enumerate(modulation_types,0):
    for s in snr_types:
        print(cnt, m_cnt,s)
        #print(input_data_dict[(m,s)].shape)
        image[cnt:cnt+1000,:,:] = np.array(input_data_dict[(m,s)])
        modulation[cnt:cnt+1000] = np.array([m_cnt for _ in range(0,1000)])
        snr[cnt:cnt+1000] = np.array([s for _ in range(0,1000)])
        #print(image[cnt:cnt+1000,:,:], modulation[cnt:cnt+1000], snr[cnt:cnt+1000])
        cnt += 1000
"""

['8PSK', 'AM-DSB', 'AM-SSB', 'BPSK', 'CPFSK', 'GFSK', 'PAM4', 'QAM16', 'QAM64', 'QPSK', 'WBFM']
[-20, -18, -16, -14, -12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
(0,) (0,) (0,)


"\nimage = np.zeros((220000,2,128), dtype='float32')\nmodulation = np.zeros((220000), dtype='int32')\nsnr = np.zeros((220000), dtype='int32')\ncnt = 0\n\nfor m_cnt,m in enumerate(modulation_types,0):\n    for s in snr_types:\n        print(cnt, m_cnt,s)\n        #print(input_data_dict[(m,s)].shape)\n        image[cnt:cnt+1000,:,:] = np.array(input_data_dict[(m,s)])\n        modulation[cnt:cnt+1000] = np.array([m_cnt for _ in range(0,1000)])\n        snr[cnt:cnt+1000] = np.array([s for _ in range(0,1000)])\n        #print(image[cnt:cnt+1000,:,:], modulation[cnt:cnt+1000], snr[cnt:cnt+1000])\n        cnt += 1000\n"

In [4]:
class myDataset(torch.utils.data.Dataset):
    def __init__(self, X,Y,Z, transform=None):
        self.X = X
        self.Y = Y
        self.Z = Z
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        rX, rY, rZ = torch.unsqueeze(torch.from_numpy(self.X[idx]),dim=0).float(), torch.tensor(self.Y[idx]), torch.tensor(self.Z[idx])
        if(self.transform):
            rX, rY, rZ = self.transform(self.X[idx]),torch.tensor(self.Y[idx]),torch.tensor(self.Z[idx])
            
        return rX, rY, rZ

In [5]:
#dataset = [image, modulation, snr]
transformed_dataset = myDataset(image, modulation, snr,transforms.Compose([transforms.ToTensor(),transforms.Normalize((0,),(1,))]))
#rX, rY, rZ = torch.unsqueeze(torch.from_numpy(image[0]),dim=0).float(), torch.tensor(modulation[0]), torch.tensor(snr[0])
#print(type(rY))

dataset_len = len(image)
l = list(range(0, dataset_len))
np.random.shuffle(l)
split = 0.5
train_indices, test_indices = l[:int(split*dataset_len)],l[int(split*dataset_len):]

train_sampler=torch.utils.data.SubsetRandomSampler(train_indices)
test_sampler=torch.utils.data.SubsetRandomSampler(test_indices)

# train_sampler=torch.utils.data.SequentialSampler(train_indices)
# test_sampler=torch.utils.data.SequentialSampler(test_indices)

trainloader = torch.utils.data.DataLoader(transformed_dataset, batch_size=128, shuffle=False, num_workers=2, sampler=train_sampler, pin_memory=True)
testloader = torch.utils.data.DataLoader(transformed_dataset, batch_size=1, shuffle=False, num_workers=2, sampler=test_sampler, pin_memory=True)

In [6]:
net = ModulationClassifier()

learning_rate = 0.01

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
net.to(device)

print(net)
for name, param in net.named_parameters():
    #if param.requires_grad:
    print('{:s}\t{:s}\t{:s}'.format(name.ljust(40), str(param.size()).ljust(30),str(param.nelement()).rjust(10)))

cuda:0
ModulationClassifier(
  (conv1): Sequential(
    (0): Conv2d(1, 256, kernel_size=(1, 3), stride=(1, 1), padding=(0, 2))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv2): Sequential(
    (0): Conv2d(256, 80, kernel_size=(2, 3), stride=(1, 1), padding=(0, 2))
    (1): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=10560, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=11, bias=True)
  )
)
conv1.0.weight                          	torch.Size([256, 1, 1, 3])    	       768
conv1.0.bias                            	torch.Size([256])             	       256
conv1.1.weight                          	torch.Size([256])             	       256
conv1.1.bias                            	torch.Size([256])             	       256
conv2.0.weight                          	torch.Size([80, 256, 2, 3])

In [7]:
num_epochs = 100
epoch_loss_array = np.zeros(num_epochs)

for epoch in range(num_epochs):  # loop over the dataset multiple times

    running_loss = 0
    
    for i, data in enumerate(trainloader):
        x,y,z = data
        x,y,z = x.to(device), y.to(device), z.to(device)
        #x_mean = torch.mean(x, 1, keepdims=True)
        #print(x.size(), y.size(), z.size(), x_mean.size())
        #x = x - x_mean
        optimizer.zero_grad()

        y_pred = net.forward(x)   # Run batch
        #print(torch.argmax(y_pred, 1).size())
        loss = criterion(y_pred, y.long())  # Wants indexes for labels, *not* one-hot encodings.
        loss.backward()                               # Compute backprop
        optimizer.step()                              # Move a step in the right direction

        running_loss += loss.item()
        
        #print( y, torch.argmax(y_pred,1))
        #break
        
    print(running_loss/len(trainloader))

print('Finished Training')

0.8023028205300486
0.5087926574917727
0.4491056091563646
0.4072604430276294
0.3836330060348954
0.3459082034091617
0.31598821088325146
0.28530036730821745
0.26138566070517827
0.2326782818450484
0.20507257379764735
0.18787415038014568
0.16429393076619436
0.14735244273446327
0.13172882828255034
0.11876257363446924
0.11376052077773005
0.10035927109420299
0.09423685324226701
0.0905521523727234
0.08736600202703199
0.08642896368753078
0.0835414188719073
0.0791843518614769
0.07745748973169993
0.07466897254073343
0.07401477539435375
0.07224483250012231
0.07074032877247001
0.06922459818942603
0.07071248071138249
0.06784668767521548
0.06835336270200652
0.06634778787576875
0.0656068085497895
0.06383736066000406
0.06372993141412735
0.06223439309139584
0.0615888248869153
0.07224563303035358
0.08380863293139047
0.1018014394699834
0.0815504582295584
0.06957547230602697
0.0630239078211923
0.05900522507727146
0.05946746484138245
0.05740791787068511
0.054379840004582736
0.055706948674348895
0.05382014769

In [8]:
PATH = './temp_model_1.pth'
torch.save(net.state_dict(), PATH) 

model1 = ModulationClassifier()
model1.to(device)
model1.load_state_dict(torch.load(PATH))

correct = 0
total = 0

snr_accuracy = dict()
for i in range(-20,20,2):
    k = str(i)
    snr_accuracy[k] = [0,0]
    
model1.eval()
for i,data in enumerate(testloader):
    x,y,z = data
    x,y,z = x.to(device), y.to(device), z.to(device)

    y_pred = model1.forward(x)
    correct += torch.sum((torch.argmax(y_pred,dim=1) == y)).item()
    snr_value = z.item()

    snr_accuracy[str(int(snr_value))][0] += torch.sum((torch.argmax(y_pred,dim=1) == y)).item()
    snr_accuracy[str(int(snr_value))][1] += 1
    
    total += 1 #Increase by batch size

print('Accuracy of the network on the test images: %f %%' % (
    100 * correct / total))

Accuracy of the network on the test images: 78.981818 %


In [10]:
for key, values in snr_accuracy.items():
    if(values[1]):
        print('SNR {:s} Accuracy {:f}'.format(key, values[0]/values[1]))

SNR 0 Accuracy 0.704426
SNR 2 Accuracy 0.770444
SNR 4 Accuracy 0.785534
SNR 6 Accuracy 0.801042
SNR 8 Accuracy 0.800219
SNR 10 Accuracy 0.805485
SNR 12 Accuracy 0.806723
SNR 14 Accuracy 0.806434
SNR 16 Accuracy 0.805647
SNR 18 Accuracy 0.812683
