<a href="https://colab.research.google.com/github/pmojiri/HPA_image_classifier/blob/main/HPA_Image_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Human Protein Atlas Image classification
Classify subcellular protein patterns in human cells

This is a project for the [Human Protein Atlas Image Classification](https://www.kaggle.com/c/human-protein-atlas-image-classification) competition on Kaggle.

![sample](sample.png)


In [2]:
%rm -rf datasets/human-protein-atlas-image-classification.zip sample_data

In [3]:
!nvidia-smi

Tue Nov  7 12:44:34 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
# install some required packages
# !pip install fastapi==0.89.0 torchinfo lightning torchmetrics
# !pip install requests opencv-contrib-python tqdm scikit-learn seaborn

Collecting fastapi==0.89.0
  Downloading fastapi-0.89.0-py3-none-any.whl (55 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/55.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.6/55.6 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Collecting lightning
  Downloading lightning-2.1.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchmetrics
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
Collecting starlette==0.22.0 (from fastapi==0.89.0)
  Downloading starlette-0.22.0-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.3/64.3 kB[0m [31m8

In [6]:
import os
ROOT_PATH = os.path.join(os.getcwd())

In [7]:
LABELS = {
0:  'Nucleoplasm',
1:  'Nuclear membrane',
2:  'Nucleoli',
3:  'Nucleoli fibrillar center',
4:  'Nuclear speckles',
5:  'Nuclear bodies',
6:  'Endoplasmic reticulum',
7:  'Golgi apparatus',
8:  'Peroxisomes',
9:  'Endosomes',
10:  'Lysosomes',
11:  'Intermediate filaments',
12:  'Actin filaments',
13:  'Focal adhesion sites',
14:  'Microtubules',
15:  'Microtubule ends',
16:  'Cytokinetic bridge',
17:  'Mitotic spindle',
18:  'Microtubule organizing center',
19:  'Centrosome',
20:  'Lipid droplets',
21:  'Plasma membrane',
22:  'Cell junctions',
23:  'Mitochondria',
24:  'Aggresome',
25:  'Cytosol',
26:  'Cytoplasmic bodies',
27:  'Rods & rings' }

LABELS_REVERSE = {item: key for key, item in LABELS.items()}

In [9]:
from dataclasses import dataclass
import platform

@dataclass
class LogsConfig:
    """This class contains all the logging configuration parameters"""
    VIZ_DIR: str = os.path.join(ROOT_PATH, "logs")


@dataclass
class DatasetConfig:
    """This class contains all the dataset configuration parameters"""
    TRAIN_IMG_DIR: str = os.path.join(ROOT_PATH, "datasets", "train")
    TEST_IMG_DIR: str = os.path.join(ROOT_PATH, "datasets", "test")
    TRAIN_CSV: str = os.path.join(ROOT_PATH, "datasets", "train.csv")
    TEST_CSV: str = os.path.join(ROOT_PATH, "datasets", "sample_submission.csv")

    IMAGE_SIZE: tuple = (512, 512)
    CHANNELS: int = 3
    NUM_CLASSES: int = 28
    VALID_PCT: float = 0.2

    MEAN: tuple = (0.485, 0.456, 0.406)
    STD: tuple = (0.229, 0.224, 0.225)


@dataclass
class ModelTrainingConfig:
    """This class contains all the Model training configuration parameters"""
    MODEL_NAME: str = "resnet18"
    FREEZE_BACKBONE: bool = False
    BATCH_SIZE: int = 32
    NUM_EPOCHS: int = 4
    INIT_LR: float = 1e-4
    NUM_WORKERS: int = 0 if platform.system() == "Windows" else os.cpu_count()
    OPTIMIZER_NAME: str = "Adam"
    WEIGHT_DECAY: float = 1e-4
    USE_SCHEDULER: bool = True  # Use learning rate scheduler?
    SCHEDULER: str = "multi_step_lr"  # Name of the scheduler to use.
    METRIC_THRESH: float = 0.4
    LOSS_FN: str = "BCEWithLogitsLoss"
    MODEL_CKPT_PATH: str = os.path.join(ROOT_PATH, "logs", "default", "version_0", "checkpoints", "epoch=29.ckpt")


In [11]:
import torch

def encode_label(
        labels: list,
        num_classes: int = 28
) -> torch.Tensor:
    """This function converts the labels into one-hot encoded vectors"""
    target = torch.zeros(num_classes)
    for label in str(labels).split(" "):
        target[int(label)] = 1.0
    return target


def decode_target(
        target: list,
        text_labels: bool = False,
        threshold: float = 0.4,
        cls_labels: dict = None,
) -> str:
    """This function converts the labels from
    probabilities to outputs or string representations
    """

    result = []
    for i, x in enumerate(target):
        if x >= threshold:
            if text_labels:
                result.append(cls_labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return " ".join(result)

In [12]:
"""
Dataset for Human Protein Atlas Image Classification Challenge
"""
import os
import numpy as np
from PIL import Image
import pandas as pd
import torchvision

from torch.utils.data import Dataset

class HPAImageDataset(Dataset):
    """
    Dataset for Human Protein Atlas Image Classification Challenge
    """

    def __init__(self,
                 df_data: pd.DataFrame,
                 root_dir: str,
                 img_size: tuple,
                 transforms: torchvision.transforms = None,
                 is_test: bool = False
                 ):
        self.df_data = df_data
        self.root_dir = root_dir
        self.img_size = img_size
        self.transforms = transforms
        self.is_test = is_test

    def load_image(self,
                   file_name: str
                   ) -> Image:
        """
        Load four channels of an image and merge them into one image
        :param file_name: image file name
        :return: image in RGB format
        """
        R = np.array(Image.open(file_name + '_red.png'))
        G = np.array(Image.open(file_name + '_green.png'))
        B = np.array(Image.open(file_name + '_blue.png'))
        Y = np.array(Image.open(file_name + '_yellow.png'))
        image = np.stack((R + Y / 2, G + Y / 2, B), -1)

        return Image.fromarray(image.astype('uint8'), 'RGB')

    def __len__(self):
        """ Returns the length of the dataset """
        return len(self.df_data)

    def __getitem__(self,
                    idx: int
                    ) -> tuple:
        row = self.df_data.loc[idx]
        img_id = row["Id"]

        img = self.load_image(self.root_dir + os.sep + str(img_id))
        img = img.resize(self.img_size, resample=3)
        img = self.transforms(img)

        if self.is_test:
            return img, img_id

        return img, encode_label(row["Target"])

In [22]:
"""
DataModule for the Human Protein Atlas Image Classification competition.
"""
import os
import shutil
from itertools import chain

import numpy as np
import pandas as pd
import subprocess
import logging
import warnings
from sklearn.utils.class_weight import compute_class_weight

from torch.utils.data import DataLoader

import torch
import torchvision.transforms as TF
import lightning.pytorch as pl

torch.set_float32_matmul_precision('high')

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# Filter userwarnings.
warnings.filterwarnings(action="ignore", category=UserWarning)


class HPADataModule(pl.LightningDataModule):
    """
    DataModule for the Human Protein Atlas Image Classification competition.
    """

    def __init__(
            self,
            num_classes: int = 28,
            valid_percentage: float = 0.2,
            resize_to: tuple = (512, 512),
            batch_size: int = 32,
            num_workers: int = 0,
            pin_memory: bool = False,
            shuffle_validation: bool = False,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.valid_percentage = valid_percentage
        self.resize_to = resize_to
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.shuffle_validation = shuffle_validation

        self.test_ds = None
        self.valid_ds = None
        self.train_ds = None
        self.class_weights = None

        self.train_transforms = TF.Compose(
            [
                TF.RandomHorizontalFlip(),
                TF.RandomVerticalFlip(),
                TF.RandomAffine(
                    degrees=30,
                    translate=(0.01, 0.12),
                    shear=0.05,
                ),
                # TF.ColorJitter(
                #     brightness=0.1,
                #     contrast=0.1,
                #     saturation=0.1,
                #     hue=0.1,
                # ),
                TF.ToTensor(),
                TF.Normalize(DatasetConfig.MEAN, DatasetConfig.STD, inplace=True),
                TF.RandomErasing(inplace=True),
            ]
        )

        self.valid_transforms = TF.Compose(
            [
                TF.ToTensor(),
                TF.Normalize(DatasetConfig.MEAN, DatasetConfig.STD, inplace=True),
            ]
        )
        self.test_transforms = self.valid_transforms

    def prepare_data(self):
        """ Download data if needed from Kaggle """
        if not os.path.exists(os.path.join(DatasetConfig.TRAIN_CSV)):
            KAGGLE_DIR = os.path.join(os.path.expanduser("~"), ".kaggle")
            KAGGLE_JSON_PATH = os.path.join(KAGGLE_DIR, "kaggle.json")

            if not os.path.exists(KAGGLE_JSON_PATH):
                os.makedirs(KAGGLE_DIR, exist_ok=True)
                shutil.copyfile("kaggle.json", KAGGLE_JSON_PATH)
                os.chmod(KAGGLE_JSON_PATH, 0o600)

            logging.info("Downloading data from Kaggle...")
            !kaggle competitions download -c human-protein-atlas-image-classification -p datasets
            !unzip datasets/human-protein-atlas-image-classification.zip -d datasets

    def setup(self, stage=None):
        """ Split data into train, validation and test sets """
        np.random.seed(42)
        data_df = pd.read_csv(DatasetConfig.TRAIN_CSV)
        msk = np.random.rand(len(data_df)) < (1.0 - self.valid_percentage)
        train_df = data_df[msk].reset_index()
        valid_df = data_df[~msk].reset_index()

        train_labels = list(chain.from_iterable([i.strip().split(" ") for i in train_df["Target"].values]))
        class_weights = compute_class_weight("balanced", classes=list(range(self.num_classes)),
                                             y=[int(i) for i in train_labels])
        self.class_weights = torch.tensor(class_weights)

        img_size = DatasetConfig.IMAGE_SIZE
        self.train_ds = HPAImageDataset(
            df_data=train_df, img_size=img_size, root_dir=DatasetConfig.TRAIN_IMG_DIR, transforms=self.train_transforms
        )

        self.valid_ds = HPAImageDataset(
            df_data=valid_df, img_size=img_size, root_dir=DatasetConfig.TRAIN_IMG_DIR, transforms=self.valid_transforms
        )

        test_df = pd.read_csv(DatasetConfig.TEST_CSV)
        self.test_ds = HPAImageDataset(
            df_data=test_df, img_size=img_size, root_dir=DatasetConfig.TEST_IMG_DIR, transforms=self.test_transforms, is_test=True
        )

        logging.info(
            f"Number of images :: "
            f"Training: {len(self.train_ds)}, "
            f"Validation: {len(self.valid_ds)}, "
            f"Testing: {len(self.test_ds)}\n"
        )

    def train_dataloader(self):
        """ Create training dataloader object """
        train_loader = DataLoader(
            self.train_ds, batch_size=self.batch_size, pin_memory=self.pin_memory, shuffle=True,
            num_workers=self.num_workers
        )
        return train_loader

    def val_dataloader(self):
        """ Create validation dataloader object."""
        valid_loader = DataLoader(
            self.valid_ds, batch_size=self.batch_size, pin_memory=self.pin_memory,
            shuffle=self.shuffle_validation, num_workers=self.num_workers
        )
        return valid_loader

    def test_dataloader(self):
        """Create test dataloader object."""
        test_loader = DataLoader(
            self.test_ds, batch_size=self.batch_size, pin_memory=self.pin_memory, shuffle=False,
            num_workers=self.num_workers
        )
        return test_loader

In [23]:
data_module = HPADataModule(
    num_classes=DatasetConfig.NUM_CLASSES,
    batch_size=32,
    num_workers=0
)

# Download dataset.
data_module.prepare_data()

# Split dataset into training, validation set.
data_module.setup()

In [24]:
"""
This module contains a helper function to load and prepare any classification
"""
import torch
import torch.nn as nn
import torchvision
import logging


def get_model(
        model_name: str,
        num_classes: int,
        freeze_backbone: bool = True
) -> torch.nn.Module:
    """
    A helper function to load and prepare any classification model
    available in Torchvision for transfer learning or fine-tuning.

    Args:
        model_name: Name of the model to be loaded.
        num_classes: Number of classes in the dataset.
        freeze_backbone: Whether to freeze the backbone of the model or not
            for fine-tuning.


    Returns:
        model: A PyTorch model with the output layer replaced with a new
            output layer with `num_classes` number of output nodes.

    """

    logging.info(f"Loading model: {model_name}")
    model = getattr(torchvision.models, model_name)(weights="DEFAULT")

    if freeze_backbone:
        # Set all layer to be non-trainable
        for param in model.parameters():
            param.requires_grad = False

    model_childs = [name for name, _ in model.named_children()]

    try:
        final_layer_in_features = getattr(model, f"{model_childs[-1]}")[-1].in_features
    except Exception as e:
        final_layer_in_features = getattr(model, f"{model_childs[-1]}").in_features

    new_output_layer = nn.Linear(
        in_features=final_layer_in_features,
        out_features=num_classes
    )

    try:
        getattr(model, f"{model_childs[-1]}")[-1] = new_output_layer
    except:
        setattr(model, model_childs[-1], new_output_layer)

    return model

In [25]:
from torchinfo import summary

model = get_model(
    model_name=ModelTrainingConfig.MODEL_NAME,
    num_classes=DatasetConfig.NUM_CLASSES,
    freeze_backbone=False,
)

model_info = summary(
    model,
    input_size=(1, DatasetConfig.CHANNELS, *DatasetConfig.IMAGE_SIZE[::-1]),
    depth=2,
    device="cpu",
    col_names=["output_size", "num_params", "trainable"]
)
model_info

Layer (type:depth-idx)                   Output Shape              Param #                   Trainable
ResNet                                   [1, 28]                   --                        True
├─Conv2d: 1-1                            [1, 64, 256, 256]         9,408                     True
├─BatchNorm2d: 1-2                       [1, 64, 256, 256]         128                       True
├─ReLU: 1-3                              [1, 64, 256, 256]         --                        --
├─MaxPool2d: 1-4                         [1, 64, 128, 128]         --                        --
├─Sequential: 1-5                        [1, 64, 128, 128]         --                        True
│    └─BasicBlock: 2-1                   [1, 64, 128, 128]         73,984                    True
│    └─BasicBlock: 2-2                   [1, 64, 128, 128]         73,984                    True
├─Sequential: 1-6                        [1, 128, 64, 64]          --                        True
│    └─BasicBlock: 

In [26]:
"""
This module contains the model class that is used for training and inference.
"""
import torch
import torch.nn as nn
import logging

import lightning.pytorch as pl

from torchmetrics import MeanMetric
from torchmetrics.classification import MultilabelF1Score

class HPAModel(pl.LightningModule):
    """
    A PyTorch Lightning Module for training and inference.
    """

    def __init__(
            self,
            model_name: str,
            num_classes: int = 28,
            freeze_backbone: bool = False,
            init_lr: float = 0.001,
            optimizer_name: str = "Adam",
            weight_decay: float = 1e-4,
            use_scheduler: bool = False,
            f1_metric_threshold: float = 0.4,
    ):
        super().__init__()

        # Save the arguments as hyperparameters.
        self.save_hyperparameters()

        # Loading model using the function defined above.
        self.model = get_model(
            model_name=self.hparams.model_name,
            num_classes=self.hparams.num_classes,
            freeze_backbone=self.hparams.freeze_backbone,
        )

        # Initialize loss class.
        self.loss_fn = nn.BCEWithLogitsLoss()

        # Initializing the required metric objects.
        self.mean_train_loss = MeanMetric()
        self.mean_train_f1score = MultilabelF1Score(num_labels=self.hparams.num_classes,
                                                    average="macro",
                                                    threshold=self.hparams.f1_metric_threshold)
        self.mean_valid_loss = MeanMetric()
        self.mean_valid_f1score = MultilabelF1Score(num_labels=self.hparams.num_classes,
                                                    average="macro",
                                                    threshold=self.hparams.f1_metric_threshold)

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:
        """ Forward pass of the model. """
        return self.model(x)

    def training_step(self,
                      batch,
                      *args,
                      **kwargs
                      ):
        """ Training step. """
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target)

        self.mean_train_loss(loss, weight=data.shape[0])
        self.mean_train_f1score(logits, target)

        self.log("train/batch_loss", self.mean_train_loss, prog_bar=True)
        self.log("train/batch_f1score", self.mean_train_f1score, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        """
        Training epoch end.
        Computing and logging the training mean loss & mean f1.
        """
        self.log("train/loss", self.mean_train_loss, prog_bar=True)
        self.log("train/f1score", self.mean_train_f1score, prog_bar=True)
        self.log("step", self.current_epoch)

    def validation_step(self,
                        batch,
                        *args,
                        **kwargs
                        ):
        """ Validation step. """
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target)

        self.mean_valid_loss.update(loss, weight=data.shape[0])
        self.mean_valid_f1score.update(logits, target)

    def on_validation_epoch_end(self):
        """
        validation epoch end.
        Computing and logging the validation mean loss & mean f1.
        """
        self.log("valid/loss", self.mean_valid_loss, prog_bar=True)
        self.log("valid/f1score", self.mean_valid_f1score, prog_bar=True)
        self.log("step", self.current_epoch)

    def configure_optimizers(self):
        """ Configuring the optimizer and the learning rate scheduler. """
        optimizer = getattr(torch.optim, self.hparams.optimizer_name)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.hparams.init_lr,
            weight_decay=self.hparams.weight_decay,
        )

        if self.hparams.use_scheduler:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[self.trainer.max_epochs // 2, ],
                gamma=0.1,
            )

            # The lr_scheduler_config is a dictionary that contains the scheduler
            # and its associated configuration.
            lr_scheduler_config = {
                "scheduler": lr_scheduler,
                "interval": "epoch",
                "name": "multi_step_lr",
            }
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

        else:
            return optimizer

In [27]:
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# Seed everything for reproducibility.
pl.seed_everything(42, workers=True)

# Initialize model.
model = HPAModel(
    model_name=ModelTrainingConfig.MODEL_NAME,
    num_classes=DatasetConfig.NUM_CLASSES,
    freeze_backbone=ModelTrainingConfig.FREEZE_BACKBONE,
    init_lr=ModelTrainingConfig.INIT_LR,
    optimizer_name=ModelTrainingConfig.OPTIMIZER_NAME,
    weight_decay=ModelTrainingConfig.WEIGHT_DECAY,
    use_scheduler=ModelTrainingConfig.USE_SCHEDULER,
    f1_metric_threshold=ModelTrainingConfig.METRIC_THRESH,
)

# Initialize data module.
data_module = HPADataModule(
    num_classes=DatasetConfig.NUM_CLASSES,
    valid_percentage=DatasetConfig.VALID_PCT,
    resize_to=DatasetConfig.IMAGE_SIZE,
    batch_size=ModelTrainingConfig.BATCH_SIZE,
    num_workers=ModelTrainingConfig.NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
)

# Creating ModelCheckpoint callback.
# Checkpoints by default will be saved in Trainer - default_root_dir which is "lightning_logs".
model_checkpoint = ModelCheckpoint(
    monitor="valid/f1score",
    mode="max",
    filename="ckpt_{epoch:03d}-vloss_{valid/loss:.4f}_vf1score_{valid/f1score:.4f}",
    auto_insert_metric_name=False,
)

# Creating a learning rate monitor callback which will be plotted/added in the default logger.
lr_rate_monitor = LearningRateMonitor(logging_interval="epoch")

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


In [28]:
# Initializing the Trainer class object.
trainer = pl.Trainer(
    accelerator="auto",  # Auto select accelerator (GPU, TPU, CPU).
    devices="auto",  # Auto select devices.
    strategy="auto",  # Auto select distributed backend.
    max_epochs=ModelTrainingConfig.NUM_EPOCHS,  # Number of epochs to train for.
    deterministic=True,  # Keep everything deterministic for reproducibility.
    enable_model_summary=False,  # Disable model summary.
    callbacks=[model_checkpoint, lr_rate_monitor],  # Add callbacks.
    precision="16",  # Use mixed precision training.
    logger=True,  # Use TensorBoard logger.
)

# Start training
trainer.fit(model, data_module)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=4` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=4` reached.


In [29]:
# Get path of the best saved model.
model_path = model_checkpoint.best_model_path

In [30]:
model = HPAModel.load_from_checkpoint(model_path)

In [31]:
# Initialize trainer class for inference.
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    enable_checkpointing=False,
    inference_mode=True,
)

# Run evaluation.
data_module.setup()
valid_loader = data_module.val_dataloader()
trainer.validate(model=model, dataloaders=valid_loader)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'valid/loss': 0.0947328433394432,
  'valid/f1score': 0.4195646047592163,
  'step': 0.0}]