In [1]:
import torch
import torchvision
import torchvision.transforms as tt
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

train_set = torchvisiontrain_set = torchvision.datasets.CIFAR10('../../data',train=True,download=False,transform=tt.Compose([
        tt.RandomCrop(32, 4, padding_mode="reflect"),
        tt.RandomHorizontalFlip(),
        tt.ToTensor(),
        tt.Normalize(*stats),
    ]))
test_set = torchvision.datasets.CIFAR10('../../data',download=False,train=False,transform=tt.Compose([tt.ToTensor(),tt.Normalize(*stats)]))

In [3]:
train_loader = torch.utils.data.DataLoader(train_set,batch_size=64,shuffle=True,pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=64,shuffle=True,pin_memory=True)

In [4]:
def conv_2d(ni, nf, stride=1, ks=3):
    return nn.Conv2d(
        in_channels=ni,
        out_channels=nf,
        kernel_size=ks,
        stride=stride,
        padding=ks // 2,
        bias=False,
    )


def bn_relu_conv(ni, nf):
    return nn.Sequential(nn.BatchNorm2d(ni), nn.ReLU(inplace=True), conv_2d(ni, nf))


In [5]:
class ResidualBlock(nn.Module):
    def __init__(self,ni,nf,stride=1):
        super().__init__()
        self.bn = nn.BatchNorm2d(ni)
        self.conv1 = conv_2d(ni,nf,stride)
        self.conv2 = bn_relu_conv(nf,nf)
        self.shortcut = lambda x :x
        if ni!=nf:
            self.shortcut = conv_2d(ni,nf,stride,1)

    def forward(self,x):
        x = F.relu(self.bn(x),inplace=True)
        r = self.shortcut(x)
        x = self.conv1(x)
        x = self.conv2(x)*0.2
        return x.add_(r)

In [6]:
def make_group(N, ni, nf, stride):
    start = ResidualBlock(ni, nf, stride)
    rest = [ResidualBlock(nf, nf) for j in range(1, N)]
    return [start] + rest

In [7]:
class WideResNet(nn.Module):
    def __init__(self,n_groups,N,n_classes,k=1,n_start=16):
        super().__init__()
        layers = [conv_2d(3,n_start)]
        n_channels = [n_start]
        for i in range(n_groups):
            n_channels.append(n_start*(2**i)*k)
            stride = 2 if i>0 else 1
            layers += make_group(N,n_channels[i],n_channels[i+1],stride)
        layers += [
            nn.BatchNorm2d(n_channels[3]),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(n_channels[3],n_classes)
        ]
        self.features = nn.Sequential(*layers)


    def forward(self,x):
        return self.features(x)

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
model = WideResNet(3,3,10,6)
model.to(device)

WideResNet(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): ResidualBlock(
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(16, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Sequential(
        (0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (shortcut): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (2): ResidualBlock(
      (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Sequential(
        (0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [10]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.005)

In [11]:
num_epoch = 5
        
for epoch in range(num_epoch):
    for image,label in train_loader:
        torch.cuda.empty_cache()
        image = image.to(device)
        label = label.to(device)
        out = model(image)
        loss = nn.functional.cross_entropy(out,label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch: {epoch} Loss: {loss.item()}")

KeyboardInterrupt: 

In [12]:
def accuracy(out,labels):
  _,pred = torch.max(out,dim=1)
  return torch.sum(pred==labels)/len(labels)

In [20]:
torch.cuda.empty_cache()

In [21]:
for image,label in test_loader:
  preds = model(image.to(device))
  break;
plt.imshow(image[0].permute(1,2,0),cmap='gray')
a, b = torch.max(preds[0],1)
a,b

OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 2.00 GiB total capacity; 1.69 GiB already allocated; 0 bytes free; 1.73 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF