In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision
import torchvision.transforms.v2 as transforms
from torchmetrics import Accuracy

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

In [3]:
# # use 20% of training data for validation
# train_set_size = int(len(train_set) * 0.8)
# valid_set_size = len(train_set) - train_set_size

# # split the train set into two
# seed = torch.Generator().manual_seed(42)
# train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

# class CIFAR10DataModule(pl.LightningDataModule):
#     batch_size = 512
#     num_workers = 24
#     def prepare_data(self): 
#         torchvision.datasets.CIFAR10('CIFAR10', train=True, download=True)
#         torchvision.datasets.CIFAR10('CIFAR10', train=True, download=True)

#     def train_dataloader(self):
#         transform = transforms.Compose([
#             transforms.ColorJitter(),
#             transforms.RandomResizedCrop(32),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#         ])
#         self.train_dataset = torchvision.datasets.CIFAR10(
#             root='./CIFAR10', train=True, download=True, transform=transform)
#         self.train_loader = torch.utils.data.DataLoader(
#             self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

#         return self.train_loader   

#     def val_dataloader(self):
#         return self.test_dataloader()   
    
#     def test_dataloader(self):
#         transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#         ])
        
#         self.test_dataset = torchvision.datasets.CIFAR10(
#             root='./CIFAR10', train=False, download=True, transform=transform)
#         self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

#         return self.test_loader   

# data = CIFAR10DataModule()

In [4]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.ColorJitter(),
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])



trainset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=True, download=True, transform=transform_train)

# split the train set into two
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

# Create train dataloader
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=512, shuffle=True, num_workers=24)
# Create validation dataloader
validloader = torch.utils.data.DataLoader(
    validset, batch_size=512, shuffle=False, num_workers=24)

testset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=False, download=True, transform=transform_test)
# Create test dataloader
testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=24)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


==> Preparing data..




Files already downloaded and verified
Files already downloaded and verified


In [5]:
class SeparableConv2d(nn.Module):
    '''Separable convolution'''
    def __init__(self, in_channels, out_channels, stride=1):
        super(SeparableConv2d, self).__init__()
        self.dw_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.dw_bn = nn.BatchNorm2d(in_channels)
        self.pw_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.pw_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = F.relu(self.dw_bn(self.dw_conv(x)))
        out = F.relu(self.pw_bn(self.pw_conv(out)))
        return out

In [20]:
class MyMobileNet(pl.LightningModule):
    cfg = [
        (32, 64, 1), 
        (64, 128, 2), 
        (128, 128, 1), 
        (128, 256, 2),
        (256, 256, 1),
        (256, 512, 2),
        (512, 512, 1),
        (512, 512, 1),
        (512, 512, 1),
        (512, 512, 1),
        (512, 512, 1),
        (512, 1024, 2),
        (1024, 1024, 1),
    ]
    
    def __init__(self, num_classes=10, alpha: float = 1):
        super(MyMobileNet, self).__init__()
        conv_out = int(32 * alpha)
        self.conv = nn.Conv2d(3, conv_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(conv_out)
        self.accuracy = Accuracy("multiclass", num_classes=num_classes)

        self.features = self.make_feature_extractor(alpha)
        self.linear = nn.Linear(1024, num_classes)

    def make_feature_extractor(self, alpha):
        layer_values = [(int(inp*alpha), int(out*alpha), chan) for inp, out, chan in self.cfg]
        layers = nn.Sequential(*[SeparableConv2d(*tup) for tup in layer_values])
        return layers

    def forward(self, x):
        out = F.relu(self.bn(self.conv(x)))
        out = self.features(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size()[0], -1)
        out = self.linear(out)
        return out

    def training_step(self, train_batch, batch_idx):
    # def training_step(self):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.compute_loss(logits, y)
        # self.log("train_loss", loss)
        acc = self.accuracy(logits, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
    # def training_step(self):
        x, y = batch
        logits = self.forward(x)
        loss = self.compute_loss(logits, y)
        # self.log("train_loss", loss)
        acc = self.accuracy(logits, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def test_step(self, batch, batch_idx):
        self.validation_step(batch, batch_idx)
        
    def configure_optimizers(self):
        self.lr = 0.02089
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.005)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, total_steps=MAX_EPOCHS * len(trainloader))
        return { "optimizer": optimizer, "lr_scheduler": scheduler }
        
    def compute_loss(self, logits, labels):
        return nn.functional.cross_entropy(logits, labels)

In [21]:
# from lightning.pytorch.tuner import Tuner
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar

MAX_EPOCHS = 100

model = MyMobileNet()

checkpoint_callback = ModelCheckpoint(
    monitor='val_accuracy',
    dirpath='checkpoints/',
    filename='CIFAR-{epoch:02d}-{val_accuracy:.2f}'
)
progress_callback = RichProgressBar(leave=True)

trainer = pl.Trainer(max_epochs=MAX_EPOCHS, callbacks=[checkpoint_callback, progress_callback])

# tune = Tuner(trainer)
# optimal_lr = tune.lr_find(model, train_dataloaders=trainloader, val_dataloaders=validloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=validloader)

/home/semar/.local/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory checkpoints/ exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

In [26]:
!pip install -U rich

