In [1]:
%pip install torch tiktoken onnx onnxruntime

Collecting onnx
  Using cached onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl.metadata (16 kB)
Collecting onnxruntime
  Using cached onnxruntime-1.20.1-cp312-cp312-macosx_13_0_universal2.whl.metadata (4.5 kB)
Collecting sympy==1.13.1 (from torch)
  Using cached sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting coloredlogs (from onnxruntime)
  Using cached coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting flatbuffers (from onnxruntime)
  Using cached flatbuffers-24.3.25-py2.py3-none-any.whl.metadata (850 bytes)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Using cached humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Using cached sympy-1.13.1-py3-none-any.whl (6.2 MB)
Using cached onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl (16.7 MB)
Using cached onnxruntime-1.20.1-cp312-cp312-macosx_13_0_universal2.whl (31.0 MB)
Using cached coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
Using cached flatbuffers-24.3.25-py2.py3-none-any.w

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import tiktoken
import numpy as np

In [3]:
# Initialise tiktoken tokeniser
tokeniser = tiktoken.get_encoding("cl100k_base")

def encode_text(text, max_length):
    tokens = tokeniser.encode(text, allowed_special={"<|endoftext|>"})
    if len(tokens) > max_length:
        tokens = tokens[:max_length]  # Truncate
    else:
        tokens += [0] * (max_length - len(tokens))  # Pad
    return tokens

In [4]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, max_length):
        self.texts = texts
        self.labels = labels
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        input_ids = torch.tensor(encode_text(text, self.max_length), dtype=torch.long)
        attention_mask = (input_ids != 0).long()  # Mask non-padding tokens
        return input_ids, attention_mask, torch.tensor(label, dtype=torch.long)

In [5]:
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes, max_seq_length):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        embedded = self.embedding(input_ids) + self.positional_encoding[:, :input_ids.size(1), :]
        transformer_output = self.transformer_encoder(
            embedded.transpose(0, 1),  # (seq_len, batch, embed_dim)
            src_key_padding_mask=~attention_mask.bool()  # Inverse mask
        )
        pooled_output = transformer_output.mean(dim=0)  # Mean pooling
        logits = self.fc(pooled_output)
        return logits

In [7]:
# Hyperparameters
vocab_size = tokeniser.n_vocab  # Tokeniser vocabulary size
embed_dim = 128
num_heads = 4
num_layers = 2
num_classes = 3  # Example: 3 classes
max_seq_length = 128
learning_rate = 1e-4
batch_size = 32
epochs = 5

In [8]:
def train_model(model, dataloader, epochs, learning_rate, device):
    model = model.to(device)
    optimiser = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for input_ids, attention_mask, labels in dataloader:
            input_ids, attention_mask, labels = (
                input_ids.to(device),
                attention_mask.to(device),
                labels.to(device),
            )
            optimiser.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimiser.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

In [21]:
# Example data
texts = ["I love programming", "Python is great", "I hate bugs"]
labels = [0, 1, 2]  # Example labels
dataset = TextClassificationDataset(texts, labels, max_seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialise model
model = TransformerClassifier(vocab_size, embed_dim, num_heads, num_layers, num_classes, max_seq_length)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train model
train_model(model, dataloader, epochs, learning_rate, device)

# Save model
torch.save(model.state_dict(), "/Users/paulzanna/Github/Ziggy/model/ziggy_model.bin")

Epoch 1/5, Loss: 1.1631
Epoch 2/5, Loss: 0.9526
Epoch 3/5, Loss: 0.8484
Epoch 4/5, Loss: 0.7075
Epoch 5/5, Loss: 0.5841


In [16]:
#
# Export ONNX model
#
dummy_input_ids = torch.randint(0, vocab_size, (1, max_seq_length)).to(device)
dummy_attention_mask = torch.ones(1, max_seq_length).to(device)

torch.onnx.export(
    model,
    (dummy_input_ids, dummy_attention_mask),
    "text_classifier.onnx",
    opset_version=14,
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "seq_length"},
        "attention_mask": {0: "batch_size", 1: "seq_length"},
        "logits": {0: "batch_size"},
    },
)

In [17]:
#
# Verify ONNX model
#

import onnxruntime as ort

# Load ONNX model
ort_session = ort.InferenceSession("text_classifier.onnx")

# Run inference
def predict_with_onnx(ort_session, input_ids, attention_mask):
    inputs = {
        "input_ids": input_ids.cpu().numpy(),
        "attention_mask": attention_mask.cpu().numpy(),
    }
    logits = ort_session.run(None, inputs)[0]
    return np.argmax(logits, axis=1)