In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from decifer import (
    HDF5Dataset,
)

class SimpleModel(nn.Module):
    def __init__(self, vocab_size, token_input_size, numeric_input_size, hidden_size, output_size):
        super(SimpleModel, self).__init__()
        
        # Tokenized data goes through embedding and a simple feedforward layer
        self.embedding = nn.Embedding(vocab_size, token_input_size)  # Adjust vocabulary size
        self.fc_token = nn.Linear(token_input_size, hidden_size)

        # Numeric data goes through a simple linear layer
        self.fc_numeric = nn.Linear(numeric_input_size, hidden_size)  # Make sure numeric_input_size matches the actual numeric data dimension

        # Final layer to produce the output
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, token_data, numeric_data):
        # Tokenized data
        embedded = self.embedding(token_data)
        token_output = self.fc_token(embedded.mean(dim=1))  # Take mean across tokens

        # Numeric data
        numeric_output = self.fc_numeric(numeric_data)

        # Combine both outputs
        combined_output = token_output + numeric_output

        # Final output
        return self.fc_out(combined_output)


# Function to train the model
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        # Send data to device (GPU or CPU)
        token_data, numeric_data = batch[0].to(device), batch[1].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(token_data, numeric_data)
        loss = criterion(outputs, numeric_data)  # Using numeric data as the target for simplicity
        print(loss)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)


# Function to evaluate the model
def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            token_data, numeric_data = batch[0].to(device), batch[1].to(device)

            outputs = model(token_data, numeric_data)
            loss = criterion(outputs, numeric_data)  # Using numeric data as the target for simplicity

            running_loss += loss.item()

    return running_loss / len(val_loader)


# Simple training loop
def run_training(train_loader, val_loader, vocab_size, token_input_size, numeric_input_size, hidden_size, output_size, epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize model, loss function, and optimizer
    model = SimpleModel(vocab_size, token_input_size, numeric_input_size, hidden_size, output_size).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        train_loss = train_model(model, train_loader, criterion, optimizer, device)
        val_loss = evaluate_model(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    return model


In [7]:
# Example dataset loading (adjust paths as necessary)
block_size = 12
train_dataset = HDF5Dataset('../data/chili100k/hdf5_data/train_dataset.h5', ['cif_tokenized', 'xrd_cont_y'], block_size)
val_dataset = HDF5Dataset('../data/chili100k/hdf5_data/val_dataset.h5', ['cif_tokenized', 'xrd_cont_y'], block_size)
test_dataset = HDF5Dataset('../data/chili100k/hdf5_data/test_dataset.h5', ['cif_tokenized', 'xrd_cont_y'], block_size)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Run training
token_input_size = 2  # Embedding size for tokenized input
numeric_input_size = 1000  # Size of numeric input
hidden_size = 2  # Size of hidden layers
output_size = 1000  # Output size (same as input size for simplicity)
vocab_size = 372

trained_model = run_training(train_loader, val_loader, vocab_size, token_input_size, numeric_input_size, hidden_size, output_size, epochs=5)

# Evaluate on test set (for simplicity)
test_loss = evaluate_model(trained_model, test_loader, nn.MSELoss(), torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
print(f"Test Loss: {test_loss:.4f}")

tensor(0.2649, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2322, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2175, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2119, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2024, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1963, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1981, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2057, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2013, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2047, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2002, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1956, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1866, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1962, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1953, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1853, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1877, device='cuda:0', grad_fn=

tensor(0.1107, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1102, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1102, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1068, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1077, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1084, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1068, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1057, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1061, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1047, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1026, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1043, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1024, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1012, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1038, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1026, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.1046, device='cuda:0', grad_fn=

tensor(0.0480, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0474, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0486, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0478, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0488, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0500, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0489, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0453, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0480, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0461, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0439, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0446, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0465, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0460, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0448, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0446, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0451, device='cuda:0', grad_fn=

tensor(0.0224, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0213, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0218, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0212, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0230, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0219, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0207, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0221, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0206, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0201, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0201, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0225, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0224, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0217, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0207, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0214, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0215, device='cuda:0', grad_fn=

KeyboardInterrupt: 