## Install dependencies

In [0]:
!pip install torch
dbutils.library.restartPython()

Collecting torch
  Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nv

## Load data from S3 bucket

In [0]:
import os
from pathlib import Path
import boto3
from botocore.client import Config

# MinIO
# from dotenv import load_dotenv
# load_dotenv(dotenv_path='secrets_minio.env')
# key_id = os.getenv('AWS_ACCESS_KEY')
# access_key = os.getenv('AWS_SECRET_KEY')
# bucket_name = os.getenv('BUCKET_NAME')
# s3_client = boto3.client('s3', endpoint_url=os.getenv('S3_ENDPOINT'), aws_access_key_id=key_id, aws_secret_access_key=access_key, config=Config(signature_version='s3v4'), region_name=os.getenv('REGION'))

# AWS 
key_id = dbutils.secrets.get(scope="my-scope", key="aws_access_key_id")
access_key = dbutils.secrets.get(scope="my-scope", key="aws_secret_access_key")
s3_client = boto3.client('s3', aws_access_key_id=key_id, aws_secret_access_key=access_key)
bucket_name = 'databricks-mvxb2etarhyoiwq5dpmkpd-cloud-storage-bucket'

In [0]:
def fetch_data(s3_client, bucket_name, file_name):
    s3_client.download_file(bucket_name, file_name, str(file_name))
    
def fetch_most_recent(s3_client, bucket_name, local_dir, local_name):
    local_dir = Path(local_dir)
    local_dir.mkdir(parents=True, exist_ok=True)

    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=str(local_dir))
    
    files = sorted(
        (item for item in response.get('Contents', []) if item['Key'].endswith('.csv')),
        key=lambda item: item['LastModified'],
        reverse=True
    )

    if len(files) == 0:
        raise Exception(f"No CSV files found in {local_dir}.")
    
    s3_client.download_file(bucket_name, files[0]['Key'], local_name)

train_data = fetch_most_recent(s3_client, bucket_name, 'train', 'train.csv')
test_data = fetch_data(s3_client, bucket_name, 'test.csv')
     

## LSTM Classifier

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_size, maxlen):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        
        self.lstm1 = nn.LSTM(embed_size, 128, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(128 * 2, 64, batch_first=True, bidirectional=True)
        
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        
        self.fc1 = nn.Linear(64 * 2, 1024)
        self.dropout1 = nn.Dropout(0.25)
        
        self.fc2 = nn.Linear(1024, 512)
        self.dropout2 = nn.Dropout(0.25)
        
        self.fc3 = nn.Linear(512, 256)
        self.dropout3 = nn.Dropout(0.25)
        
        self.fc4 = nn.Linear(256, 128)
        self.dropout4 = nn.Dropout(0.25)
        
        self.fc5 = nn.Linear(128, 64)
        self.dropout5 = nn.Dropout(0.25)
        
        self.fc6 = nn.Linear(64, 4)

    def forward(self, x):
        x = self.embedding(x)                      # (batch_size, maxlen, embed_size)
        
        x, _ = self.lstm1(x)                       # (batch_size, maxlen, 256)
        x, _ = self.lstm2(x)                       # (batch_size, maxlen, 128)
        
        x = x.permute(0, 2, 1)                     # (batch_size, 128, maxlen)
        x = self.global_max_pool(x).squeeze(2)     # (batch_size, 128)
        
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        
        x = F.relu(self.fc3(x))
        x = self.dropout3(x)
        
        x = F.relu(self.fc4(x))
        x = self.dropout4(x)
        
        x = F.relu(self.fc5(x))
        x = self.dropout5(x)
        
        x = self.fc6(x)                            # (batch_size, 4)
        return F.softmax(x, dim=1)                 # Apply softmax over classes

# model = TextClassificationModel(vocab_size=10000, embed_size=300, maxlen=100)


## Tokenizer

In [0]:
from collections import Counter
import torch
import pickle
import json

class Tokenizer:
    def __init__(self, lower=True, max_vocab_size=None, pad_token="<pad>", unk_token="<unk>", maxlen=None):
        self.lower = lower
        self.max_vocab_size = max_vocab_size
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.maxlen = maxlen
        self.word2idx = {}
        self.idx2word = {}
        self.vocab = None

    def fit_on_texts(self, texts):
        # Tokenize
        tokens = []
        for text in texts:
            if self.lower:
                text = text.lower()
            tokens.extend(text.split())

        # Count words and build vocab
        word_freq = Counter(tokens)
        most_common = word_freq.most_common(self.max_vocab_size)
        vocab_words = [self.pad_token, self.unk_token] + [word for word, _ in most_common]

        self.word2idx = {word: idx for idx, word in enumerate(vocab_words)}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}

    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            if self.lower:
                text = text.lower()
            tokens = text.split()
            seq = [self.word2idx.get(token, self.word2idx[self.unk_token]) for token in tokens]
            sequences.append(seq)
        return sequences

    def pad_sequences(self, sequences, padding="post", truncating="post"):
        pad_val = self.word2idx[self.pad_token]
        maxlen = self.maxlen or max(len(seq) for seq in sequences)

        padded = []
        for seq in sequences:
            if len(seq) < maxlen:
                pad_length = maxlen - len(seq)
                if padding == "post":
                    seq = seq + [pad_val] * pad_length
                else:
                    seq = [pad_val] * pad_length + seq
            elif len(seq) > maxlen:
                if truncating == "post":
                    seq = seq[:maxlen]
                else:
                    seq = seq[-maxlen:]
            padded.append(seq)
        return torch.tensor(padded, dtype=torch.long)

    def __len__(self):
        return len(self.word2idx)

    def vocab_size(self):
        return len(self.word2idx)
   
    def save(self, filepath):
        tokenizer_data = {
            "lower": self.lower,
            "max_vocab_size": self.max_vocab_size,
            "pad_token": self.pad_token,
            "unk_token": self.unk_token,
            "maxlen": self.maxlen,
            "word2idx": self.word2idx
        }
        with open(filepath, "w") as f:
            json.dump(tokenizer_data, f)

    @staticmethod
    def load(filepath):
        with open(filepath) as f:
            data = json.load(f)

        tokenizer = Tokenizer(
            lower=data["lower"],
            max_vocab_size=data["max_vocab_size"],
            pad_token=data["pad_token"],
            unk_token=data["unk_token"],
            maxlen=data["maxlen"]
        )
        tokenizer.word2idx = data["word2idx"]
        tokenizer.idx2word = {int(v): k for k, v in tokenizer.word2idx.items()}
        return tokenizer

## Training and evaluation

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
import mlflow
import mlflow.pytorch  # For logging PyTorch models
from mlflow.models.signature import infer_signature

def load_dataset(file_path, tokenizer, fit_tokenizer=False):
    data = pd.read_csv(file_path, header=0, names=['ID', 'Title', 'Desc'])
    texts = data['Desc'].tolist()
    labels = data['ID'].tolist()  # Assuming 'ID' is the label column
    
    if fit_tokenizer:
        tokenizer.fit_on_texts(texts)

    sequences = tokenizer.texts_to_sequences(texts)
    padded_sequences = tokenizer.pad_sequences(sequences)

    vocab_size = tokenizer.vocab_size() + 1

    label_tensor = torch.tensor(labels, dtype=torch.long)
    label_tensor = label_tensor - 1  # Adjust labels to be zero-indexed
    dataset = torch.utils.data.TensorDataset(padded_sequences, label_tensor)
    
    return dataset, vocab_size

def train():
    vocab_size = 10000
    embed_size = 32
    maxlen = 100
    batch_size = 64
    epochs = 3
    learning_rate = 1e-4
    patience = 2  # early stopping patience

    tokenizer = Tokenizer(max_vocab_size=vocab_size, maxlen=maxlen)

    # data_path = '/Workspace/Users/sannikov.pn@ucu.edu.ua/'
    data_path = './'

    train_url = data_path + 'train.csv'
    test_url = data_path + 'test.csv'

    train_dataset, vocab_size = load_dataset(train_url, tokenizer, fit_tokenizer=True)
    val_dataset, _ = load_dataset(test_url, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = TextClassificationModel(vocab_size, embed_size, maxlen).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Early stopping
    best_val_acc = 0.0
    counter = 0

    # mlflow.set_tracking_uri('http://localhost:5000')

    with mlflow.start_run():
        mlflow.log_param("vocab_size", vocab_size)
        mlflow.log_param("embed_size", embed_size)
        mlflow.log_param("maxlen", maxlen)
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("epochs", epochs)
        mlflow.log_param("learning_rate", learning_rate)
        mlflow.log_param("patience", patience)

        for epoch in range(epochs):
            model.train()
            total_loss = 0
            correct = 0
            total = 0

            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            train_acc = correct / total
            mlflow.log_metric("train_loss", total_loss, step=epoch)
            mlflow.log_metric("train_accuracy", train_acc, step=epoch)
            print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f}, Train Acc: {train_acc:.4f}")

            # Validation
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs, 1)
                    correct += (predicted == labels).sum().item()
                    total += labels.size(0)

            val_acc = correct / total
            mlflow.log_metric("val_accuracy", val_acc, step=epoch)
            print(f"Validation Accuracy: {val_acc:.4f}")

            # Early stopping logic
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                counter = 0
                
                # torch.save(model.state_dict(), "best_model.pt")

                # Log the final model
                example_input = torch.randint(0, vocab_size, (5, maxlen))  # batch of 5
                example_output = model(example_input.to(device)).detach().cpu().numpy()
                input_df = pd.DataFrame(example_input.numpy(), columns=[f"token_{i}" for i in range(maxlen)])
                output_df = pd.DataFrame(example_output, columns=[f"class_{i}" for i in range(4)])
                signature = infer_signature(input_df, output_df)

                tokenizer.save("tokenizer.json")
                mlflow.pytorch.log_model(model, "model", signature=signature)
                
                scripted_model = torch.jit.script(model)
                scripted_model.save("best_model.pt")
                mlflow.log_artifact("best_model.pt", artifact_path="model")
                mlflow.log_artifact("tokenizer.json", artifact_path="model")
            else:
                counter += 1
                if counter >= patience:
                    print(f"Early stopping triggered at epoch {epoch+1}")
                    break
        
        mlflow.log_metric("best_val_accuracy", best_val_acc, step=epoch)
        model_name = "Classifier"
        model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"
        model_info = mlflow.register_model(model_uri, model_name)

        print("Training complete. Model saved.")
        return model_info

challenger_info = train()

Epoch 1/3 - Loss: 514.5959, Train Acc: 0.2870
Validation Accuracy: 0.3713




Epoch 2/3 - Loss: 486.9294, Train Acc: 0.4008
Validation Accuracy: 0.4242




Epoch 3/3 - Loss: 470.1351, Train Acc: 0.4503
Validation Accuracy: 0.4799


Registered model 'Classifier' already exists. Creating a new version of this model...


Training complete. Model saved.
🏃 View run burly-crab-361 at: https://dbc-b6ddbecb-76c6.cloud.databricks.com/ml/experiments/552872594943880/runs/9a180be6eede4fa99191dc41550c8ef2
🧪 View experiment at: https://dbc-b6ddbecb-76c6.cloud.databricks.com/ml/experiments/552872594943880


Created version '6' of model 'workspace.default.classifier'.


## Model Version Management with MLFlow

Ensure that the best model (Champion) is always in use 

In [0]:
from mlflow.tracking import MlflowClient

client = MlflowClient()
if challenger_info.version == '1':
    print("No champion model found. Champion model set to challenger.")
    client.set_registered_model_alias(challenger_info.name, "Champion", challenger_info.version)
else:
    champion_version = client.get_model_version_by_alias(challenger_info.name, "Champion")
    champion_accuracy = client.get_metric_history(champion_version.run_id, "val_accuracy")[0].value
    challenger_accuracy = client.get_metric_history(challenger_info.run_id, "best_val_accuracy")[0].value
    print(f"Champion model: {champion_accuracy}")
    print(f"Challenger model: {challenger_accuracy}")

    if champion_accuracy < challenger_accuracy:
        client.set_registered_model_alias(champion_version.name, "Retired", champion_version.version)
        client.set_registered_model_alias(challenger_info.name, "Champion", challenger_info.version)

Champion model: 0.39552631578947367
Challenger model: 0.47986842105263156


## Automatic deployment
Disabled in trial version

In [0]:
# import json
# import requests

# champion_info = client.get_model_version_by_alias("Classifier", "Champion")

# API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
# API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# data = {
#     "name": "NewsClassification",
#     "config": {
#         "served_entities": [
#             {
#                 "entity_name": champion_info.name,
#                 "entity_version": champion_info.version,
#                 "workload_size": "Small",
#                 "scale_to_zero_enabled": False,
#                 "workload_type": "CPU",
#             }
#         ]
#     },
# }

# headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

# response = requests.post(
#     url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers
# )

# print(json.dumps(response.json(), indent=4))

# {
#     "error_code": "FEATURE_DISABLED",
#     "message": "Model serving is not available for trial workspaces. Please contact your organization admin or Databricks support."
# }