### Import Pytorch Libraries

In [4]:
import os
import numpy as np
import pandas as pd
import soundfile as sf
import scipy.signal as sps
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

### Feature Extraction from .wav file

In [5]:
def extract_features(filepath):
    x, fs = sf.read(filepath)

    # Recreate chirp
    f0, f1, T = 15000, 20000, 0.012 # freq sweep, chirp duration
    t = np.linspace(0, T, int(fs*T), endpoint=False)
    chirp = sps.chirp(t, f0, f1, T, method='linear')

    # Matched filter to clean signal and convert to impulse response
    ir = sps.fftconvolve(x, chirp[::-1], mode='same')
    ir = np.abs(ir)

    # Find primary reflection (eliminate weaker multi-path)
    peak = np.argmax(ir)
    win = ir[max(0, peak-300):peak+600]  # define window to locate peak amplitude within ir

    # FFT magnitude
    X = np.abs(np.fft.rfft(win * np.hanning(len(win))))
    freqs = np.fft.rfftfreq(len(win), 1/fs)

    # Band energies helpful for material ID
    bands = [(2e3,5e3), (5e3,10e3), (10e3,18e3)]
    feats = [np.mean(X[(freqs>=lo)&(freqs<hi)]) for lo,hi in bands]

    # Spectral centroid
    centroid = np.sum(freqs * X) / np.sum(X)
    feats.append(centroid)

    return np.log1p(np.array(feats, dtype=np.float32)) # (4,)
        

In [6]:
# Define dataset class

class ChirpDataset(Dataset):
    def __init__(self, csv_path, root):
        self.df = pd.read_csv(csv_path)
        self.root = root

        # Encode material â†’ integer
        self.label_encoder = LabelEncoder()
        self.df["material_idx"] = self.label_encoder.fit_transform(self.df["material"])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load waveform + extract features
        wav_path = os.path.join(self.root, row["filename"])
        feats = extract_features(wav_path)

        material = row["material_idx"]
        dist = row["distance_m"]

        return (
            torch.tensor(feats, dtype=torch.float32), # X
            torch.tensor(material, dtype=torch.long), # material label
            torch.tensor([dist], dtype=torch.float32) # distance
        )

In [7]:
# Neural Network

class EchoNet(nn.Module):
    def __init__(self, in_dim=4, num_classes=4):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, 32)
        self.fc2 = nn.Linear(32, 16)
        self.classifier = nn.Linear(16, num_classes) # material classification
        self.regressor = nn.Linear(16, 1) # distance regression

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.classifier(x), self.regressor(x)

In [None]:
# Train model

def train(dataset_folder="dataset"):

    csv_path = os.path.join(dataset_folder, "chirp_data.csv")
    ds = ChirpDataset(csv_path, dataset_folder)
    loader = DataLoader(ds, batch_size=16, shuffle=True)

    num_classes = len(ds.label_encoder.classes_)
    model = EchoNet(in_dim=4, num_classes=num_classes)

    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(50):
        total_cls = 0.0
        total_reg = 0.0

        for feats, mat, dist in loader:
            opt.zero_grad()

            logits, pred_dist = model(feats)

            loss_cls = F.cross_entropy(logits, mat)
            loss_reg = F.mse_loss(pred_dist.squeeze(), dist.squeeze())

            loss = loss_cls + 0.1 * loss_reg # weight regression lightly
            loss.backward()
            opt.step()

            total_cls += loss_cls.item()
            total_reg += loss_reg.item()

        print(f"Epoch {epoch:02d} | cls={total_cls:.3f} | reg={total_reg:.3f}")

    # Save PyTorch model
    torch.save({
        "model": model.state_dict(),
        "classes": ds.label_encoder.classes_
    }, "echo_model.pth")

In [None]:
# Test model on unseen data
 example = torch.rand(1, 4)
    model.eval()
    traced = torch.jit.trace(model, example)
    traced.save("echo_model_mobile.pt")

    print("\nSaved:")
    print(" - echo_model.pth (PyTorch)")
    print(" - echo_model_mobile.pt (TorchScript for Android)")
    print("\nClasses:", list(ds.label_encoder.classes_))