Imports

In [2]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

### vitlightning.py

In [3]:
class ViTLightningModule(pl.LightningModule):
    def __init__(self, model_name, num_labels, id2label, label2id, dataloaders):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained(
            model_name,
            num_labels,
            id2label,
            label2id,
        )
        self.dataloaders = dataloaders

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits

    def common_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct / pixel_values.shape[0]

        return loss, accuracy

    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)

        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return self.dataloaders["train"]

    def val_dataloader(self):
        return self.dataloaders["val"]

    def test_dataloader(self):
        return self.dataloaders["test"]

### train_vit_clf.py

In [None]:
def create_dataloaders_and_mappings(data_path):
    dataset = load_dataset("imagefolder", data_dir=args.data_path)
    splits = dataset["train"].train_test_split(test_size=0.1)
    dataset["train"] = splits["train"]
    dataset["val"] = splits["test"]

    id2label = {
        id: label for id, label in enumerate(dataset["train"].features["label"].names)
    }
    label2id = {label: id for id, label in id2label.items()}

    train_dataloader = DataLoader(dataset["train"], shuffle=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(dataset["val"], collate_fn=collate_fn)
    test_dataloader = DataLoader(dataset["test"], collate_fn=collate_fn)

    dataloaders = {}
    dataloaders["train"] = train_dataloader
    dataloaders["val"] = val_dataloader
    dataloaders["test"] = test_dataloader

    return dataloaders, id2label, label2id


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

def main():
    colab_data_path = ""
    colab_dir = ""
    model_id = ""
    dataloaders, id2label, label2id = create_dataloaders_and_mappings()
    num_labels = len(id2label)

    early_stop_callback = EarlyStopping(
        monitor="val_loss", patience=3, strict=False, verbose=False, mode="min"
    )

    model = ViTLightningModule(
        model_id, num_labels, id2label, label2id, dataloaders
    )

    # model = ViTLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")

    trainer = Trainer(
        gpus=1, callbacks=early_stop_callback, default_root_dir=colab_dir
    )

    trainer.fit()
    trainer.test()

### Experiment

In [None]:
main()