# load data

In [3]:
from datasets import load_dataset
from transformers import ViTImageProcessor
# load dataset
train_ds, test_ds = load_dataset('imagefolder', data_dir='../data', split=['train', 'test'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.3)
train_ds = splits['train']
val_ds = splits['test']
id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}
label2id = {label: id for id, label in id2label.items()}
print(id2label)

Resolving data files: 100%|█████████████████████████████████████████████████████| 2176/2176 [00:00<00:00, 27960.06it/s]
Resolving data files: 100%|█████████████████████████████████████████████████████| 2180/2180 [00:00<00:00, 41131.91it/s]


{0: '0', 1: '1'}


# preprocess data

In [4]:
processor = ViTImageProcessor.from_pretrained("google/vit-large-patch16-224")
image_mean = processor.image_mean
image_std = processor.image_std
size = processor.size["height"]

In [5]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [6]:
# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

In [17]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 2
eval_batch_size = 2
test_batch_size = 1

train_dataloader = DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=test_batch_size)

# Define the model

In [25]:
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

class ViTLightningModule(pl.LightningModule):
    def __init__(self, num_labels=2):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224',
                                                              num_labels=num_labels,
                                                              ignore_mismatched_sizes=True,
                                                              id2label=id2label,
                                                              label2id=label2id)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch['pixel_values']).softmax(dim=1)
        
    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss, prog_bar=True)
        self.log("training_accuracy", accuracy, prog_bar=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True, prog_bar=True)
        self.log("validation_accuracy", accuracy, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader

# train the model

In [19]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import loggers as pl_loggers

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=50,
    verbose=False,
    mode='min'
)
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="validation_loss")
tb_logger = pl_loggers.TensorBoardLogger(save_dir="lightning_logs/")
model = ViTLightningModule()

trainer = Trainer(
    num_nodes=1,
    max_epochs=100,
    callbacks=[EarlyStopping(monitor='validation_loss'), 
    checkpoint_callback],
    val_check_interval=len(train_dataloader),
    logger=tb_logger
)
trainer.fit(model)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([2, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type                      | Params
---------------------------------------------------
0 | vit  | ViTForImageClassification | 303 M 
---------------------------------------------------
303 M     Trainable params
0         Non-trainable params
303 M     Total p

Epoch 0: 100%|███████████████| 762/762 [02:34<00:00,  4.92it/s, v_num=1, training_loss=0.0304, training_accuracy=1.000]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                              | 0/327 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                 | 0/327 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|▏                                                        | 1/327 [00:00<00:19, 16.59it/s][A
Validation DataLoader 0:   1%|▎                                                        | 2/327 [00:00<00:18, 17.97it/s][A
Validation DataLoader 0:   1%|▌                                                        | 3/327 [00:00<00:17, 18.60it/s][A
Validation DataLoader 0:   1%|▋                                                        | 4/327 [00:00<00:18, 17.83it/s][A
Validation DataLoader 0:   2%|▊                                                        | 5/327 [00:00<00:18,

In [26]:
model = ViTLightningModule().load_from_checkpoint('./lightning_logs/lightning_logs/version_3/checkpoints/epoch=47-step=36576.ckpt')
model.eval()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([2, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([2, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a d

ViTLightningModule(
  (vit): ViTForImageClassification(
    (vit): ViTModel(
      (embeddings): ViTEmbeddings(
        (patch_embeddings): ViTPatchEmbeddings(
          (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): ViTEncoder(
        (layer): ModuleList(
          (0-23): 24 x ViTLayer(
            (attention): ViTAttention(
              (attention): ViTSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTSelfOutput(
                (dense): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
  

In [27]:
from pytorch_lightning import Trainer
trainer=Trainer()
predictions = trainer.predict(model, test_dataloader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|█████████████████████████████████████████████████████| 2180/2180 [01:06<00:00, 32.62it/s]


In [30]:
final_preds = []
for predict in predictions:
    final_preds.append(predict[0][1])
print(final_preds)

[tensor(0.8458), tensor(0.4622), tensor(0.0350), tensor(0.9912), tensor(0.0397), tensor(0.9396), tensor(0.0627), tensor(0.0365), tensor(0.0511), tensor(0.0753), tensor(0.0939), tensor(0.1321), tensor(0.7680), tensor(0.0375), tensor(0.2392), tensor(0.0587), tensor(0.9261), tensor(0.0631), tensor(0.9036), tensor(0.0791), tensor(0.0022), tensor(0.9974), tensor(0.8649), tensor(0.9989), tensor(0.0197), tensor(0.2638), tensor(0.0037), tensor(0.0045), tensor(0.9980), tensor(0.0675), tensor(0.4586), tensor(0.0723), tensor(0.9990), tensor(0.0028), tensor(0.8702), tensor(0.3253), tensor(0.4701), tensor(0.3646), tensor(0.9456), tensor(0.5196), tensor(0.9605), tensor(0.0250), tensor(0.9994), tensor(0.0470), tensor(0.9095), tensor(0.1285), tensor(0.0074), tensor(0.8745), tensor(0.7081), tensor(0.0137), tensor(0.9563), tensor(0.9761), tensor(0.0286), tensor(0.6272), tensor(0.3080), tensor(0.0128), tensor(0.9984), tensor(0.5172), tensor(0.2049), tensor(0.2639), tensor(0.9990), tensor(0.9908), tensor(

In [35]:
import pandas as pd
submit_file = '../sample_submit.csv'
df = pd.read_csv(submit_file, header=None)  # Specify header=None
df.iloc[:, 1] = final_preds
df.to_csv(submit_file, index=False, header=False)