In [None]:
import time

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    loss_sum=0
    
    for X, y in dataloader:     
        optimizer.zero_grad()
        pred = model(X)
        batch_loss_result = loss_fn(pred, y)
        batch_loss_result.backward()
        optimizer.step()
        loss_sum+=batch_loss_result.item()
        
    loss_sum=loss_sum/len(dataloader.dataset)
    return loss_sum

def test(dataloader, model, loss_fn):
    model.eval()
    with torch.no_grad():
        loss_sum = 0
        top1_correct = 0
        top5_correct = 0
        total_samples = 50_000
        batch_size = 50

        for X, y in dataloader:
            pred = model(X)  

            batch_loss_result = loss_fn(pred, y)
            loss_sum += batch_loss_result.item()

            softmax_pred = torch.softmax(pred, dim=1)
            view_softmax_pred = softmax_pred.view(batch_size, 10, 1000)
            mean_view_softmax_pred = view_softmax_pred.mean(dim=1)
            top1_pred = torch.argmax(mean_view_softmax_pred, dim=1)
            
            view_y = y.view(batch_size, 10, 1000)
            mean_view_y = view_y.mean(dim=1)
            top1_y = torch.argmax(mean_view_y, dim=1)

            top1_correct += (top1_pred == top1_y).sum().item()

            _, top5_pred = torch.topk(mean_view_softmax_pred, 5, dim=1)  
            top1_y_expanded = top1_y.view(-1, 1)  
            top5_correct += torch.sum(torch.eq(top5_pred, top1_y_expanded)).item()

        avg_loss = loss_sum / (10*total_samples)
        top1_accuracy = top1_correct / total_samples
        top5_accuracy = top5_correct / total_samples

    return avg_loss, top1_accuracy, top5_accuracy

epochs = 1
summary_top1 = list()
summary_top5 = list()
epoch_durations = list()
train_losses = list()
test_losses = list()

for i in range(1):
    model = model().to("cuda") #if you want to apply new model, declare new object. Do not change the name "model"
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.95)
    
    top1_accuracy = 0
    top5_accuracy = 0 

    for epoch in range(epochs):
        start_time = time.time()

        train_loss = train(train_dataloader, model, loss_fn, optimizer)
        test_loss, top1_temp, top5_temp = test(test_dataloader, model, loss_fn)

        scheduler.step()

        epoch_duration = time.time() - start_time

        epoch_durations.append(epoch_duration)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        summary_top1.append(top1_temp)
        summary_top5.append(top5_temp)

        if top1_temp > top1_accuracy:
            top1_accuracy = top1_temp

        if top5_temp > top5_accuracy:
            top5_accuracy = top5_temp

print("Done!")
