In [210]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Load the dataset
data = np.load('lab2_dataset.npz')
train_feats = torch.tensor(data['train_feats'])
test_feats = torch.tensor(data['test_feats'])

train_labels = torch.tensor(data['train_labels'])
test_labels = torch.tensor(data['test_labels'])
phone_labels = data['phone_labels']

# Set up the dataloaders
train_dataset = torch.utils.data.TensorDataset(train_feats, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
train_loader.requires_grad=True

test_dataset = torch.utils.data.TensorDataset(test_feats, test_labels)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)

# Define the model architecture
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = nn.Linear(40, 2*48*11)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(2*48*11, 48*11)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(48*11, 48)
        self.relu3 = nn.ReLU()
        self.linear4 = nn.Linear(48*11, 48)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.linear3(x)
        x = torch.reshape(x, (-1, 11*48))
        x = self.linear4(x)
        return x

# Instantiate the model, loss function, and optimizer
model = MyModel()
criterion = nn.CrossEntropyLoss()
# Changing optimizer to Adam and lr to 0.001
optimizer = optim.Adam(model.parameters(), lr=0.001)
# optimizer = optim.SGD(model.parameters(), )

def train_network(model, train_loader, criterion, optimizer):
    for epoch in range(5):
        for i, (inputs, labels) in enumerate(train_loader, 0):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        

def test_network(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Test accuracy: %d %%' % (100 * correct / total))

train_network(model, train_loader, criterion, optimizer)
test_network(model, test_loader)


Test accuracy: 56 %


In [307]:
def test_network_per_class(model, test_loader):
    correct = 0
    total = 0
    correct_per_class = [0]*48
    total_per_class = [0]*48
    aux = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            for i in range(labels.size(0)):
                correct_per_class[labels[i].item()] += (labels[i].item() == predicted[i].item())
                total_per_class[labels[i].item()] += 1
                aux += 1
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Test accuracy: %d %%' % (100 * correct / total))
    correct_per_class = np.array(correct_per_class)
    total_per_class = np.array(total_per_class)
    correct_per_class = (100*correct_per_class/total_per_class)
    phone_labels2 = np.array(phone_labels)
    correct_per_class_sorted = sorted(correct_per_class, reverse=True)
    print("Accuracy w/o labels: ", correct_per_class_sorted)
    ans = {}
    for i in range(48):
        ans[i] = (correct_per_class[i], phone_labels2[i])
    print("Unsorted accuracy: ", ans)
test_network_per_class(model, test_loader)

Test accuracy: 56 %
Accuracy w/o labels:  [91.0, 83.0, 81.0, 80.0, 79.0, 78.0, 78.0, 77.0, 76.0, 74.0, 74.0, 72.0, 70.0, 68.0, 68.0, 67.0, 67.0, 66.0, 66.0, 65.0, 63.0, 61.0, 61.0, 60.0, 60.0, 60.0, 59.0, 53.0, 53.0, 53.0, 51.0, 49.0, 48.0, 46.0, 44.0, 43.0, 41.0, 40.0, 38.0, 35.0, 34.0, 32.0, 32.0, 30.0, 29.0, 25.0, 23.0, 13.698630136986301]
Unsorted accuracy:  {0: (83.0, 'sil'), 1: (91.0, 's'), 2: (61.0, 'ao'), 3: (30.0, 'l'), 4: (60.0, 'r'), 5: (32.0, 'iy'), 6: (60.0, 'vcl'), 7: (38.0, 'd'), 8: (48.0, 'eh'), 9: (65.0, 'cl'), 10: (63.0, 'p'), 11: (44.0, 'ix'), 12: (60.0, 'z'), 13: (49.0, 'ih'), 14: (78.0, 'sh'), 15: (43.0, 'n'), 16: (70.0, 'v'), 17: (23.0, 'aa'), 18: (72.0, 'y'), 19: (53.0, 'uw'), 20: (67.0, 'w'), 21: (79.0, 'ey'), 22: (74.0, 'dx'), 23: (66.0, 'b'), 24: (78.0, 'ay'), 25: (66.0, 'ng'), 26: (61.0, 'k'), 27: (81.0, 'epi'), 28: (74.0, 'ch'), 29: (53.0, 'dh'), 30: (53.0, 'er'), 31: (40.0, 'en'), 32: (77.0, 'g'), 33: (29.0, 'aw'), 34: (51.0, 'hh'), 35: (59.0, 'ae'), 36: (6

In [None]:
# (91.0, 's'), (83.0, 'sil'), (81.0, 'epi') are the top three phoneme classes with higher accuracy.
# (25.0, 'ax'), (23.0, 'aa'), (13.7, 'zh') are the top three phoneme classes with lower accuracy

In [269]:
for i, phone in enumerate(phone_labels):
    print(i, phone)

0 sil
1 s
2 ao
3 l
4 r
5 iy
6 vcl
7 d
8 eh
9 cl
10 p
11 ix
12 z
13 ih
14 sh
15 n
16 v
17 aa
18 y
19 uw
20 w
21 ey
22 dx
23 b
24 ay
25 ng
26 k
27 epi
28 ch
29 dh
30 er
31 en
32 g
33 aw
34 hh
35 ae
36 ow
37 t
38 ax
39 m
40 zh
41 ah
42 el
43 f
44 jh
45 uh
46 oy
47 th


In [319]:
def test_network_per_class_detailed(model, test_loader, phoneme):
    correct = 0
    total = 0
    correct_per_class = [0]*48
    total_per_class = [0]*48
    aux = 0
    predictions = [0] * 48
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            for i in range(labels.size(0)):
                if (labels[i].item() == phoneme):
                    predictions[predicted[i].item()] += 1
    phone_labels2 = np.array(phone_labels)
    ans = {}
    predictions = np.array(predictions)
    summ = predictions.sum()
    predictions = 100 * predictions/summ
    for i in range(48):
        ans[i] = (predictions[i], phone_labels2[i])
    print("Predictions: ", ans) 
    print(predictions)

In [320]:
# 14 sh
test_network_per_class_detailed(model, test_loader, 14)
# The most commonly mis-classified for 'sh' are: 15.0, 's'; 5.0, 'ch'; 1.0, 'zh'; 1.0, 'f'; and 78.0 for 'sh'

Predictions:  {0: (0.0, 'sil'), 1: (15.0, 's'), 2: (0.0, 'ao'), 3: (0.0, 'l'), 4: (0.0, 'r'), 5: (0.0, 'iy'), 6: (0.0, 'vcl'), 7: (0.0, 'd'), 8: (0.0, 'eh'), 9: (0.0, 'cl'), 10: (0.0, 'p'), 11: (0.0, 'ix'), 12: (0.0, 'z'), 13: (0.0, 'ih'), 14: (78.0, 'sh'), 15: (0.0, 'n'), 16: (0.0, 'v'), 17: (0.0, 'aa'), 18: (0.0, 'y'), 19: (0.0, 'uw'), 20: (0.0, 'w'), 21: (0.0, 'ey'), 22: (0.0, 'dx'), 23: (0.0, 'b'), 24: (0.0, 'ay'), 25: (0.0, 'ng'), 26: (0.0, 'k'), 27: (0.0, 'epi'), 28: (5.0, 'ch'), 29: (0.0, 'dh'), 30: (0.0, 'er'), 31: (0.0, 'en'), 32: (0.0, 'g'), 33: (0.0, 'aw'), 34: (0.0, 'hh'), 35: (0.0, 'ae'), 36: (0.0, 'ow'), 37: (0.0, 't'), 38: (0.0, 'ax'), 39: (0.0, 'm'), 40: (1.0, 'zh'), 41: (0.0, 'ah'), 42: (0.0, 'el'), 43: (1.0, 'f'), 44: (0.0, 'jh'), 45: (0.0, 'uh'), 46: (0.0, 'oy'), 47: (0.0, 'th')}
[ 0. 15.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. 78.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  5.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  1.  0.  0.  1.  0.  0.  0.

In [321]:
# 10 p
test_network_per_class_detailed(model, test_loader, 10)
# The most commonly mis-classified for 'p' are: 9.0, 'k'; 9.0, 't'; 6.0, 'b'; 4.0, 'dh'; and 63.0, 'p'

Predictions:  {0: (1.0, 'sil'), 1: (0.0, 's'), 2: (0.0, 'ao'), 3: (0.0, 'l'), 4: (0.0, 'r'), 5: (0.0, 'iy'), 6: (0.0, 'vcl'), 7: (2.0, 'd'), 8: (0.0, 'eh'), 9: (0.0, 'cl'), 10: (63.0, 'p'), 11: (0.0, 'ix'), 12: (0.0, 'z'), 13: (0.0, 'ih'), 14: (0.0, 'sh'), 15: (0.0, 'n'), 16: (1.0, 'v'), 17: (0.0, 'aa'), 18: (0.0, 'y'), 19: (0.0, 'uw'), 20: (0.0, 'w'), 21: (0.0, 'ey'), 22: (0.0, 'dx'), 23: (6.0, 'b'), 24: (0.0, 'ay'), 25: (0.0, 'ng'), 26: (9.0, 'k'), 27: (0.0, 'epi'), 28: (0.0, 'ch'), 29: (4.0, 'dh'), 30: (0.0, 'er'), 31: (0.0, 'en'), 32: (0.0, 'g'), 33: (0.0, 'aw'), 34: (1.0, 'hh'), 35: (0.0, 'ae'), 36: (0.0, 'ow'), 37: (9.0, 't'), 38: (1.0, 'ax'), 39: (0.0, 'm'), 40: (0.0, 'zh'), 41: (0.0, 'ah'), 42: (0.0, 'el'), 43: (3.0, 'f'), 44: (0.0, 'jh'), 45: (0.0, 'uh'), 46: (0.0, 'oy'), 47: (0.0, 'th')}
[ 1.  0.  0.  0.  0.  0.  0.  2.  0.  0. 63.  0.  0.  0.  0.  0.  1.  0.
  0.  0.  0.  0.  0.  6.  0.  0.  9.  0.  0.  4.  0.  0.  0.  0.  1.  0.
  0.  9.  1.  0.  0.  0.  0.  3.  0.  0.  0. 

In [322]:
# 39 m
test_network_per_class_detailed(model, test_loader, 39)
# The most commonly mis-classified for 'm' are: 

Predictions:  {0: (0.0, 'sil'), 1: (0.0, 's'), 2: (0.0, 'ao'), 3: (1.0, 'l'), 4: (0.0, 'r'), 5: (1.0, 'iy'), 6: (2.0, 'vcl'), 7: (0.0, 'd'), 8: (0.0, 'eh'), 9: (0.0, 'cl'), 10: (0.0, 'p'), 11: (0.0, 'ix'), 12: (0.0, 'z'), 13: (0.0, 'ih'), 14: (0.0, 'sh'), 15: (22.0, 'n'), 16: (2.0, 'v'), 17: (0.0, 'aa'), 18: (0.0, 'y'), 19: (1.0, 'uw'), 20: (1.0, 'w'), 21: (0.0, 'ey'), 22: (4.0, 'dx'), 23: (0.0, 'b'), 24: (0.0, 'ay'), 25: (10.0, 'ng'), 26: (0.0, 'k'), 27: (0.0, 'epi'), 28: (0.0, 'ch'), 29: (2.0, 'dh'), 30: (0.0, 'er'), 31: (8.0, 'en'), 32: (0.0, 'g'), 33: (0.0, 'aw'), 34: (0.0, 'hh'), 35: (0.0, 'ae'), 36: (0.0, 'ow'), 37: (0.0, 't'), 38: (0.0, 'ax'), 39: (46.0, 'm'), 40: (0.0, 'zh'), 41: (0.0, 'ah'), 42: (0.0, 'el'), 43: (0.0, 'f'), 44: (0.0, 'jh'), 45: (0.0, 'uh'), 46: (0.0, 'oy'), 47: (0.0, 'th')}
[ 0.  0.  0.  1.  0.  1.  2.  0.  0.  0.  0.  0.  0.  0.  0. 22.  2.  0.
  0.  1.  1.  0.  4.  0.  0. 10.  0.  0.  0.  2.  0.  8.  0.  0.  0.  0.
  0.  0.  0. 46.  0.  0.  0.  0.  0.  0.  0

In [323]:
# 4 r
test_network_per_class_detailed(model, test_loader, 4)
# The most commonly mis-classified for 'r' are: 

Predictions:  {0: (0.0, 'sil'), 1: (0.0, 's'), 2: (2.0, 'ao'), 3: (0.0, 'l'), 4: (60.0, 'r'), 5: (0.0, 'iy'), 6: (0.0, 'vcl'), 7: (0.0, 'd'), 8: (2.0, 'eh'), 9: (0.0, 'cl'), 10: (0.0, 'p'), 11: (1.0, 'ix'), 12: (0.0, 'z'), 13: (0.0, 'ih'), 14: (0.0, 'sh'), 15: (0.0, 'n'), 16: (1.0, 'v'), 17: (1.0, 'aa'), 18: (1.0, 'y'), 19: (6.0, 'uw'), 20: (2.0, 'w'), 21: (1.0, 'ey'), 22: (0.0, 'dx'), 23: (0.0, 'b'), 24: (0.0, 'ay'), 25: (0.0, 'ng'), 26: (0.0, 'k'), 27: (0.0, 'epi'), 28: (0.0, 'ch'), 29: (0.0, 'dh'), 30: (13.0, 'er'), 31: (0.0, 'en'), 32: (0.0, 'g'), 33: (1.0, 'aw'), 34: (1.0, 'hh'), 35: (1.0, 'ae'), 36: (1.0, 'ow'), 37: (0.0, 't'), 38: (1.0, 'ax'), 39: (0.0, 'm'), 40: (0.0, 'zh'), 41: (0.0, 'ah'), 42: (0.0, 'el'), 43: (0.0, 'f'), 44: (0.0, 'jh'), 45: (4.0, 'uh'), 46: (1.0, 'oy'), 47: (0.0, 'th')}
[ 0.  0.  2.  0. 60.  0.  0.  0.  2.  0.  0.  1.  0.  0.  0.  0.  1.  1.
  1.  6.  2.  1.  0.  0.  0.  0.  0.  0.  0.  0. 13.  0.  0.  1.  1.  1.
  1.  0.  1.  0.  0.  0.  0.  0.  0.  4.  1.

In [324]:
# 35 ae
test_network_per_class_detailed(model, test_loader, 35)
# The most commonly mis-classified for 'ae' are: 

Predictions:  {0: (0.0, 'sil'), 1: (0.0, 's'), 2: (1.0, 'ao'), 3: (0.0, 'l'), 4: (1.0, 'r'), 5: (0.0, 'iy'), 6: (0.0, 'vcl'), 7: (0.0, 'd'), 8: (16.0, 'eh'), 9: (0.0, 'cl'), 10: (0.0, 'p'), 11: (0.0, 'ix'), 12: (0.0, 'z'), 13: (0.0, 'ih'), 14: (0.0, 'sh'), 15: (0.0, 'n'), 16: (0.0, 'v'), 17: (0.0, 'aa'), 18: (0.0, 'y'), 19: (0.0, 'uw'), 20: (0.0, 'w'), 21: (3.0, 'ey'), 22: (0.0, 'dx'), 23: (0.0, 'b'), 24: (9.0, 'ay'), 25: (1.0, 'ng'), 26: (0.0, 'k'), 27: (0.0, 'epi'), 28: (0.0, 'ch'), 29: (0.0, 'dh'), 30: (0.0, 'er'), 31: (0.0, 'en'), 32: (0.0, 'g'), 33: (3.0, 'aw'), 34: (1.0, 'hh'), 35: (59.0, 'ae'), 36: (0.0, 'ow'), 37: (0.0, 't'), 38: (0.0, 'ax'), 39: (0.0, 'm'), 40: (0.0, 'zh'), 41: (5.0, 'ah'), 42: (1.0, 'el'), 43: (0.0, 'f'), 44: (0.0, 'jh'), 45: (0.0, 'uh'), 46: (0.0, 'oy'), 47: (0.0, 'th')}
[ 0.  0.  1.  0.  1.  0.  0.  0. 16.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  3.  0.  0.  9.  1.  0.  0.  0.  0.  0.  0.  0.  3.  1. 59.
  0.  0.  0.  0.  0.  5.  1.  0.  0.  0.  0.