In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from sklearn.metrics import mean_squared_error
import numpy as np
from sentence_transformers import SentenceTransformer

In [2]:

train_df = pd.read_csv("train.csv")
val_df = pd.read_csv("val.csv")
test_df = pd.read_csv("test.csv")

In [7]:
model_st = SentenceTransformer("all-MiniLM-L6-v2")
X_train = model_st.encode(train_df["query"].tolist(), convert_to_tensor=True)
X_val = model_st.encode(val_df["query"].tolist(), convert_to_tensor=True)

y_train = torch.tensor(train_df["carb"].values, dtype=torch.float32).unsqueeze(1)
y_val = torch.tensor(val_df["carb"].values, dtype=torch.float32).unsqueeze(1)
X_train = X_train.unsqueeze(1)  # (N, 1, 384)
X_val = X_val.unsqueeze(1)

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

X_train.shape

torch.Size([8000, 1, 384])

In [38]:
class LSTMModel(nn.Module):
    def __init__(self, input_size=384, hidden_size=128, num_layers=3, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers,
                            dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)          # x: (batch, seq_len, input_size)
        out = out[:, -1, :]            # Get last time step
        return self.fc(out)


In [42]:
import math
import torch.nn.functional as F

model = LSTMModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.MSELoss()

for epoch in range(100):
    model.train()
    train_preds = []
    train_targets = []
    total_loss = 0

    for xb, yb in train_loader:
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        train_preds.append(pred.detach())
        train_targets.append(yb)

    train_preds = torch.cat(train_preds)
    train_targets = torch.cat(train_targets)
    train_rmse = math.sqrt(F.mse_loss(train_preds, train_targets).item())

    print(f"Epoch {epoch+1}, Train Loss: {total_loss:.4f}, Train RMSE: {train_rmse:.2f}")

Epoch 1, Train Loss: 478781.0383, Train RMSE: 43.76
Epoch 2, Train Loss: 411399.3261, Train RMSE: 40.57
Epoch 3, Train Loss: 385616.9707, Train RMSE: 39.27
Epoch 4, Train Loss: 369906.6596, Train RMSE: 38.47
Epoch 5, Train Loss: 359272.4746, Train RMSE: 37.91
Epoch 6, Train Loss: 348966.3570, Train RMSE: 37.36
Epoch 7, Train Loss: 339618.0681, Train RMSE: 36.86
Epoch 8, Train Loss: 329203.4835, Train RMSE: 36.29
Epoch 9, Train Loss: 320694.2627, Train RMSE: 35.82
Epoch 10, Train Loss: 310047.8438, Train RMSE: 35.22
Epoch 11, Train Loss: 300704.5149, Train RMSE: 34.68
Epoch 12, Train Loss: 294724.4025, Train RMSE: 34.34
Epoch 13, Train Loss: 287573.5440, Train RMSE: 33.92
Epoch 14, Train Loss: 278916.2472, Train RMSE: 33.40
Epoch 15, Train Loss: 274414.1807, Train RMSE: 33.13
Epoch 16, Train Loss: 266313.1424, Train RMSE: 32.64
Epoch 17, Train Loss: 261228.8691, Train RMSE: 32.33
Epoch 18, Train Loss: 255300.2610, Train RMSE: 31.96
Epoch 19, Train Loss: 247892.6529, Train RMSE: 31.49
Ep

In [46]:
import torch.nn.functional as F
import math

model.eval()
with torch.no_grad():
    preds = []
    targets = []
    for xb, yb in val_loader:
        y_pred = model(xb)
        preds.append(y_pred)
        targets.append(yb)

    preds = torch.cat(preds)
    targets = torch.cat(targets)
    rmse = math.sqrt(F.mse_loss(preds, targets).item())
    print(f"Validation RMSE: {rmse:.2f}")

    #18.44 -> hidden_size=128, num_layers=3, dropout=0.5


Validation RMSE: 19.47


In [47]:
model.eval()
with torch.no_grad():
    preds = model(X_val).squeeze().numpy()

# Add prediction column and save
test_df["carb"] = preds
test_df.to_csv("test_with_predictions_transformer_lstm.csv", index=False)