In [None]:
!pip install wandb
!pip install lightning

In [None]:
import torch
from torch import nn
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
import torchvision
import torchvision.transforms.v2 as v2
import os
import matplotlib.pyplot as plt
import numpy as np
import lightning as pl
from torchmetrics import Accuracy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer
import wandb
import shutil
import numpy as np
import random
import gc


In [None]:
os.environ['WANDB_API_KEY'] = "761e2f0f9986fd2e6ee9f21ef44a2665e0bc8618"
wandb.login(key=os.getenv("WANDB_API_KEY"))

In [None]:
data_dir = '/kaggle/input/inaturalist12k'
target_dir = '/kaggle/temp/inaturalist12k'

if not os.path.exists(target_dir):
    shutil.copytree(data_dir, target_dir)


In [None]:
def transforms(augmentation):  
   if augmentation:
       transform = v2.Compose(
       [v2.Resize((256, 256)),
        v2.RandomHorizontalFlip(p=0.4),
        v2.RandomVerticalFlip(p=0.1),
        v2.RandomApply(
        [v2.RandomRotation(degrees=15)],
        p=0.1
        ),
        v2.RandomApply(
        [v2.ColorJitter(brightness=0.2, contrast=0.2,
                        saturation=0.2, hue=0.1)],
        p=0.5
        ),
        #v2.ColorJitter(brightness=0.2, contrast=0.2),
        v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
   else:
        transform = v2.Compose(
        [v2.Resize((256, 256)),
        v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
    

   return transform

In [None]:
class Augmentation(torch.utils.data.Dataset):
    def __init__(self, train_complete, indices, transform):
        self.train_complete = train_complete
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        image, label = self.train_complete[actual_idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
NUM_WORKERS = 0
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)

In [None]:
class CNN(nn.Module):
    def __init__(self, input, filters, kernel, pool_kernel, pool_stride, batchnorm, activation, dropout, ffn_size, num_classes=10):
        super().__init__()
        
        self.act = self._activation(activation)
        self.convblock1 = self._convblock(input, filters[0], kernel[0], pool_kernel[0], pool_stride[0], self.act, batchnorm, dropout)
        self.convblock2 = self._convblock(filters[0], filters[1], kernel[1], pool_kernel[1], pool_stride[1], self.act, batchnorm, dropout)
        self.convblock3 = self._convblock(filters[1], filters[2], kernel[2], pool_kernel[2], pool_stride[2], self.act, batchnorm, dropout)
        self.convblock4 = self._convblock(filters[2], filters[3], kernel[3], pool_kernel[3], pool_stride[3], self.act, batchnorm, dropout)
        self.convblock5 = self._convblock(filters[3], filters[4], kernel[4], pool_kernel[4], pool_stride[4], self.act, batchnorm, dropout)
        if batchnorm:
            self.batch_norm = nn.BatchNorm1d(num_features=ffn_size)
        else:
            self.batch_norm = nn.Identity()

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.LazyLinear(ffn_size)
        self.out = nn.Linear(ffn_size, num_classes)

    def _convblock(self, input, output, kernel, pool_kernel, pool_stride, activation_fn,  batchnorm, dropout=0):

        if batchnorm:
          return torch.nn.Sequential(
              nn.Conv2d(input, output, kernel),
  
              activation_fn,
              nn.BatchNorm2d(output),
              #nn.Dropout(dropout),
              nn.MaxPool2d(pool_kernel, pool_stride))
        else:
          return torch.nn.Sequential(
              nn.Conv2d(input, output, kernel),
              activation_fn,
              #nn.Dropout(dropout),
              nn.MaxPool2d(pool_kernel, pool_stride))

    def _activation(self, act):
        if act == 'relu':
            act = nn.ReLU()
        elif act == 'gelu':
            act = nn.GELU()
        elif act == 'selu':
            act = nn.SELU()
        elif act == 'mish':
            act = nn.Mish()
        elif act == 'swish':
            act = nn.SiLU()
        return act

    def forward(self, x):
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.convblock4(x)
        x = self.convblock5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.out(x)
        return x


In [None]:
class CNN_light(pl.LightningModule):
    def __init__(self, optim, filters, kernel, pool_kernel, pool_stride, batchnorm, activation, dropout, ffn_size, lr):
        super().__init__()
        self.optim = optim
        self.save_hyperparameters()
        self.model = CNN(input=3, filters=filters, kernel=kernel, pool_kernel=pool_kernel, pool_stride=pool_stride, batchnorm=batchnorm, activation=activation, dropout=dropout, ffn_size=ffn_size, num_classes=10)
        self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.train_accuracy(logits, y)
        self.log("train loss", loss, on_step = False, on_epoch = True)
        self.log("train accuracy", acc, on_step = False, on_epoch = True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.val_accuracy(logits, y)
        self.log("val loss", loss, on_step = False, on_epoch = True)
        self.log("val accuracy", acc, on_step = False, on_epoch = True)

        return loss

    def configure_optimizers(self):
        if self.optim == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9)
        elif self.optim == 'adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

In [None]:
aug = True
batch_size = 16
optim = 'adam'
filters = [64, 64, 64, 64, 64]
kernel = [5,5,5,5,5]
pool_kernel = [3,3,3,3,3]
pool_stride = [1,1,1,1,2]
batchnorm = True
activation = 'mish'
dropout = 0.4
ffn_size = 256
lr = 0.00005
epochs = 5

In [None]:
train_transform = transforms(augmentation=aug)
val_transform = transforms(augmentation=False)
test_transform = transforms(augmentation=False)
DATA_DIR = "/kaggle/temp/inaturalist12k/inaturalist_12K"
train_dataset_complete = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, "train"))
test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, "val"), transform=test_transform)

# Getting labels and random splitting/shuffling of each class examples
labels = np.array([entry[1] for entry in train_dataset_complete.samples])
split_fn = StratifiedShuffleSplit(n_splits = 1, test_size = 0.2, random_state = 219)
train_ids, valid_ids = next(split_fn.split(np.zeros(len(labels)), labels))

#train_dataset = Subset(train_dataset_complete, train_ids)
#alid_dataset = Subset(train_dataset_complete, valid_ids)

# Transforms
train_dataset = Augmentation(train_dataset_complete, train_ids, train_transform)
valid_dataset = Augmentation(train_dataset_complete, valid_ids, val_transform)


# Dataloader
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=NUM_WORKERS, pin_memory = False, worker_init_fn=seed_worker,
    generator=g)

val_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
                                        shuffle=False, num_workers=NUM_WORKERS, pin_memory= False, worker_init_fn=seed_worker,
    generator=g)

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                        shuffle=False, num_workers=NUM_WORKERS, pin_memory= False, worker_init_fn=seed_worker,
    generator=g)

classes = train_dataset_complete.classes
n_classes = len(classes)


# Model
model = CNN_light(optim= optim, filters= filters, kernel = kernel, pool_kernel=pool_kernel, pool_stride=pool_stride, batchnorm=batchnorm, activation=activation, dropout=dropout, ffn_size=ffn_size, lr=lr)
logger= WandbLogger(project= 'scratch_test', name = 'best_val_acc', log_model = False)
trainer = pl.Trainer(
                        devices=1,
                        accelerator="gpu",
                        #strategy="ddp_notebook",
                        precision="16-mixed",
                        gradient_clip_val=1.0,
                        max_epochs=epochs,
                        logger=logger,
                        profiler=None,
                        
                    )

trainer.fit(model, train_dataloader, val_dataloader)
wandb.finish()

## Sweeps

In [None]:
filters_des = {'same_8': [8,8,8,8,8], 'same_16':[16,16,16,16,16], 'same_32': [32,32,32,32,32], 'same_64': [64, 64, 64, 64, 64], 'increase_16_128':[16, 32, 64, 128, 128], 'decrease_128_16': [128, 128, 64, 32, 16], 'mixed': [16, 32, 64, 32, 16]}
kernels_des = {'same_3': [3,3,3,3,3], 'same_5':[5,5,5,5,5], 'mix_3_5':[3,3,5,5,5]}

In [None]:
sweep_config = {
    #'name': 'bayes_sweep_init',
    'method': 'bayes',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 3},
    'parameters': {
        'lr': {
            'min': 1e-5,
            'max': 1e-4
        },
        'batch_size': {
            'values': [16,32]
        },
        'filters': {
            'values': [
                #filters_des['same_8'],
                #filters_des['same_16'],
                filters_des['same_32'],
                filters_des['same_64'],
                filters_des['increase_16_128'],
                #filters_des['decrease_128_16'],
                filters_des['mixed']
            ]
        },
        'kernel': {
            'values': [
                kernels_des['same_3'],
                kernels_des['same_5'],
                #kernels_des['same_7'],
                kernels_des['mix_3_5']
            ]
        },
        'pool_kernel': {
            'values': [[2,2,2,2,2], [3,3,3,3,3], [2,2,2,3,3]]
        },
        'pool_stride': {
            'values': [[1,1,1,1,1], [1,1,1,2,2],[1,1,1,1,2]]
        },
        'batchnorm': {
            'values': [True,]   # False
        },
        'activation': {
            'values': ['relu', 'gelu', 'mish', ]   #'swish''selu'
        },

        'augmentation': {
            'values': [True]   #, False
        },
        'dropout': {
            'min': 0.3,
            'max': 0.4
        },
        'ffn_size': {
            'values': [128, 256]  #64
        },
        'epochs': {'values': [5]}, #10
        'optim': {'values': ['adam']}      #'sgd', 
    }
}

#'augmentation': {
#    'values': ['hflip', 'vflip', 'rotate', None]
#},

In [None]:
def trainCNN(config=None):
    with wandb.init(config=config) as run:
        config = wandb.config
        
        run.name = f"A_{config.augmentation}_D_{config.dropout:.2f}_bn_{config.batchnorm}_ffn_{config.ffn_size}"
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        try:
            # Dataloading
            train_transform = transforms(augmentation=config.augmentation)
            val_transform = transforms(augmentation=False)
            DATA_DIR = "/kaggle/temp/inaturalist12k/inaturalist_12K"
            train_dataset_complete = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, "train"))
            test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, "val"))
        
            # Getting labels and random splitting/shuffling of each class examples
            labels = np.array([entry[1] for entry in train_dataset_complete.samples])
            split_fn = StratifiedShuffleSplit(n_splits = 1, test_size = 0.2, random_state = 219)
            train_ids, valid_ids = next(split_fn.split(np.zeros(len(labels)), labels))
        
            #train_dataset = Subset(train_dataset_complete, train_ids)
            #alid_dataset = Subset(train_dataset_complete, valid_ids)

            # Transforms
            train_dataset = Augmentation(train_dataset_complete, train_ids, train_transform)
            valid_dataset = Augmentation(train_dataset_complete, valid_ids, val_transform)

        
            # Dataloader
            train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size,
                                                      shuffle=True, num_workers=NUM_WORKERS, pin_memory = False, worker_init_fn=seed_worker,
                generator=g)
        
            val_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=config.batch_size,
                                                    shuffle=False, num_workers=NUM_WORKERS, pin_memory= False, worker_init_fn=seed_worker,
                generator=g)
        
            test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size,
                                                    shuffle=False, num_workers=NUM_WORKERS, pin_memory= False, worker_init_fn=seed_worker,
                generator=g)
        
            classes = train_dataset_complete.classes
            n_classes = len(classes)
        
        
            # Model
            model = CNN_light(optim= config.optim, filters= config.filters, kernel = config.kernel, pool_kernel=config.pool_kernel, pool_stride=config.pool_stride, batchnorm=config.batchnorm, activation=config.activation, dropout=config.dropout, ffn_size=config.ffn_size, lr=config.lr)
            logger= WandbLogger(project= 'dlas2_sweeps', name = run.name, experiment=run, log_model = False)
            trainer = pl.Trainer(
                                    devices=1,
                                    accelerator="gpu",
                                    #strategy="ddp_notebook",
                                    precision="16-mixed",
                                    gradient_clip_val=1.0,
                                    max_epochs=config.epochs,
                                    logger=logger,
                                    profiler=None,
                                    
                                )

            trainer.fit(model, train_dataloader, val_dataloader)
        finally:
            del trainer
            del model
            gc.collect()
            torch.cuda.empty_cache()

In [None]:
sweep_id = wandb.sweep(sweep_config, project="dla2-sweeps")
#sweep_id = "deeplearn24/dla2-sweeps/9pjx0avr"
wandb.agent(sweep_id, function=trainCNN, count=20)

### To do
- ResNET
- vis filters

In [None]:
import torch
import torchvision.models as models

In [None]:
class CNN_light_finetune(pl.LightningModule):
    def __init__(self, optim, lr):
        super().__init__()
        self.optim = optim
        self.save_hyperparameters()
        self.model = models.googlenet(pretrained=True)
        self.model.fc = nn.Linear(in_features = 1024, out_features = 10, bias = True)
        for param in self.model.parameters():
            param.requires_grad = False
        for param in self.model.fc.parameters():
            param.requires_grad = True
        self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.train_accuracy(logits, y)
        self.log("train loss", loss, on_step = False, on_epoch = True)
        self.log("train accuracy", acc, on_step = False, on_epoch = True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.val_accuracy(logits, y)
        self.log("val loss", loss, on_step = False, on_epoch = True)
        self.log("val accuracy", acc, on_step = False, on_epoch = True)

        return loss

    def configure_optimizers(self):
        if self.optim == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9)
        elif self.optim == 'adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

In [None]:
finetune_model = CNN_light_finetune(optim = 'adam', lr = 0.00005)
logger= WandbLogger(project= 'finetune', name = 'test', log_model = False)
trainer = pl.Trainer(
                        devices=1,
                        accelerator="gpu",
                        #strategy="ddp_notebook",
                        precision="16-mixed",
                        gradient_clip_val=1.0,
                        max_epochs=epochs,
                        logger=logger,
                        profiler=None,
                        
                    )

trainer.fit(finetune_model, train_dataloader, val_dataloader)
wandb.finish()