In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

class CNN1(nn.Module): # 更換模型
  def __init__(self, n_chan=32):
    super().__init__()
    self.n_chan = n_chan
    self.conv1 = nn.Conv2d(4, n_chan, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(n_chan, n_chan//2, kernel_size=3, padding=1)
    self.fc1 = nn.Linear(n_chan//2 * 90 * 90, 32)
    self.fc2 = nn.Linear(32, 6)

  def forward(self, out):
    out = F.max_pool2d(F.relu(self.conv1(out)), 2)
    out = F.max_pool2d(F.relu(self.conv2(out)), 2)
    out = out.view(-1, self.n_chan//2 * 90 * 90)
    out = F.relu(self.fc1(out))
    out = self.fc2(out)
    return out
 
mean = [0.3740, 0.3766, 0.3755, 0.3776]
std = [0.4429, 0.4416, 0.4419, 0.4407]
preprocess = transforms.Compose([ # 更換正規化參數
  transforms.Resize([360, 360]), # 更換圖片輸入大小
  transforms.ToTensor(),
  transforms.Normalize(mean=mean, std=std)
])

def data_loader(file_path):
    image = Image.open(file_path)
    image_tensor = preprocess(image)
    return image_tensor

label_dict = {"front":0, "back":1, "left":2, "right":3, "up":4, "down":5}
inv_label_dict = [key for key in label_dict.keys()]
print(inv_label_dict)

In [17]:
import librosa
import librosa.display
import matplotlib.pyplot as plt

# WAV成生圖片
def PlotFreq_Time(fileName): # 更換畫圖的function
    output_path = fileName.split('.')[0] + ".png"
    audio_data, sample_rate = librosa.load(fileName)
    stft = librosa.stft(audio_data)
    spectrogram = librosa.amplitude_to_db(abs(stft))
    plt.figure(figsize=(5, 5))
    librosa.display.specshow(spectrogram, sr=sample_rate, cmap='gray')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def SoundPredict(_model, _modelPath, _imgPath):
    model = _model
    model.load_state_dict(torch.load(_modelPath))
    model.eval()
    imgs= data_loader(_imgPath)
    imgs= imgs.unsqueeze(0)
    outputs= model(imgs)
    _, predicted = torch.max(outputs, dim=1)
    print(inv_label_dict[predicted.item()])    

In [60]:
# pip install sounddevice
import sounddevice as sd
from scipy.io import wavfile

freq = 44100
duration = 2 # 錄音持續 2 秒
MODEL_PATH = "model1.pth"

# 執行 且 print "start" 後 開始錄音2秒，並預測
recording = sd.rec(int(duration*freq), samplerate=freq, channels=2)
print("start")
sd.wait()
wavfile.write("output.wav", freq, recording)
PlotFreq_Time("output.wav")
SoundPredict(CNN1(), MODEL_PATH, "output.png") #更改使用的模型

start
down
