In [1]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, val_ds, test_ds = load_dataset("yerevann/coco-karpathy", split=['train', 'validation', 'test'])

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
train_ds[0]

{'filepath': 'train2014',
 'sentids': [787980, 789366, 789888, 791316, 794853],
 'filename': 'COCO_train2014_000000057870.jpg',
 'imgid': 40504,
 'split': 'train',
 'sentences': ['A restaurant has modern wooden tables and chairs.',
  'A long restaurant table with rattan rounded back chairs.',
  'a long table with a plant on top of it surrounded with wooden chairs ',
  'A long table with a flower arrangement in the middle for meetings',
  'A table is adorned with wooden chairs with blue accents.'],
 'cocoid': 57870,
 'url': 'http://images.cocodataset.org/train2014/COCO_train2014_000000057870.jpg'}

In [11]:
import os
from PIL import Image

import torch
from torch.utils.data import Dataset

class CocoDataset(Dataset):
    def __init__(self, dataset, image_dir, processor=None, transform=None):
        self.dataset = dataset
        self.image_dir = image_dir
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        #img_filename = sample['filename']
        img_filenames = sample['filename']
        img_filename = img_filenames[0] if isinstance(img_filenames, list) else img_filenames
        img_path = os.path.join(self.image_dir, img_filename)       
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transformations if provided
        if self.processor:
            img = torch.squeeze(self.processor(images=image, return_tensors="pt").pixel_values)
        else:
            img = image

        if self.transform:
            img = self.transform(img)
        
        # Extract captions
        captions = sample['sentences']
        # Take the first caption
        caption = captions[0]
        
        return img, caption

In [6]:
from transformers import AutoImageProcessor

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
image_mean = processor.image_mean
image_std = processor.image_std
size = processor.size["height"]

In [22]:
image = Image.open(os.path.join('/home/data/COCOcaptions/train2014', train_ds[420]['filename'])).convert('RGB')
processor(image)['pixel_values']

[array([[[-0.5921569 , -0.5764706 , -0.5686275 , ..., -0.8039216 ,
          -0.8117647 , -0.8039216 ],
         [-0.5921569 , -0.58431375, -0.5921569 , ..., -0.73333335,
          -0.75686276, -0.7254902 ],
         [-0.6       , -0.60784316, -0.60784316, ..., -0.5686275 ,
          -0.58431375, -0.67058825],
         ...,
         [-0.8509804 , -0.8509804 , -0.84313726, ..., -0.827451  ,
          -0.81960785, -0.81960785],
         [-0.84313726, -0.827451  , -0.81960785, ..., -0.827451  ,
          -0.81960785, -0.827451  ],
         [-0.827451  , -0.827451  , -0.81960785, ..., -0.81960785,
          -0.827451  , -0.8509804 ]],
 
        [[-0.56078434, -0.54509807, -0.5372549 , ..., -0.7882353 ,
          -0.79607844, -0.79607844],
         [-0.56078434, -0.5529412 , -0.56078434, ..., -0.6862745 ,
          -0.7019608 , -0.67058825],
         [-0.5686275 , -0.58431375, -0.58431375, ..., -0.49019605,
          -0.49019605, -0.5764706 ],
         ...,
         [-0.8980392 , -0.9058823

In [5]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

In [6]:
# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

In [7]:
train_ds[:2]

{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],
 'label': [5, 8],
 'pixel_values': [tensor([[[-0.6471, -0.6471, -0.6471,  ..., -0.7333, -0.7333, -0.7333],
           [-0.6471, -0.6471, -0.6471,  ..., -0.7333, -0.7333, -0.7333],
           [-0.6471, -0.6471, -0.6471,  ..., -0.7333, -0.7333, -0.7333],
           ...,
           [-0.5294, -0.5294, -0.5294,  ..., -0.6000, -0.6000, -0.6000],
           [-0.5294, -0.5294, -0.5294,  ..., -0.6000, -0.6000, -0.6000],
           [-0.5294, -0.5294, -0.5294,  ..., -0.6000, -0.6000, -0.6000]],
  
          [[-0.2078, -0.2078, -0.2078,  ..., -0.4510, -0.4510, -0.4510],
           [-0.2078, -0.2078, -0.2078,  ..., -0.4510, -0.4510, -0.4510],
           [-0.2078, -0.2078, -0.2078,  ..., -0.4510, -0.4510, -0.4510],
           ...,
           [-0.1294, -0.1294, -0.1294,  ..., -0.2863, -0.2863, -0.2863],
           [-0.1294, -0.1294, -0.1294,  ..., -0.2863, -0.2863, -0.

In [8]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 2
eval_batch_size = 2

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

In [9]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([2, 3, 224, 224])
labels torch.Size([2])


In [10]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['labels'].shape == (train_batch_size,)

In [11]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([2, 3, 224, 224])

In [12]:
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

class ViTLightningModule(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                              num_labels=10,
                                                              id2label=id2label,
                                                              label2id=label2id)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
        
    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [15]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)
# 
model = ViTLightningModule()
trainer = Trainer(callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/phisch/venv_py3.8/py3.8/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') tha

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/phisch/venv_py3.8/py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


                                                                           

/home/phisch/venv_py3.8/py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Epoch 5:  56%|█████▌    | 1264/2250 [02:26<01:54,  8.62it/s, v_num=3]

/home/phisch/venv_py3.8/py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
