In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
SEED = 1

## First, let's create a fake data tensor

In [2]:
batch_size = 3
seq_len = 4
hidden_size = 5

In [3]:
torch.manual_seed(SEED)

data = torch.randn((batch_size, seq_len, hidden_size)) # matrix of size (batch_size, seq_len, hidden_size)
targets = torch.tensor([0, 1, 3]) # matrix of size (batch_size) where entry is class idx
print(data.shape, targets.shape)

torch.Size([3, 4, 5]) torch.Size([3])


## Then, our softmax model

In [63]:
class SoftmaxRegression(nn.Module):
    def __init__(self, seq_len, hidden_size):
        super(SoftmaxRegression, self).__init__()
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.W = torch.nn.Linear(self.hidden_size, 1, bias=True)
    
    def forward(self, input):
        scores = self.W(input).squeeze(-1)
        return scores
    
    def train_forward(self, input, target):
        scores = self.forward(input)
        loss = nn.CrossEntropyLoss(ignore_index=-100)(scores, target)
        return loss
    
    def predict_proba(self, input):
        scores = self.forward(input)
        p = F.softmax(scores, dim=-1)
        return p
    
    def predict(self, input):
        p = self.predict_proba(input)
        _, preds = p.max(-1)
        return preds

## And train it

In [64]:
lr = 1e-3
device = "cpu" # else "cuda:0"
max_epoch = 2000
print_every = 100
max_grad_norm = 1.0

In [65]:
torch.manual_seed(0)

model = SoftmaxRegression(seq_len, hidden_size)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=float(lr))

epoch = 0
while epoch < max_epoch:
    epoch += 1
    loss = 0
    
    model.train()
    for i in range(batch_size):
        example = data[i, :, :].unsqueeze(0)
        target = targets[i].unsqueeze(0)
        
        loss = model.train_forward(example, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        
    if epoch%print_every==0:
        print("epoch {} loss {}".format(epoch, loss))

epoch 100 loss 0.9858344793319702
epoch 200 loss 0.7027537226676941
epoch 300 loss 0.5020290613174438
epoch 400 loss 0.3692546784877777
epoch 500 loss 0.2781364917755127
epoch 600 loss 0.21403783559799194
epoch 700 loss 0.16786591708660126
epoch 800 loss 0.13383346796035767
epoch 900 loss 0.10820576548576355
epoch 1000 loss 0.08852928876876831
epoch 1100 loss 0.07316025346517563
epoch 1200 loss 0.060973044484853745
epoch 1300 loss 0.05118098109960556
epoch 1400 loss 0.04322313144803047
epoch 1500 loss 0.03669150546193123
epoch 1600 loss 0.03128419816493988
epoch 1700 loss 0.026774315163493156
epoch 1800 loss 0.022988714277744293
epoch 1900 loss 0.019793258979916573
epoch 2000 loss 0.017082812264561653


In [67]:
(model.predict_proba(data))

tensor([[9.6258e-01, 3.6998e-02, 8.5927e-06, 4.0928e-04],
        [2.9356e-05, 9.8474e-01, 8.4081e-03, 6.8241e-03],
        [4.6515e-04, 9.9967e-03, 6.4669e-03, 9.8307e-01]],
       grad_fn=<SoftmaxBackward>)

In [66]:
(model.predict(data))

tensor([0, 1, 3])