In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from pytorch_lightning import LightningModule, Trainer
import kornia as K
from sklearn.model_selection import train_test_split

import torchvision.transforms as transforms

import cv2
import glob

%matplotlib inline

In [2]:
transform = transforms.ToTensor()
     

class Augmentation_train(nn.Module):
    _augmentations = K.augmentation.AugmentationSequential(
    K.augmentation.Resize((32, 32)),
    # K.augmentation.RandomHorizontalFlip(p=0.5),
    # K.augmentation.RandomVerticalFlip(p=0.5),
    K.augmentation.RandomRotation(10, p=0.8),
    K.augmentation.RandomElasticTransform(p=0.4),
    K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25, p=0.3),
    same_on_batch=False,
    keepdim=True,
    )

    def __init__(self):
        super(Augmentation_train, self).__init__()

    @torch.no_grad()
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        return self.augmentations(img)

    def augmentations(self, img: torch.Tensor) -> dict:
        img = transform(img)
        x = self._augmentations(img)
        return x

class Augmentation_val(nn.Module):
    _augmentations = K.augmentation.AugmentationSequential(
    K.augmentation.Resize((32, 32)),
    keepdim=True,
    )

    def __init__(self):
        super(Augmentation_val, self).__init__()

    @torch.no_grad()
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        return self.augmentations(img)

    def augmentations(self, img: torch.Tensor) -> dict:
        img = transform(img)
        x = self._augmentations(img)
        return x

Aug_tr = Augmentation_train()
Aug_val = Augmentation_val()

In [3]:
CLASSES = len(glob.glob('TrainIJCNN2013/TrainIJCNN2013' + '/*/'))
CLASSES

43

In [4]:
list_train = glob.glob('TrainIJCNN2013/TrainIJCNN2013' + '/*/*')
list_train = [string for string in list_train if '.ppm' in string]

In [5]:
train, val = train_test_split(list_train, train_size=0.8, )

In [6]:
class MyDataset(Dataset):

    def __init__(self,
     list_names, preprocess):
        self.list_names = list_names
        self.preprocess = preprocess

    def get_img(self, idx):
        name = self.list_names[idx]
        image =  cv2.imread(name)
        class_ = name.split('/')[-2]
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        return image, int(class_)

    def __len__(self):
        return len(self.list_names)
    
    def __getitem__(self, idx):
        img, class_ = self.get_img(idx)
        return self.preprocess(img), class_ 

In [7]:
dataset_train = MyDataset(train, Aug_tr)
dataset_val = MyDataset(val, Aug_val)

In [8]:
dataset_train[0][0].shape

torch.Size([3, 32, 32])

In [9]:
train_dataloader = DataLoader(dataset_train, batch_size=4, 
                              num_workers=8, 
                              shuffle=True)
val_dataloader = DataLoader(dataset_val, batch_size=4, 
                            num_workers=8,
                            )

In [10]:
def accuracy(out, labels):
    _,pred = torch.max(out, dim=1)
    return torch.sum(pred==labels).item()

In [11]:
model = torchvision.models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, CLASSES),
    nn.LogSoftmax(dim=1),
)

loss = nn.CrossEntropyLoss()

In [None]:
class ClassModel(LightningModule):
    def __init__(
        self,
        lr: float = 0.001,
    ):
        super().__init__()
        self.lr = lr
        self.net = model
        self.loss = loss
        self.metric = accuracy

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, _):
        img, classes = batch
        out = self(img.float())
        loss_train = self.loss(out, classes)
        self.log('train_step', loss_train, on_step=True, )
        return loss_train

    def validation_step(self, batch, _):
        img, classes = batch
        out = self(img.float())
        metric = self.metric(out, classes)
        self.log('val_metric_step', metric, on_step=True)

    def validation_epoch_end(self, outputs):
        print(outputs)
        loss_val = torch.stack([x["val_metric_step"] for x in outputs]).mean() if outputs else 0
        self.log('val_epoch_total_step', loss_val, on_epoch=True)
        return loss_val

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min'),
            'monitor' : 'train_step',
        }
        return {
                "optimizer": optimizer,
                "lr_scheduler": scheduler,
                }


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
model = ClassModel()
checkpoint_callback = ModelCheckpoint(dirpath="lightning_logs/classification/best__resnet50", save_top_k=2, monitor="val_epoch_total_step")

In [None]:
trainer = Trainer(gpus=1,
 max_epochs=30,
 callbacks=[checkpoint_callback]
 )

In [None]:
trainer.fit(model,
 train_dataloader, val_dataloader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type             | Params
------------------------------------------
0 | net  | ResNet           | 23.6 M
1 | loss | CrossEntropyLoss | 0     
------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.385    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

[]




Training: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validating: 0it [00:00, ?it/s]

[]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
images, classes = next(iter(train_dataloader))

In [None]:
model.eval()
preds = model(images)

In [None]:
_,pred = torch.max(preds, dim=1)

In [None]:
pred

tensor([10, 10, 10, 10])

In [None]:
classes

tensor([ 1,  0,  4, 28])