<a href="https://colab.research.google.com/github/szabeenglobal/ConvolutionalNeuralNetworkWithTensorflow/blob/master/CIFAR10-classification-in-pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install pytorch-lightning



In [None]:

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch
import torchvision as tv

class CIFARclassification(pl.LightningModule):
  def __init__(self, model):
    super().__init__()
    self.model = model

  def forward(self, batch):
    return self.model(batch)

  def training_step(self, batch, _):
    feature, labels = batch
    logits = self(feature)
    cost = F.cross_entropy(logits, labels)

    return cost
  
  def validation_step(self, batch, _):
    feature, labels = batch
    logits = self(feature)
    cost = F.cross_entropy(logits, labels)

    return {'val_loss': cost}
  
  def configure_optimizers(self):

    return torch.optim.Adam(self.parameters(), lr=0.02)


In [None]:
import torchvision


class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transformation = torchvision.transforms.ToTensor()

    def setup(self, stage=None):
        self.cifar10_test = torchvision.datasets.CIFAR10(
            self.data_dir, transform=self.transformation, train=False, download=True
        )
        cifar10_train = torchvision.datasets.CIFAR10(
            self.data_dir, transform=self.transformation, train=True, download=True
        )
        self.cifar10_train, self.cifar10_val = torch.utils.data.random_split(
            cifar10_train, [40000, 10000]
        )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.cifar10_train, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar10_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar10_test, batch_size=self.batch_size)

In [None]:
 class ConvNet(torch.nn.Module):
  def __init__(self, in_channels, classes):
    super().__init__()
    self.in_channels = in_channels
    self.classes = classes
    self.features = torch.nn.Sequential(
        # torchvision.models.resnet18(num_classes=self.classes),
        torch.nn.Conv2d(in_channels=self.in_channels,
                        out_channels=16, kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=16,
                        out_channels=32,
                        kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=32,
                        out_channels=64,
                        kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2),
        torch.nn.Conv2d(in_channels=64,
                        out_channels=64,
                        kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=64,
                        out_channels=64,
                        kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2),
        torch.nn.Conv2d(in_channels=64,
                        out_channels=32,
                        kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.Conv2d(in_channels=32,
                        out_channels=32,
                        kernel_size=3),
        torch.nn.ReLU()
    )
    self.linear =torch.nn.Sequential(
        torch.nn.AdaptiveMaxPool2d(output_size=1), 
        torch.nn.Flatten(),
        torch.nn.Linear(),
        torch.nn.ReLU(),
        torch.nn.Linear(),
        torch.nn.Softmax()
        )
    
    def forward(self, X):
      return self.linear(self.features(X))

In [None]:

#AdaptiveMaxPool2d() for global pooling AdaptiveMaxPool2d(1) (batch, channels, width, height) -> (batch, channels, 1, 1)

In [None]:
import tempfile
import torch
pl.seed_everything(0)

with tempfile.TemporaryDirectory() as tmp_dir:
    model = ConvNet(3, 10)
    system = CIFARclassification(model)
    data = CIFAR10DataModule(tmp_dir, batch_size=1024)
    logger = TensorBoardLogger("cifar_logs", name="my_model")
    trainer = pl.Trainer(gpus=-1, max_epochs=10, logger=logger)
    trainer.fit(system, data)

Global seed set to 0


TypeError: ignored