In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, z_dim, n_class):
        super().__init__()
        self.linear = nn.Linear(z_dim, n_class)
        
    def forward(self, data, **kwargs):
        # z : (b, z)
        # t : (b,)
        
        logit = self.linear(data['z'])
        loss = F.cross_entropy(logit, data['t'])
        data['cross_entropy_loss'] = loss
        
        predict = torch.argmax(logit, dim=1)
        precision = torch.sum(predict == data['t']) / len(data['t'])
        data['precision'] = precision
        
        return data
        

In [10]:
classifier = Classifier(2, 10)
data = {'z': torch.randn(10000, 2),
        't': torch.randint(0, 10, size=(10000,))}
classifier(data)

{'z': tensor([[-0.1029,  0.9447],
         [-1.2995,  1.6374],
         [-1.5176, -1.2669],
         ...,
         [ 0.9709,  0.8865],
         [ 0.8922,  1.3592],
         [ 1.4032,  0.5030]]),
 't': tensor([1, 8, 3,  ..., 1, 1, 4]),
 'cross_entropy_loss': tensor(2.5237, grad_fn=<NllLossBackward0>),
 'precision': tensor(0.0955)}