In [None]:
import torch
from torch import nn
from torch.distributions.bernoulli import Bernoulli

In [None]:
from torchvision.datasets import CIFAR10
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from time import sleep
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.optim.lr_scheduler import MultiStepLR
from stochastic_resnet import stochastic_resnet34

In [None]:
train_data = CIFAR10(root='./cifar10_train.pt', train=True, download=True, transform=Compose([ToTensor(), Normalize(0, 1)]))
test_data = CIFAR10(root='./cifar10_test.pt', train=False, download=True, transform=Compose([ToTensor(), Normalize(0, 1)]))

Files already downloaded and verified
Files already downloaded and verified


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
model = stochastic_resnet34(in_channels=3, num_classes=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
#optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
BATCH_SIZE = 128
scheduler = MultiStepLR(optimizer, milestones=[250, 375], gamma=0.1)


train_loader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    num_workers=1
)

In [None]:
for epoch in range(1, 501):
    with tqdm(train_loader, unit="batch") as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            loss = loss_fn(output, target)
            correct = (predictions == target).sum().item()
            accuracy = correct / BATCH_SIZE
            
            loss.backward()
            optimizer.step()

            tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
            sleep(0.1)
        scheduler.step()  

In [None]:
test_loader = DataLoader(
    dataset=test_data
)

In [None]:
total = 0
correct = 0

model.eval()
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predictions = torch.max(outputs.data, dim=1)
        total += labels.size(0)
        correct += (predictions == labels).sum().item()

print('Accuracy == {} %'.format(100 * correct / total))