In [5]:
import os
import torch
import polars as pl
from typing import Callable
from torch.utils.data import DataLoader


# Model and Data structures
from models.FusionCNN import ContextFusionCNN
from data.source.pg_experiment import get_example_dataframe
from dataset import Cast, TorchDataset
from pytorch_dataloader import build_collate_fn, PaddingCollate, DefaultCollate, MemoryLoadedDataLoader
from transformation import Channels, RMSEnergy, TorchVadMFCC, ZeroCrossingRate

# Environment setup
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
df_pron, df_tone = get_example_dataframe()

# Get the target words with val accuracy above 70%
TARGET_WORDS = ["a0", "a1", "a100", "a2", "a3", "a5", "a8"]
dataframe = df_pron.filter(pl.col("word_id").is_in(TARGET_WORDS))

dataframe = dataframe.with_columns([
    pl.struct("word_id").rank("dense").alias("word_id"),
    pl.col("value").cast(pl.Float32) 
])
dataframe = dataframe.filter((pl.col("stage") == 1))

N_WORDS = dataframe.select(pl.col("word_id").n_unique()).item()
print(f"Number of unique words: {N_WORDS}")
print(f"Number of samples: {dataframe.shape[0]}")

Number of unique words: 7
Number of samples: 48


In [3]:
dataframe.head()

id_student,value,word_id,rec_path,stage
i64,f32,u32,str,i32
1603,1.0,1,"""/home/kamil2002/Mandarin_Pronu…",1
1580,1.0,1,"""/home/kamil2002/Mandarin_Pronu…",1
1593,1.0,1,"""/home/kamil2002/Mandarin_Pronu…",1
1686,0.0,1,"""/home/kamil2002/Mandarin_Pronu…",1
1687,1.0,1,"""/home/kamil2002/Mandarin_Pronu…",1


In [6]:
to_dataset: Callable[[pl.DataFrame], TorchDataset] = lambda dataframe: TorchDataset(
    Cast(dataframe.get_column("rec_path"), Channels("stack","multiply")(
            TorchVadMFCC(delta=0),
        )),
    Cast(dataframe.get_column("rec_path"), Channels("cat","multiply")(
            ZeroCrossingRate(),
            RMSEnergy(),
        )),
    Cast(dataframe.get_column("word_id"), lambda x: torch.tensor(x-1, dtype=torch.long)),
    Cast(dataframe.get_column("value"), lambda x: torch.tensor(x).float()),
)

collate_fn = build_collate_fn(
    PaddingCollate(mode="SET_MAX_LEN", max_len=80, pad_dim=2),
    PaddingCollate(mode="SET_MAX_LEN", max_len=160, pad_dim=1),
    DefaultCollate(),
    DefaultCollate(),
)


dataset_demo = to_dataset(dataframe)

#note, if you are using Windows you MUST set `num_workers=0` - TL;DT multithreading DON'T work in notebooks because Windows DON'T have `fork()`
num_workers = 0 if os.name == "nt" else 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
demo_loader = DataLoader(dataset_demo, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=num_workers)

demo_loader = MemoryLoadedDataLoader(demo_loader, device=device)

for x in next(iter(demo_loader)):
    print(x.shape)

torch.Size([1, 1, 40, 80])
torch.Size([1, 2, 160])
torch.Size([1])
torch.Size([1])


In [7]:
from models.FusionCNN import ContextFusionCNN

model = ContextFusionCNN(1, 2, num_words=N_WORDS).to(device)

model_path = "ContextFusionCNN.pth"
model.load_state_dict(torch.load(model_path, map_location=device))

model.eval()

print("Model has been successfully loaded and is ready for testing.")

Model has been successfully loaded and is ready for testing.


In [8]:
results = []
correct_predictions = 0

with torch.no_grad():
    for i, batch in enumerate(demo_loader):
        # Unpack the batch
        x_mfcc, x_context, word_ids, targets = batch
        
        # Model Forward Pass
        output = model(x_mfcc, x_context, word_ids)
        
        # Decision logic (Threshold 0.5)
        prediction = 1 if output.item() > 0.5 else 0
        target = int(targets.item())
        
        # Track accuracy
        if prediction == target:
            correct_predictions += 1
        
        # Get metadata for the printout
        original_row = dataframe.row(i, named=True)
        
        results.append({
            "Student ID": original_row["id_student"],
            "Word": original_row["word_id"],
            "Expert": "Correct" if target == 1 else "Incorrect",
            "Model": "Correct" if prediction == 1 else "Incorrect",
            "Status": "✅" if prediction == target else "❌"
        })

# Final Accuracy Calculation
final_accuracy = correct_predictions / len(dataframe)

# Display the results
print(f"{'Student ID':<12} | {'Word':<8} | {'Expert':<12} | {'Model':<12} | {'Status'}")
print("-" * 65)
for res in results:
    print(f"{res['Student ID']:<12} | {res['Word']:<8} | "
          f"{res['Expert']:<12} | {res['Model']:<12} | {res['Status']}")

print("-" * 65)
print(f"FINAL DEMO ACCURACY: {final_accuracy:.2%}")

Student ID   | Word     | Expert       | Model        | Status
-----------------------------------------------------------------
1603         | 1        | Correct      | Correct      | ✅
1580         | 1        | Correct      | Correct      | ✅
1593         | 1        | Correct      | Correct      | ✅
1686         | 1        | Incorrect    | Correct      | ❌
1687         | 1        | Correct      | Incorrect    | ❌
1699         | 1        | Correct      | Correct      | ✅
1615         | 1        | Incorrect    | Correct      | ❌
1603         | 2        | Correct      | Incorrect    | ❌
1580         | 2        | Correct      | Correct      | ✅
1593         | 2        | Incorrect    | Incorrect    | ✅
1686         | 2        | Incorrect    | Incorrect    | ✅
1687         | 2        | Correct      | Correct      | ✅
1699         | 2        | Correct      | Correct      | ✅
1615         | 2        | Correct      | Correct      | ✅
1603         | 4        | Correct      | Incorrect    | ❌
1