In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from sklearn.metrics import mean_squared_error
import numpy as np
from sentence_transformers import SentenceTransformer

In [2]:

train_df = pd.read_csv("train.csv")
val_df = pd.read_csv("val.csv")
test_df = pd.read_csv("test.csv")

In [7]:
model_st = SentenceTransformer("all-MiniLM-L6-v2")
X_train = model_st.encode(train_df["query"].tolist(), convert_to_tensor=True)
X_val = model_st.encode(val_df["query"].tolist(), convert_to_tensor=True)

y_train = torch.tensor(train_df["carb"].values, dtype=torch.float32).unsqueeze(1)
y_val = torch.tensor(val_df["carb"].values, dtype=torch.float32).unsqueeze(1)
X_train = X_train.unsqueeze(1)  # (N, 1, 384)
X_val = X_val.unsqueeze(1)

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

X_train.shape

torch.Size([8000, 1, 384])

In [59]:
class LSTMModel(nn.Module):
    def __init__(self, input_size=384, hidden_size=128, num_layers=3, dropout=0.3):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers,
                            dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)          # x: (batch, seq_len, input_size)
        out = out[:, -1, :]            # Get last time step
        return self.fc(out)


In [60]:
import math
import torch.nn.functional as F

model = LSTMModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.MSELoss()

for epoch in range(100):
    model.train()
    train_preds = []
    train_targets = []
    total_loss = 0

    for xb, yb in train_loader:
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        train_preds.append(pred.detach())
        train_targets.append(yb)

    train_preds = torch.cat(train_preds)
    train_targets = torch.cat(train_targets)
    train_rmse = math.sqrt(F.mse_loss(train_preds, train_targets).item())

    print(f"Epoch {epoch+1}, Train Loss: {total_loss:.4f}, Train RMSE: {train_rmse:.2f}")

Epoch 1, Train Loss: 482155.1545, Train RMSE: 43.92
Epoch 2, Train Loss: 417613.1095, Train RMSE: 40.87
Epoch 3, Train Loss: 388174.3433, Train RMSE: 39.40
Epoch 4, Train Loss: 370938.8422, Train RMSE: 38.52
Epoch 5, Train Loss: 357660.5226, Train RMSE: 37.82
Epoch 6, Train Loss: 347445.6405, Train RMSE: 37.28
Epoch 7, Train Loss: 336729.3076, Train RMSE: 36.70
Epoch 8, Train Loss: 325809.0484, Train RMSE: 36.10
Epoch 9, Train Loss: 314409.1737, Train RMSE: 35.46
Epoch 10, Train Loss: 303951.8467, Train RMSE: 34.87
Epoch 11, Train Loss: 296782.4672, Train RMSE: 34.45
Epoch 12, Train Loss: 286914.7556, Train RMSE: 33.88
Epoch 13, Train Loss: 275749.1315, Train RMSE: 33.21
Epoch 14, Train Loss: 270826.2058, Train RMSE: 32.91
Epoch 15, Train Loss: 262540.5114, Train RMSE: 32.41
Epoch 16, Train Loss: 256245.5990, Train RMSE: 32.02
Epoch 17, Train Loss: 246701.2210, Train RMSE: 31.41
Epoch 18, Train Loss: 246677.7731, Train RMSE: 31.41
Epoch 19, Train Loss: 238889.8497, Train RMSE: 30.91
Ep

In [61]:
import torch.nn.functional as F
import math

model.eval()
with torch.no_grad():
    preds = []
    targets = []
    for xb, yb in val_loader:
        y_pred = model(xb)
        preds.append(y_pred)
        targets.append(yb)

    preds = torch.cat(preds)
    targets = torch.cat(targets)
    rmse = math.sqrt(F.mse_loss(preds, targets).item())
    print(f"Validation RMSE: {rmse:.2f}")

    #18.44 -> hidden_size=128, num_layers=3, dropout=0.5


Validation RMSE: 18.32


In [47]:
model.eval()
with torch.no_grad():
    preds = model(X_val).squeeze().numpy()

# Add prediction column and save
test_df["carb"] = preds
test_df.to_csv("test_with_predictions_transformer_lstm.csv", index=False)