<a href="https://colab.research.google.com/github/paolo-peretti/conv/blob/main/capsule.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/adambielski/CapsNet-pytorch
import sys, os
sys.path.append('/content/CapsNet-Pytorch')
os.chdir('/content/CapsNet-pytorch')

fatal: destination path 'CapsNet-pytorch' already exists and is not an empty directory.


In [None]:
import torch

torch.cuda.empty_cache()
device = ("cuda" if torch.cuda.is_available() else "cpu")

from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [None]:
from net import CapsNetWithReconstruction, CapsNet, ReconstructionNet

In [None]:
# load model
capsnet = CapsNet(3, 10)
reconstructionnet = ReconstructionNet(16, 10)
model = CapsNetWithReconstruction(capsnet, reconstructionnet)
model.to(device)

CapsNetWithReconstruction(
  (capsnet): CapsNet(
    (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
    (primaryCaps): PrimaryCapsLayer(
      (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
    )
    (digitCaps): CapsLayer(
      (routing_module): AgreementRouting()
    )
  )
  (reconstruction_net): ReconstructionNet(
    (fc1): Linear(in_features=160, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=1024, bias=True)
    (fc3): Linear(in_features=1024, out_features=784, bias=True)
  )
)

In [None]:
dataset = MNIST('./data', train=False, transform=ToTensor(), download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize(*stats,inplace=True)])


# Create datasets for training & validation, download if necessary
training_set = MNIST('./data', train=True, transform=transform, download=True)
validation_set = MNIST('./data', train=False, transform=transform, download=True)



batch_size = 10


# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


# Class labels
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

Training set has 60000 instances
Validation set has 10000 instances


In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

In [None]:
def accuracy(outputs, labels):

    _, preds = torch.max(outputs, dim=1)

    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [None]:
@torch.no_grad()
def evaluate(model, validation_loader):

    with torch.no_grad():
      
      model.eval()

      running_vloss = 0.0
      running_vacc = 0.0

      for i, vdata in enumerate(validation_loader):

          vinputs, vlabels = vdata

          vinputs = vinputs.to(device)
          vlabels = vlabels.to(device)
          
          voutputs = model(vinputs)
          

          running_vloss += loss_fn(voutputs, vlabels).item()
        
          running_vacc += accuracy(voutputs, vlabels).item()

      avg_vloss = running_vloss / (i + 1)
      avg_vaccuracy = running_vacc / (i + 1)

      print('  avg_vloss: {}'.format(avg_vloss))
      print('  avg_vaccuracy: {}'.format(avg_vaccuracy))

      return avg_vloss, avg_vaccuracy

In [None]:
def train_one_epoch(epoch_index):

    model.train()
  
    running_loss = 0.
    running_accuracy = 0.

 
    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        

        running_accuracy += accuracy(outputs,labels).item()


        if i % 1000 == 999:

            last_loss = running_loss / 1000
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            last_accuracy = running_accuracy / 1000
            print('  batch {} accuracy: {}'.format(i + 1, last_accuracy))
            
            running_loss = 0.
            running_accuracy = 0.

            print('-----------------------------------------------------------------------')
  

      
    # last_loss = running_loss / (i + 1) # loss per batch
    # last_accuracy = running_accuracy / (i + 1)
    # print('  loss: {}'.format(last_loss))
    # print('  accuracy: {}'.format(last_accuracy))
    
      

    return last_loss, last_accuracy

In [None]:
def train_data(EPOCHS):

  patience = 2
  best_vloss = 1_000_000.
  trigger_times = 0

  for epoch in range(EPOCHS):

      print('\n\nEPOCH {}:'.format(epoch + 1))

      model.requires_grad_(True)
      
      avg_loss, avg_accuracy = train_one_epoch(epoch)

      model.requires_grad_(False)

      vloss, vaccuracy = evaluate(model, validation_loader)
      

      # Early stopping
        
      if vloss > best_vloss:
          trigger_times += 1
          print('Trigger Times:', trigger_times)

          if trigger_times >= patience:
              print('Early stopping!\nStart to test process.')
              return model

          else:
              print('trigger times: 0')
              trigger_times = 0

      best_vloss = avg_loss

  return model

In [None]:
model = train_data(5)




EPOCH 1:


RuntimeError: ignored