In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import os
import numpy as np

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
class GCMSDataEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=16):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [4]:
class SensorDataEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=16):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [5]:
# load model
gcms_model_path = "/content/drive/My Drive/Smell/Contrastive Learning/gcms_encoder.pt"
sensor_model_path = "/content/drive/My Drive/Smell/Contrastive Learning/sensor_encoder.pt"

hidden_dim = 128
embedding_dim = 16

gcms_encoder = GCMSDataEncoder(10, hidden_dim, embedding_dim)
sensor_encoder = SensorDataEncoder(13, hidden_dim, embedding_dim)

gcms_encoder.load_state_dict(torch.load(gcms_model_path))
sensor_encoder.load_state_dict(torch.load(sensor_model_path))

<All keys matched successfully>

In [None]:
# uploading gcms data
df = pd.read_csv("/content/drive/My Drive/Smell/Contrastive Learning/gcms_dataframe.csv")

# adding ambient to the df
ambient_row = pd.DataFrame([{'food_name': "ambient", 'C': 0, "Ca": 0, "H": 0, "K": 0, "Mg": 0, "N": 0, "Na":0, "O": 0, "P": 0, "Se":0,}])

df = pd.concat([df, ambient_row], ignore_index=True)

In [None]:
# getting rid of names and keeping only numerical values
df_dropped = df.drop(columns=["food_name"], errors="ignore")

gcms_data = df_dropped.values

In [None]:
available_food_names = df["food_name"].to_list()

ix_to_name = {i: name for i, name in enumerate(available_food_names)}
name_to_ix = {name: i for i, name in enumerate(available_food_names)}

In [None]:
import torch
import torch.nn.functional as F

def evaluate_retrieval(smell_matrix, gcms_data, gcms_encoder, sensor_encoder, device='cpu'):
    gcms_encoder.eval()
    sensor_encoder.eval()

    smell_matrix = torch.tensor(smell_matrix, dtype=torch.float)
    gcms_data = torch.tensor(gcms_data, dtype=torch.float)

    all_z_gcms = []
    all_z_sensor = []

    with torch.no_grad():
        gcms_data = gcms_data.to(device)
        smell_matrix = smell_matrix.to(device)

        z_gcms = gcms_encoder(gcms_data) # (15 x 16)
        z_sensor = sensor_encoder(smell_matrix) # (n x 16)

        # L2 normalize if thatâ€™s how your model was trained
        z_gcms = F.normalize(z_gcms, dim=1)
        z_sensor = F.normalize(z_sensor, dim=1)

        all_z_gcms.append(z_gcms)
        all_z_sensor.append(z_sensor)

    # Concatenate all batches
    all_z_gcms = torch.cat(all_z_gcms, dim=0)     # [N, embed_dim]
    all_z_sensor = torch.cat(all_z_sensor, dim=0) # [N, embed_dim]

    # Compute similarity matrix: shape [N, N]
    # sim[i, j] = dot( z_gcms[i], z_sensor[j] )
    sim = torch.matmul(all_z_sensor, all_z_gcms.t()) # (n x 15)

    # For each row i, find the column j with the highest similarity
    # If j == i, it means we matched the correct sensor embedding
    predicted = sim.argmax(dim=1)  # [N]

    return predicted.tolist()


In [None]:
smell_matrix = np.zeros((1, 13), dtype=float) # TODO collected smell matrix

predicted_from_smell = evaluate_retrieval(smell_matrix, gcms_data, gcms_encoder, sensor_encoder)

print([ix_to_name[ix] for ix in predicted_from_smell])

['ambient']
