## Libraries

In [1]:
import numpy as np
import random
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

## Model

In [2]:
class JudgeModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 64,
        num_classes: int = 3
    ):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        logits = self.net(x)
        probs = F.softmax(logits, dim=-1)

        return {
            "logits": logits,
            "probs": probs
        }

## Dataset

In [3]:
class JudgeDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X)
        self.y = torch.tensor(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## Generation synthetic data

In [4]:
VERDICTS = {"publish": 0, "revise": 1, "reject": 2}

def generate_judge_sample():
    risk = random.random()
    uncertainty = random.random() * 0.3
    quality = random.random()

    aspects = [random.uniform(0.3, 1.0) for _ in range(4)]

    features = [
        risk,
        uncertainty,
        quality,
        1 - np.std(aspects),
        *aspects
    ]

    if risk > 0.7:
        label = VERDICTS["reject"]
    elif risk < 0.3 and quality > 0.7:
        label = VERDICTS["publish"]
    else:
        label = VERDICTS["revise"]

    return np.array(features, dtype="float32"), label


def generate_judge_dataset(n=5000):
    X, y = [], []
    for _ in range(n):
        f, l = generate_judge_sample()
        X.append(f)
        y.append(l)
    return X, y

## Train function

In [6]:
def train_judge(dataset, input_dim: int):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = JudgeModel(input_dim=input_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    epochs = 10

    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        loop = tqdm.tqdm(loader, desc=f"Epoch {epoch + 1}")

        for x, y in loop:
            x, y = x.to(device), y.to(device)
            
            out = model(x)
            loss = criterion(out["logits"], y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch + 1}: loss={epoch_loss / len(loader):.4f}")

    return model

## Train

In [7]:
X, y = generate_judge_dataset()
dataset = JudgeDataset(X, y)

model = train_judge(dataset, input_dim=X[0].shape[0])
torch.save(model.state_dict(), "judge.pt")

  self.X = torch.tensor(X)
Epoch 1: 100%|██████████| 157/157 [00:00<00:00, 406.24it/s, loss=0.382]


Epoch 1: loss=0.6849


Epoch 2: 100%|██████████| 157/157 [00:00<00:00, 931.93it/s, loss=0.167]


Epoch 2: loss=0.2959


Epoch 3: 100%|██████████| 157/157 [00:00<00:00, 958.50it/s, loss=0.0966]


Epoch 3: loss=0.2133


Epoch 4: 100%|██████████| 157/157 [00:00<00:00, 927.12it/s, loss=0.174]


Epoch 4: loss=0.1845


Epoch 5: 100%|██████████| 157/157 [00:00<00:00, 873.27it/s, loss=0.0926]


Epoch 5: loss=0.1672


Epoch 6: 100%|██████████| 157/157 [00:00<00:00, 554.16it/s, loss=0.271] 


Epoch 6: loss=0.1545


Epoch 7: 100%|██████████| 157/157 [00:00<00:00, 732.67it/s, loss=0.189] 


Epoch 7: loss=0.1452


Epoch 8: 100%|██████████| 157/157 [00:00<00:00, 937.12it/s, loss=0.0372]


Epoch 8: loss=0.1357


Epoch 9: 100%|██████████| 157/157 [00:00<00:00, 910.22it/s, loss=0.131]


Epoch 9: loss=0.1287


Epoch 10: 100%|██████████| 157/157 [00:00<00:00, 561.94it/s, loss=0.292]


Epoch 10: loss=0.1325
