In [None]:
!pip install pytorch-lightning -q
!pip install wandb -q
!pip install datasets -q
!pip install transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m727.7/727.7 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.0/190.0 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.8/224.8 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━

## PyTorch Lightning: Streamlined Deep Learning

PyTorch Lightning simplifies deep learning in PyTorch. It abstracts lower-level details, organizes code, and accelerates experimentation. Key features:

**Structured Code**: Use LightningModule for organized model code.
**Data Handling**: Simplified data loading with data modules.
**Training Loop**: Automate the training loop with the Trainer class.
**Monitoring**: Seamlessly integrate with monitoring tools like TensorBoard and WandB.
**Reproducibility**: Encourages reproducible research practices.
For detailed information, refer to the PyTorch Lightning [documentation](https://pytorch-lightning.readthedocs.io/en/latest/).

You're gonna need these imports.


In [None]:
import random
import numpy as np
import argparse
from typing import List, Dict, Union, Optional

import wandb
import transformers
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import torch
from torch import nn
from torch.utils.data import random_split, DataLoader, Dataset

from torchvision import transforms
from torchmetrics.classification import AUROC, F1Score, Accuracy
from torchvision.models import list_models, get_model, get_model_weights, get_weight
from datasets import load_dataset, list_datasets

Now you'll need to login to you wandb account.

In [None]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [40]:
class DataModule(pl.LightningDataModule):
    """
    A DataModule implements 5 key methods:
        - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
        - setup (things to do on every accelerator in distributed mode)
        - train_dataloader (the training dataloader)
        - val_dataloader (the validation dataloader(s))
        - test_dataloader (the test dataloader(s))
    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.
    """

    def __init__(
        self,
        dataset_name: str="cifar100",
        batch_size: int=32,
        data_dir: str="~/cache",
        num_workers: int=4,
        transform_custom: bool=True,
        **kwargs
    ):
        super().__init__()

        self.dataset_name = dataset_name
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform_custom = transform_custom

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def prepare_data(self):
        """Download data if needed. This method is called only from a single GPU."""
        datasets_list = list_datasets()
        assert self.dataset_name in datasets_list, "the dataset was not found in HF hub"
        self.train_dataset = load_dataset(self.dataset_name, cache_dir=self.data_dir, split="train")
        self.data_test = load_dataset(self.dataset_name, cache_dir=self.data_dir, split="test")

    @staticmethod
    def transforms_train(examples):
        transform_train = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.5073, 0.4868, 0.4410], std=[0.2623, 0.2515, 0.2716]),
              transforms.RandomHorizontalFlip(),
              transforms.RandomRotation(degrees=5),

        ])
        examples["pixel_values"] = [transform_train(img.convert("RGB")) for img in examples['img']]
        return examples

    @staticmethod
    def transforms_test(examples):
        transform_test = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.5073, 0.4868, 0.4410], std=[0.2623, 0.2515, 0.2716]),
        ])

        examples["pixel_values"] = [transform_test(img.convert("RGB")) for img in examples['img']]
        return examples

    def setup(self, stage: Optional[str]=None):
        train_valid = self.train_dataset.train_test_split(test_size=0.2)
        self.data_train = train_valid['train']
        self.data_val = train_valid['test']
        self.data_train = self.data_train.map(lambda example: {"labels": example["fine_label"]}, remove_columns=["fine_label", "coarse_label"])
        self.data_val = self.data_val.map(lambda example: {"labels": example["fine_label"]}, remove_columns=["fine_label", "coarse_label"])
        self.data_test = self.data_test.map(lambda example: {"labels": example["fine_label"]}, remove_columns=["fine_label", "coarse_label"])

        if self.transform_custom:
            self.data_train = self.data_train.map(self.transforms_train, batched=True, remove_columns=["img"])
            self.data_val = self.data_val.map(self.transforms_test, batched=True, remove_columns=["img"])
            self.data_test = self.data_test.map(self.transforms_test, batched=True, remove_columns=["img"])
            self.data_train.set_format("pt", columns=["pixel_values"], output_all_columns=True)
            self.data_val.set_format("pt", columns=["pixel_values"], output_all_columns=True)
            self.data_test.set_format("pt", columns=["pixel_values"], output_all_columns=True)

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

In [70]:
# Define the LightningModule for image classification
class ImageModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        # Load model weights if not specified
        if not self.hparams.weight_name:
            weights = torch.hub.load("pytorch/vision", "get_model_weights", name=self.hparams.model_name)
            self.hparams.weight_name = [weight for weight in weights][0]
            print(f"set weights to {self.hparams.weight_name}")

        # Load the model with specified weights and replace the final classification layer
        self.weights = get_weight(str(self.hparams.weight_name))
        self.model = torch.hub.load("pytorch/vision", self.hparams.model_name, weights=self.hparams.weight_name)
        self.model.fc = nn.Linear(self.model.fc.in_features, self.hparams.num_classes)
        self.preprocess = self.weights.transforms()
        self.test_outputs = []

        # Initialize metrics based on the number of classes
        if self.hparams.num_classes > 2:
            self.aucroc = AUROC(task="multiclass", num_classes=self.hparams.num_classes, average="weighted")
            self.accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes, average='weighted')
            self.f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='weighted')
        elif self.hparams.num_classes == 2:
            self.aucroc = AUROC(task="binary")
            self.accuracy = Accuracy(task="binary")
            self.f1 = F1Score(task="binary")
        else:
            raise ValueError(f"num_classes should be 2 or more, regression not supported. Got value {self.hparams.num_classes}")

    def forward(self, x):
        # Preprocess input and pass through the model
        x = self.preprocess(x)
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx=None):
        # Forward pass during training
        images, labels = batch['pixel_values'], batch['labels']
        logits = self(images)

        # Calculate loss and metrics
        loss = nn.CrossEntropyLoss()(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, labels)
        aucroc = self.aucroc(logits, labels)
        f1 = self.f1(preds, labels)
        values = {'train_loss': loss, 'train_acc': acc, 'train_aucroc': aucroc, 'train_f1': f1}

        # Log metrics
        self.log_dict(values, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx=None):
        # Forward pass during validation
        images, labels = batch['pixel_values'], batch['labels']
        logits = self(images)
        loss = nn.CrossEntropyLoss()(logits, labels)
        preds = torch.argmax(logits, dim=1)

        # Calculate metrics
        acc = self.accuracy(preds, labels)
        aucroc = self.aucroc(logits, labels)
        f1 = self.f1(preds, labels)
        values = {'val_loss': loss, 'val_acc': acc, 'val_aucroc': aucroc, 'val_f1': f1}

        # Log metrics
        self.log_dict(values, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx=None):
        # Forward pass during testing
        images, labels = batch['pixel_values'], batch['labels']
        logits = self(images)
        preds = torch.argmax(logits, dim=1)

        # Calculate metrics
        acc = self.accuracy(preds, labels)
        aucroc = self.aucroc(logits, labels)
        f1 = self.f1(preds, labels)
        values = {'test_acc': acc, 'test_aucroc': aucroc, 'test_f1': f1}

        # Log metrics
        self.log_dict(values, prog_bar=True)

        # Store test outputs for visualization
        self.test_outputs.append((batch, preds, labels))

    def test_epoch_end(self):
        # Visualize some test examples and predictions
        imgs, preds, labels = random.sample(self.test_outputs, 5)
        self.log({
            "examples":[wandb.Image(img, caption=f"Pred:{pred}, Label:{lbl}") for img, pred, lbl in zip(imgs, preds, labels)]
        })
        self.test_outputs.clear()

    def configure_optimizers(self):
        # Configure optimizer and learning rate scheduler
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
                "betas": (0.9, 0.999),
                "eps": self.hparams.adam_eps,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "betas": (0.9, 0.999),
                "eps": self.hparams.adam_eps,
            },
        ]

        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )

        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_updates,
            num_training_steps=self.hparams.total_num_updates,
        )

        lr_dict = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
            "name": f"LearningRateLinearScheduler",
        }

        return [optimizer], [lr_dict]

In [None]:
def get_args() -> argparse.Namespace:
    """
    Parse command line arguments and return them as an `argparse.Namespace` object.

    Returns:
        argparse.Namespace: Parsed command line arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=100, help='Number of classes')
    parser.add_argument('--savedir', type=str, default='./checkpoints', help='Save checkpoint directory')
    parser.add_argument('--model_name', type=str, default='googlenet', help='Model name from torchvision')
    parser.add_argument('--weight_name', type=str, default=None, help='Model weights name from torchvision')
    parser.add_argument('--num_sanity_val_steps', type=int, default=2, help='Number of sanity validation steps')
    parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
    parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
    parser.add_argument('--accelerator', type=str, default="gpu", help='Accelerator (e.g., gpu or tpu)')
    parser.add_argument('--lr', type=float, default=1e-05, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay')
    parser.add_argument('--adam_eps', type=float, default=1.0e-08, help='Adam optimizer epsilon')
    parser.add_argument('--warmup_updates', type=int, default=100, help='Number of warmup updates')
    parser.add_argument('--total_num_updates', type=int, default=1000, help='Total number of updates')
    parser.add_argument('--accumulate_grad_batches', type=int, default=8, help='Number of gradient accumulation batches')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--gpu_ids', type=int, default=None, nargs='+', help='List of GPU IDs')
    group.add_argument('--gpus', type=int, default=1, help='Number of GPUs to use')
    args, unknown = parser.parse_known_args()
    return args


In [None]:
def get_gpu_settings(gpu_ids: List[int] | None, gpus: int | None) -> Tuple[str, Union[int, List[int], None], Union[str, None]]:
    """
    Determine GPU settings based on provided GPU IDs and the number of GPUs to use.

    :param gpu_ids: List of GPU IDs to use, or None if not specified.
    :param gpus: Number of GPUs to use, or None if not specified.
    :return: A tuple containing:
        - A string indicating the device type ('gpu' or 'cpu').
        - An integer or list of integers representing the selected GPU(s) or None for CPU.
        - A string indicating the strategy ('gpu', 'ddp', or None).
    """
    if not torch.cuda.is_available():
        return "cpu", -1, "cpu"

    if gpu_ids is not None:
        gpus = gpu_ids
        strategy = "ddp" if len(gpu_ids) > 1 else "gpu"
    elif gpus is not None:
        gpus = gpus
        strategy = "ddp" if gpus > 1 else "gpu"
    else:
        gpus = 1
        strategy = "gpu"

    return "gpu", gpus, strategy

In [71]:
# Parse command line arguments
args = get_args()

# Determine GPU settings
device_type, gpus, strategy = get_gpu_settings(gpu_ids=args.gpu_ids, gpus=args.gpus)

# Ensure total_num_updates matches max_steps
if args.total_num_updates != args.max_steps:
    args.total_num_updates = args.max_steps

# Initialize DataModule
data_module = DataModule(batch_size=args.batch_size, transform_custom=True)

# Initialize ImageModel with arguments from parsed command line arguments
model = ImageModel(**vars(args))

# Initialize WandbLogger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize callbacks for Lightning Trainer
early_stop_callback = EarlyStopping(monitor="val_f1", mode='max', patience=5, verbose=True)
lr_monitor = LearningRateMonitor(logging_interval='step')
checkpoint_callback = ModelCheckpoint(
    monitor="val_f1", mode="max", save_top_k=2, save_last=True,
    dirpath="./checkpoints", filename="model-epoch_{epoch:03d}-val_loss_{val_loss:.2f}-val_f1_{val_f1:.2f}",
    auto_insert_metric_name=False, save_on_train_epoch_end=False
)


Using cache found in /root/.cache/torch/hub/pytorch_vision_main


set weights to GoogLeNet_Weights.IMAGENET1K_V1


Using cache found in /root/.cache/torch/hub/pytorch_vision_main
  rank_zero_warn(


In [42]:
# Initialize Lightning Trainer
trainer = pl.Trainer(
    max_steps=args.max_steps,
    default_root_dir=args.savedir,
    devices=gpus,
    accelerator=strategy,
    num_sanity_val_steps=args.num_sanity_val_steps,
    accumulate_grad_batches=args.accumulate_grad_batches,
    deterministic=True,
    logger=wandb_logger,
    callbacks=[early_stop_callback, checkpoint_callback, lr_monitor],
)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, data_module)
trainer.test(dataloaders=data_module.test_dataloader())
wandb.finish()

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type                | Params
---------------------------------------------------
0 | model      | GoogLeNet           | 5.7 M 
1 | preprocess | ImageClassification | 0     
2 | aucroc     | MulticlassAUROC     | 0     
3 | accuracy   | MulticlassAccuracy  | 0     
4 | f1         | MulticlassF1Score   | 0     
---------------------------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.810    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved. New best score: 0.045


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.121 >= min_delta = 0.0. New best score: 0.166
