In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.transforms as T
import torch.nn as nn

import numpy as np

In [None]:
traindt = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True,
)

In [2]:
transform = T.Compose([T.Resize(10), T.ToTensor(),])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)

class MulticlassLogisticRegression(nn.Module):
    def __init__(self):
        super(MulticlassLogisticRegression, self).__init__()
        self.linear = nn.Linear(100, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 10*10)
        x = self.linear(x)
        return x

model = MulticlassLogisticRegression()
print(model)

100%|██████████| 9.91M/9.91M [00:00<00:00, 39.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.07MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.20MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.13MB/s]


MulticlassLogisticRegression(
  (linear): Linear(in_features=100, out_features=10, bias=False)
)


In [3]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


for epoch in range(50):
    sum_losses = 0
    for x, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, 10*10)

        out = model(images)
        losses = criterion(out, labels)
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        sum_losses += losses.item()
        if x == 10:
            break

    print(f"EPOCH {epoch+1} TRAIN LOSS : {sum_losses / len(train_loader)}")

EPOCH 1 TRAIN LOSS : 0.049766685916925034
EPOCH 2 TRAIN LOSS : 0.04125076951756915
EPOCH 3 TRAIN LOSS : 0.03495062592187162
EPOCH 4 TRAIN LOSS : 0.03034284348681029
EPOCH 5 TRAIN LOSS : 0.02688028868327517
EPOCH 6 TRAIN LOSS : 0.024514354241173913
EPOCH 7 TRAIN LOSS : 0.022538598158212103
EPOCH 8 TRAIN LOSS : 0.020737379090364046
EPOCH 9 TRAIN LOSS : 0.019572103963986134
EPOCH 10 TRAIN LOSS : 0.018011110932079712
EPOCH 11 TRAIN LOSS : 0.017089049063765925
EPOCH 12 TRAIN LOSS : 0.016561169995427893
EPOCH 13 TRAIN LOSS : 0.015840855869911373
EPOCH 14 TRAIN LOSS : 0.015799987036536244
EPOCH 15 TRAIN LOSS : 0.015382882501524904
EPOCH 16 TRAIN LOSS : 0.014061061812362182
EPOCH 17 TRAIN LOSS : 0.013820419560617475
EPOCH 18 TRAIN LOSS : 0.013621550506111909
EPOCH 19 TRAIN LOSS : 0.012692299415307766
EPOCH 20 TRAIN LOSS : 0.01341225851827593
EPOCH 21 TRAIN LOSS : 0.012103104197394365
EPOCH 22 TRAIN LOSS : 0.012376849712339292
EPOCH 23 TRAIN LOSS : 0.01286574384805236
EPOCH 24 TRAIN LOSS : 0.01

In [4]:
model.eval()

weights = model.linear.weight.detach().cpu().numpy()

In [5]:
weights.shape

(10, 100)

In [6]:
for i, (imgs, labels) in enumerate(train_loader):

    outputs = nn.Softmax(dim=1)(model(imgs))
    imgs = imgs.detach().cpu().numpy()

    imgs2 = imgs.reshape((128, 100))

    w = (imgs2[:, np.newaxis, :] * weights[np.newaxis, :, :]).transpose(2, 1, 0)
    w = np.maximum(w, 0)
    w = np.minimum(w, 1)
    outputs = outputs.detach().cpu().numpy()
    imgs = np.squeeze(imgs, axis=1).transpose(1, 2, 0)
    outputs = np.expand_dims(outputs, axis=1).transpose(1, 2, 0)
    np.save("inputs.npy", imgs)
    np.save("outputs.npy", outputs)
    np.save("weights.npy", w)
    break

In [7]:
!pip install onnx
!pip install onnxruntime

Collecting onnx
  Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m80.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx
Successfully installed onnx-1.19.1
Collecting onnxruntime
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m56.8 MB/s[0m eta [36m0:0

In [10]:
import torch.onnx
import onnxruntime as ort
print(torch.cuda.is_available())

False


In [11]:
dummy = torch.randn(1, 1, 10, 10)

torch.onnx.export(
    model,
    dummy,
    "mnist_logreg.onnx",
    input_names=["images"],
    output_names=["logits"],
    dynamic_axes={"images": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=17
)

print("Exported to mnist_logreg.onnx")

  torch.onnx.export(


Exported to mnist_logreg.onnx


In [12]:
# pip install onnxruntime onnx
import onnxruntime as ort
import numpy as np

# same preprocessing used during training: Resize(10), ToTensor()
# Here we'll pretend we already have a batch tensor `imgs_torch` (B,1,10,10)
# If you have PIL images, apply the same transforms first, then stack.

# Example: fake batch of 128 just to show shape
B = 128
imgs_np = np.random.randn(B, 1, 10, 10).astype(np.float32)

# Create session
sess = ort.InferenceSession("mnist_logreg.onnx", providers=["CPUExecutionProvider"])

# Run
outputs = sess.run(
    ["logits"],
    {"images": imgs_np}
)[0]  # shape (B, 10) logits

# Turn logits into probabilities (optional)
exp = np.exp(outputs - outputs.max(axis=1, keepdims=True))
probs = exp / exp.sum(axis=1, keepdims=True)

preds = probs.argmax(axis=1)  # predicted digit 0–9
print(preds[:10], probs[:1])

[6 2 0 6 3 6 2 6 5 9] [[1.03030925e-05 9.59636152e-15 2.41885392e-08 7.81567418e-21
  2.30383246e-09 4.11800382e-14 9.99981880e-01 3.40727113e-10
  7.30378025e-09 7.75558328e-06]]


In [13]:
# Compare one batch between PyTorch and ONNX
with torch.no_grad():
    torch_logits = model(torch.from_numpy(imgs_np))  # shape (B,10)

onnx_logits = outputs
close = np.allclose(torch_logits.numpy(), onnx_logits, rtol=1e-3, atol=1e-4)
print("Match PyTorch vs ONNX:", close)

Match PyTorch vs ONNX: True
