In [1]:
from FlavorFormer import *
from augment import *
from readBruker import *
import pandas as pd

# Read NMR spectra

In [2]:
flavor_standards = read_bruker_hs('data/flavor_standards', False, True, False)
known_flavor_mixtures = read_bruker_hs('data/known_flavor_mixtures', False, True, False)

Read Bruker H-NMR files: 100%|██████████| 24/24 [01:00<00:00,  2.54s/it]
Read Bruker H-NMR files: 100%|██████████| 21/21 [00:10<00:00,  2.09it/s]


# Data augment

In [None]:
output_dir = 'data/augmented_data/50000_samples'
augment_and_split_data(
        flavor_standards,
        output_dir,
        augment_samples=50000,
        max_pc=7,
        noise_level=0.001
    )

Data augmentation: 100%|██████████| 500/500 [00:02<00:00, 240.13it/s]


Train dataset saved to data/augmented_data/50000_samples\train_dataset.pkl.
Val dataset saved to data/augmented_data/50000_samples\val_dataset.pkl.
Test dataset saved to data/augmented_data/50000_samples\test_dataset.pkl.


# Load data

In [4]:
BATCH_SIZE = 256

train_data_pattern = 'data/augmented_data/50000_samples/train_dataset.pkl'
valid_data_pattern = 'data/augmented_data/50000_samples/val_dataset.pkl'

train_loader = get_data_loader(NMRDataset(train_data_pattern), batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_data_loader(NMRDataset(valid_data_pattern), batch_size=BATCH_SIZE, shuffle=True)

# Train model

In [None]:
train_loader = get_data_loader(NMRDataset(train_data_pattern), batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_data_loader(NMRDataset(valid_data_pattern), batch_size=BATCH_SIZE, shuffle=True)

set_random_seed(42)
# define CNN 
cnn_params = { 
    "input_dim": 39042, 
    "output_dim": 128,
    "channels": [64, 64, 64], 
    "kernel_size": [5, 5, 5],    
    "stride": [3, 3, 3],         
    "pool_size": [2, 2, 2],      
    "dropout": 0.1
}

cnn_model = CNNFeatureExtractor(**cnn_params)

# define Transformer 
transformer_params = {
    "input_dim": cnn_params["output_dim"],
    "hidden_dim":64,
    "num_heads": 4,
    "num_layers": 4,
    "dropout": 0.1,
}
transformer_model = CrossEncoder(**transformer_params)
transformer_params['input_shape'] = (BATCH_SIZE, cnn_output_shape[2], cnn_output_shape[1])
learning_rate = 5e-5
weight_decay = 1e-6
num_epochs = 100
positive_weight = 3.0
save_path = f"model/epoch_{num_epochs}_lr_{learning_rate}_wd_{weight_decay}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trainer = NMRTrainer(cnn_model, transformer_model, device, learning_rate, weight_decay, tb_logger, positive_weight = positive_weight)

trainer.train(train_loader, val_loader, num_epochs, save_path)

# save model
trainer.save_model(
cnn_model=cnn_model,
transformer_model=transformer_model,
cnn_params=cnn_params,
transformer_params=transformer_params,
dataset_name=train_data_pattern,
save_path=save_path,
batch_size=BATCH_SIZE,
learning_rate=learning_rate,
weight_decay=weight_decay,
positive_weight = positive_weight,
num_epochs=num_epochs,
cnn_output_shape=cnn_output_shape)


# Load the model

In [None]:
# define CNN 
cnn_params = { 
    "input_dim": 39042, 
    "output_dim": 128,
    "channels": [64, 64, 64], 
    "kernel_size": [5, 5, 5],    
    "stride": [3, 3, 3],         
    "pool_size": [2, 2, 2],      
    "dropout": 0.1
}

cnn_model = CNNFeatureExtractor(**cnn_params)

# define Transformer 
transformer_params = {
    "input_dim": cnn_params["output_dim"],
    "hidden_dim":64,
    "num_heads": 4,
    "num_layers": 4,
    "dropout": 0.1
}
transformer_model = CrossEncoder(**encoder_params)
save_path = "model"

model_paths = {
    "cnn": f"{save_path}/cnn_params.pth",
    "transformer": f"{save_path}/encoder_params.pth"
}
cnn_model.load_state_dict(torch.load(model_paths["cnn"]))
transformer_model.load_state_dict(torch.load(model_paths["transformer"]))
print("Load the model successfully")

Load the model successfully


# Define functions

In [17]:
from sklearn.metrics import confusion_matrix
from config_ import known_dict 
def judge_confusion_matrix(df, name_en, probability, true_comp):
    if df[probability] >= 0.5:
        if df[name_en] in true_comp:
            return 'TP'
        else:
            return 'FP'
    else:
        if df[name_en] in true_comp:
            return 'FN'
        else:
            return 'TN'
    
def predict_test(cnn_model, transformer_model, test_loader, device):

    y_true, y_pred = [], []
    with torch.no_grad():
        for standard, mixture, labels in tqdm(test_loader, desc="Testing"):
            standard, mixture, labels = (standard.to(device), mixture.to(device), labels.to(device))
            
            standard_features = cnn_model(standard).permute(0, 2, 1)
            mixture_features = cnn_model(mixture).permute(0, 2, 1)

            outputs = transformer_model(standard_features, mixture_features)
            probabilities = torch.sigmoid(outputs)  
            predictions = (probabilities.squeeze() > 0.5).float()

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    cnf_matrix = confusion_matrix(y_true, y_pred)
    print("Test set confusion matrix：\n", cnf_matrix)
    return cnf_matrix

def predict_known(cnn_model, transformer_model, known_spectra, stds, device):

    cnn_model.eval()
    transformer_model.eval()

    out_results = pd.DataFrame()
    confusion_counts = pd.DataFrame(columns=['Name', 'TN', 'TP', 'FN', 'FP'])
    stds_name_order = [std['name'] for std in stds]  
    all_results = []
    confusion_stats = []
    for t in tqdm(range(len(known_spectra)), desc="Predicting Known Spectra"):
        query = known_spectra[t]
        sample_name = query['name']
        n = len(stds)

        R = np.array([stds[i]['fid'] for i in range(n)], dtype=np.float32)
        Q = np.array([query['fid']] * n, dtype=np.float32)

        with torch.no_grad():
            r_features = cnn_model(torch.tensor(R, device=device)).permute(0, 2, 1)
            q_features = cnn_model(torch.tensor(Q, device=device)).permute(0, 2, 1)
            yp = torch.sigmoid(transformer_model(r_features, q_features)).cpu().numpy()

        result_df = pd.DataFrame({
            'Mixture': sample_name, 
            'Compound': stds_name_order,
            'Probability': yp[:, 0]
        })

        result_df['Confusion'] = result_df.apply(
            judge_confusion_matrix, axis=1, args=('Compound', 'Probability', known_dict[sample_name])
        )

        result_df = result_df.sort_values(by='Probability', ascending=False).reset_index(drop=True)
        
        result_df['Rank'] = range(1, len(result_df)+1)

        all_results.append(result_df)
        
        counts = result_df['Confusion'].value_counts().to_dict()
        confusion_stats.append({
            'Name': sample_name,
            'TP': counts.get('TP',0),
            'FP': counts.get('FP',0),
            'FN': counts.get('FN',0),
            'TN': counts.get('TN',0)
        })

        confusion_counts = pd.DataFrame(confusion_stats)
        out_results = pd.concat(all_results, ignore_index=True).sort_values(['Mixture','Rank'])

    return confusion_counts, out_results

# Predict results

In [None]:
test_data_path = 'data/augmented_data/50000_samples/test_dataset.pkl'
batch_size=256
test_loader = get_data_loader(NMRDataset(test_data_path), batch_size=batch_size, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn_model.to(device)
transformer_model.to(device)
# test set
test_cnf_matrix = predict_test(cnn_model, transformer_model, test_loader, device)

# known mixtures
confusion_stats, detailed_results = predict_known(cnn_model, transformer_model, 
                                                known_flavor_mixtures, flavor_standards , device)
print("Confusion matrix for known mixtures:")
print(confusion_stats)
print("Detailed results for known mixtures:")
print(detailed_results)

Testing: 100%|██████████| 40/40 [00:06<00:00,  6.54it/s]


Test set confusion matrix：
 [[4923   56]
 [  39 4982]]


Predicting Known Spectra: 100%|██████████| 21/21 [00:00<00:00, 54.14it/s]

Confusion matrix for known mixtures:
   Name  TP  FP  FN  TN
0    A1   2   0   0  22
1    A2   2   1   0  21
2    A3   2   0   0  22
3    B1   3   0   0  21
4    B2   3   0   0  21
5    B3   3   0   0  21
6    B4   3   0   0  21
7    B5   3   3   0  18
8    B6   3   1   0  20
9    B7   3   4   0  17
10   C1   4   0   0  20
11   C2   4   0   0  20
12   C3   4   1   0  19
13   C4   4   0   0  20
14   C5   4   2   0  18
15   C6   3   2   1  18
16   D1   5   2   0  17
17   D2   3   1   2  18
18   D3   4   1   1  18
19   D4   4   0   1  19
20   D5   5   0   0  19
Detailed results for known mixtures:
    Mixture            Compound  Probability Confusion  Rank
0        A1     Linalyl acetate     0.999927        TP     1
1        A1            β-Ionone     0.998698        TP     2
2        A1   Isopentyl acetate     0.077915        TN     3
3        A1           Nerolidol     0.000803        TN     4
4        A1         1,8-Cineole     0.000383        TN     5
..      ...                 ... 


