In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchsummary import summary
from b2aiprep.dataset import VBAIDataset
from b2aiprep.process import Audio, specgram
import IPython.display as Ipd

### Bridge2AI dataset

In [2]:
dataset = VBAIDataset('./bids_without_sensitive_recordings')

### Dataset split
- 80% for training
- 10% for validation
- 10% for testing

In [3]:
participant_df = dataset.load_and_pivot_questionnaire('participant')
all_identities = sorted(participant_df['record_id'].to_numpy().tolist())

N = len(all_identities)

train_identities = set(all_identities[:int(0.8*N)])
val_identities = set(all_identities[int(0.8*N):int(0.9*N)])
test_identities = set(all_identities[int(0.9*N):])

print('train:', len(train_identities))
print('val:', len(val_identities))
print('test:', len(test_identities))

train: 142
val: 18
test: 18


### Create PyTorch dataset of prolonged vowel audios with age and airway stenosis labels

In [4]:
qs = dataset.load_questionnaires('recordingschema')
q_dfs = []
for i, questionnaire in enumerate(qs):
    df = dataset.questionnaire_to_dataframe(questionnaire)
    df['dataframe_number'] = i
    q_dfs.append(df)
    i += 1
recordingschema_df = pd.concat(q_dfs)
recordingschema_df = pd.pivot(recordingschema_df, index='dataframe_number', columns='linkId', values='valueString')

person_session_pairs = recordingschema_df[['record_id', 'recording_session_id']].to_numpy().astype(str)
person_session_pairs = np.unique(person_session_pairs, axis=0).tolist()

print('Found {} person/session pairs'.format(len(person_session_pairs)))

Found 204 person/session pairs


In [5]:
class MyAudioDataset(torch.utils.data.Dataset):
	def __init__(self, identities, dataset, person_session_pairs, segment_size=3):
		self.segment_size = segment_size
		
		# get age and airway stenosis classification for all subjects
		participant_df = dataset.load_and_pivot_questionnaire('participant')
		age_df = participant_df[['record_id', 'age']].to_numpy()
		airway_stenosis_df = participant_df[['record_id', 'airway_stenosis']].to_numpy()
        
		age_dict = {}
		for person_id, age in age_df:
			age_dict[str(person_id)] = float(age)
		airway_stenosis_dict = {}
		for person_id, airway_stenosis in airway_stenosis_df:
			airway_stenosis_dict[str(person_id)] = float(airway_stenosis)

		# get all prolonged vowel audios
		self.audio_files = []
		self.age = []
		self.airway_stenosis = []
        
		for person_id, session_id in person_session_pairs:
			if person_id not in identities:
				continue
			vowel_audios = [str(x) for x in dataset.find_audio(person_id, session_id) if str(x).endswith('-Prolonged-vowel.wav')]
			self.audio_files += vowel_audios
			self.age += [age_dict[person_id]]*len(vowel_audios)
			self.airway_stenosis += [airway_stenosis_dict[person_id]]*len(vowel_audios)

	def __len__(self):
		return len(self.audio_files)

	def __getitem__(self, idx):
		audio = Audio.from_file(self.audio_files[idx])
		audio = audio.to_16khz().signal.squeeze()
		# get middle K seconds if audio is too long, pad with zeros if it is too short
		if audio.size(0) > self.segment_size*16000:
			d = (audio.size(0)-self.segment_size*16000)//2
			audio = audio[d:d+self.segment_size*16000]
		else:
			audio = torch.nn.functional.pad(audio, (0,self.segment_size*16000-audio.size(0)), mode='constant', value=0)
		return {'signal': audio, 'age': self.age[idx], 'airway_stenosis': self.airway_stenosis[idx]}

In [6]:
train_dataset = MyAudioDataset(train_identities, dataset, person_session_pairs)
val_dataset = MyAudioDataset(val_identities, dataset, person_session_pairs)
test_dataset = MyAudioDataset(test_identities, dataset, person_session_pairs)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)

for batch in train_dataloader:
	print(batch['age'], batch['airway_stenosis'])
	Ipd.display(Ipd.Audio(data=batch['signal'][0], rate=16000))
	break

tensor([71., 20., 39., 65., 45., 59., 29., 73.], dtype=torch.float64) tensor([0., 0., 1., 0., 1., 0., 0., 0.], dtype=torch.float64)


### CNN Model

In [7]:
class CNN_1D(torch.nn.Module):
	def __init__(self, num_classes):
		super(CNN_1D, self).__init__()
		self.conv1 = torch.nn.Conv1d(1, 8, kernel_size=11, stride=3, padding=0)
		self.conv2 = torch.nn.Conv1d(8, 16, kernel_size=11, stride=3, padding=0)
		self.conv3 = torch.nn.Conv1d(16, 32, kernel_size=11, stride=3, padding=0)
		self.conv4 = torch.nn.Conv1d(32, 64, kernel_size=11, stride=3, padding=0)
		self.mp = torch.nn.MaxPool1d(kernel_size=3, stride=3)
		self.fc1 = torch.nn.Linear(384, 120)
		self.fc2 = torch.nn.Linear(120, 84)
		self.fc3 = torch.nn.Linear(84, num_classes)

	def forward(self, x):
		x = self.mp(torch.nn.functional.relu(self.conv1(x)))
		x = self.mp(torch.nn.functional.relu(self.conv2(x)))
		x = self.mp(torch.nn.functional.relu(self.conv3(x)))
		x = self.mp(torch.nn.functional.relu(self.conv4(x)))
		x = x.reshape(x.size(0), -1)
		x = torch.nn.functional.relu(self.fc1(x))
		x = torch.nn.functional.relu(self.fc2(x))
		x = self.fc3(x)
		return x

In [8]:
cnn = CNN_1D(1)
_ = summary(cnn, (1, 48000))

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv1d: 1-1                            [-1, 8, 15997]            96
├─MaxPool1d: 1-2                         [-1, 8, 5332]             --
├─Conv1d: 1-3                            [-1, 16, 1774]            1,424
├─MaxPool1d: 1-4                         [-1, 16, 591]             --
├─Conv1d: 1-5                            [-1, 32, 194]             5,664
├─MaxPool1d: 1-6                         [-1, 32, 64]              --
├─Conv1d: 1-7                            [-1, 64, 18]              22,592
├─MaxPool1d: 1-8                         [-1, 64, 6]               --
├─Linear: 1-9                            [-1, 120]                 46,200
├─Linear: 1-10                           [-1, 84]                  10,164
├─Linear: 1-11                           [-1, 1]                   85
Total params: 86,225
Trainable params: 86,225
Non-trainable params: 0
Total mult-adds (M): 5.46
Input size (MB): 0.18
Forward/backward pa

### Training

In [9]:
def eval(model, dataloader):
	model.eval()
	acc = 0
	for batch in dataloader:
		with torch.no_grad():
			outputs = torch.nn.functional.sigmoid(model(batch['signal'].unsqueeze(1)).squeeze(1))
		for i in range(len(batch['signal'])):
			if outputs[i].item() > 0.5:
				acc += batch['airway_stenosis'][i]
			else:
				acc += 1.0-batch['airway_stenosis'][i]
	return acc/len(dataloader.dataset)	

num_epochs = 250
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001, weight_decay=5e-5)

best_val_acc = 0
for epoch in range(num_epochs):
	cnn.train()
	
	closs = []
	for batch in train_dataloader:		
		optimizer.zero_grad()
		outputs = torch.nn.functional.sigmoid(cnn(batch['signal'].unsqueeze(1)).squeeze(1))
		loss = torch.nn.functional.binary_cross_entropy(outputs, batch['airway_stenosis'].float())
		closs += [loss.item()]*len(batch['signal'])
		loss.backward()
		optimizer.step()

	if epoch%10 == 9:
		val_acc = eval(cnn, val_dataloader)
		print('Epoch:{} TrainLoss:{:.4f} TrainACC:{:.4f} ValACC:{:.4f}'.format(epoch+1, sum(closs)/len(closs), eval(cnn, train_dataloader), val_acc))

		if val_acc > best_val_acc:
			best_val_acc = val_acc
			torch.save(cnn.state_dict(), './mymodel.pth')
			print('Saved!')

Epoch:10 TrainLoss:0.6208 TrainACC:0.6828 ValACC:0.5789
Saved!
Epoch:20 TrainLoss:0.5801 TrainACC:0.6828 ValACC:0.5789
Epoch:30 TrainLoss:0.5790 TrainACC:0.6828 ValACC:0.5789
Epoch:40 TrainLoss:0.5356 TrainACC:0.7241 ValACC:0.6842
Saved!
Epoch:50 TrainLoss:0.4644 TrainACC:0.7586 ValACC:0.6842
Epoch:60 TrainLoss:0.4936 TrainACC:0.7379 ValACC:0.6842
Epoch:70 TrainLoss:0.4288 TrainACC:0.7655 ValACC:0.6842
Epoch:80 TrainLoss:0.4244 TrainACC:0.7724 ValACC:0.6842
Epoch:90 TrainLoss:0.3632 TrainACC:0.8552 ValACC:0.7368
Saved!
Epoch:100 TrainLoss:0.1567 TrainACC:0.9586 ValACC:0.7895
Saved!
Epoch:110 TrainLoss:0.3062 TrainACC:0.9517 ValACC:0.3684
Epoch:120 TrainLoss:0.1011 TrainACC:0.9862 ValACC:0.5263
Epoch:130 TrainLoss:0.0619 TrainACC:1.0000 ValACC:0.5789
Epoch:140 TrainLoss:0.0198 TrainACC:0.9931 ValACC:0.5789
Epoch:150 TrainLoss:0.0145 TrainACC:0.9931 ValACC:0.6316
Epoch:160 TrainLoss:0.0297 TrainACC:0.9793 ValACC:0.5263
Epoch:170 TrainLoss:0.0092 TrainACC:1.0000 ValACC:0.5789
Epoch:180 Tr

### Testing

In [10]:
cnn.load_state_dict(torch.load('./mymodel.pth'))
test_acc = eval(cnn, test_dataloader)
print('TestACC:{:.4f}'.format(test_acc))

TestACC:0.7500


In [15]:
cnn.eval()
acc = [0, 0]
total = [0, 0]
for batch in test_dataloader:
	with torch.no_grad():
		outputs = torch.nn.functional.sigmoid(cnn(batch['signal'].unsqueeze(1)).squeeze(1))
	for i in range(len(batch['signal'])):
		total[int(batch['airway_stenosis'][i])] += 1
		if outputs[i].item() > 0.5:
			acc[int(batch['airway_stenosis'][i])] += batch['airway_stenosis'][i]
		else:
			acc[int(batch['airway_stenosis'][i])] += 1.0-batch['airway_stenosis'][i]

print('Class #0 ({}) ACC:{:.4f} Class #1 ({}) ACC:{:.4f}'.format(total[0], acc[0]/total[0], total[1], acc[1]/total[1]))

Class #0 (9) ACC:0.7778 Class #1 (7) ACC:0.7143
