In [1]:
import sys
import pyrootutils

root = pyrootutils.setup_root(sys.path[0], pythonpath=True, cwd=True)

import timm
import torch
import shutil
import numpy as np
import torchvision
import seaborn as sns
import torch.nn as nn
import albumentations as A
import torch.optim as optim
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms


from PIL import Image
from omegaconf import OmegaConf
from hydra import compose, initialize
from albumentations.pytorch import ToTensorV2
from torch.utils.data import random_split, DataLoader, TensorDataset


shutil.copy("configs/config.yaml", "notebooks/config.yaml")
with initialize(version_base=None, config_path=""):
    config = compose(config_name="config.yaml")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set random seed and device
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:

class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="data/", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label


transform_train = A.Compose(
    [
        A.RandomCrop(32, 32),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.HueSaturationValue(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        ),
        ToTensorV2(),
    ]
)



transform_test = A.Compose(
    [   
        A.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        ),
        ToTensorV2(),
    ]
)


In [4]:

trainset = Cifar10SearchDataset(
    root="./data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=config.batch_size, shuffle=True, num_workers=2
)

testset = Cifar10SearchDataset(
    root="./data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=config.batch_size, shuffle=False, num_workers=2
)




Files already downloaded and verified
Files already downloaded and verified


In [10]:


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # import resnet18 from timm
        self.model = timm.create_model("resnet18", pretrained=True, num_classes=10)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        print("Shape of x in forward: ", x.shape)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        print("shape in training step: ", x.shape)
        logits = self(x)
        loss = self.criterion(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        print("shape in validation step: ", x.shape)
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        return {"val_loss": loss, "val_acc": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
        return {"val_loss": avg_loss, "val_acc": avg_acc}

    def test_step(self, batch, batch_idx):
        x, y = batch
        print("shape in test step: ", x.shape)
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        return {"test_loss": loss, "test_acc": acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        return {"test_loss": avg_loss, "test_acc": avg_acc}

    def configure_optimizers(self):
        optimizer = optim.SGD(
            self.parameters(), lr=config.lr, momentum=0.9, weight_decay=5e-4
        )
        return optimizer

    def train_dataloader(self):
        return trainloader

    def val_dataloader(self):
        return testloader

    def test_dataloader(self):
        return testloader

    def predict_dataloader(self):
        return super().predict_dataloader()


In [11]:
model = LitModel()
trainer = pl.Trainer(gpus=1, max_epochs=config.epochs)
trainer.fit(model)



  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ResNet           | 11.2 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]shape in validation step:  torch.Size([10000, 3, 32, 32])
Shape of x in forward:  torch.Size([10000, 3, 32, 32])
                                                                            

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/6 [00:00<?, ?it/s] shape in training step:  torch.Size([10000, 3, 32, 32])
Shape of x in forward:  torch.Size([10000, 3, 32, 32])
Epoch 0:  17%|█▋        | 1/6 [00:01<00:09,  1.91s/it, loss=2.64, v_num=13]shape in training step:  torch.Size([10000, 3, 32, 32])
Shape of x in forward:  torch.Size([10000, 3, 32, 32])
Epoch 0:  33%|███▎      | 2/6 [00:02<00:04,  1.17s/it, loss=2.69, v_num=13]shape in training step:  torch.Size([10000, 3, 32, 32])
Shape of x in forward:  torch.Size([10000, 3, 32, 32])
Epoch 0:  50%|█████     | 3/6 [00:03<00:03,  1.14s/it, loss=2.53, v_num=13]shape in training step:  torch.Size([10000, 3, 32, 32])
Shape of x in forward:  torch.Size([10000, 3, 32, 32])
Epoch 0:  67%|██████▋   | 4/6 [00:03<00:01,  1.04it/s, loss=2.51, v_num=13]shape in training step:  torch.Size([10000, 3, 32, 32])
Shape of x in forward:  torch.Size([10000, 3, 32, 32])
Epoch 0:  83%|████████▎ | 5/6 [00:04<00:00,  1.08it/s, loss=2.43, v_num=13]shape in validation ste

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 6/6 [00:05<00:00,  1.07it/s, loss=2.43, v_num=13]


In [12]:
# load an image from a path
image = Image.open('tmp/2.png')
image = np.array(image)
image.shape

(1080, 1080, 3)

In [13]:
# load an image from a path
image = Image.open('tmp/0000.jpg')
image = np.array(image)
image.shape

(32, 32, 3)

In [14]:
# load an image from directory and predict the class
def predict_image(img_path):
    img = Image.open(img_path)
    # plt.imshow(img)
    # plt.show()
    # print shape of img
    print(img.size)
    img = transform_test(image=np.array(img))["image"][np.newaxis, :]
    # img = torch.from_numpy(img).to(device).unsqueeze(0)
    output = model(img)
    pred = output.argmax(dim=1, keepdim=True)
    # map the predicted class to the CIFAR10 class name
    class_names = [
        "airplane",
        "automobile",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    ]
    

    print("Predicted class: ", class_names[pred.item()])
predict_image('tmp/2.png')
predict_image('tmp/0000.jpg')


(1080, 1080)
Shape of x in forward:  torch.Size([1, 3, 1080, 1080])
Predicted class:  cat
(32, 32)
Shape of x in forward:  torch.Size([1, 3, 32, 32])


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])