In [3]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import torchmetrics
import torch
from torch import nn
from sklearn.utils import class_weight
import pandas as pd
import numpy as np
from datasets import load_dataset

In [4]:
from transformers import BeitFeatureExtractor

feature_extractor = BeitFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
train_ds = load_dataset("DGE_training", data_dir="*/", split="train").shuffle()
val_ds = load_dataset("DGE_val", data_dir="*/", split="train").shuffle()
test_ds = load_dataset("DGE_test", data_dir="*/", split="train")

Resolving data files: 100%|██████████| 2950/2950 [00:00<00:00, 18785.61it/s]
Using custom data configuration DGE_test-14758c97f2c96edf
Reusing dataset imagefolder (C:\Users\zacha\.cache\huggingface\datasets\imagefolder\DGE_test-14758c97f2c96edf\0.0.0\0fc50c79b681877cc46b23245a6ef5333d036f48db40d53765a68034bc48faff)
Resolving data files: 100%|██████████| 2346/2346 [00:00<00:00, 20046.77it/s]
Using custom data configuration DGE_val-75693c4867bd4840
Reusing dataset imagefolder (C:\Users\zacha\.cache\huggingface\datasets\imagefolder\DGE_val-75693c4867bd4840\0.0.0\0fc50c79b681877cc46b23245a6ef5333d036f48db40d53765a68034bc48faff)
Resolving data files: 100%|██████████| 476/476 [00:00<00:00, 237960.51it/s]
Using custom data configuration DGE_training-1a4881cf19cc1a09
Reusing dataset imagefolder (C:\Users\zacha\.cache\huggingface\datasets\imagefolder\DGE_training-1a4881cf19cc1a09\0.0.0\0fc50c79b681877cc46b23245a6ef5333d036f48db40d53765a68034bc48faff)


In [5]:
id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}
label2id = {label:id for id,label in id2label.items()}
print(label2id)

{'2': 0, '3': 1, '4': 2, '5': 3, '6': 4}


In [6]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor,
                                    RandomRotation)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
        [   
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [7]:
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

In [8]:
from sklearn.utils import class_weight
import pandas as pd
import numpy as np

train_df = pd.read_csv('train_data.csv',dtype=str)
class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(train_df['class']), y =train_df['class'])
print(class_weights)

# dataset_size = len(train_ds[:]['label'])
# train_count = int(dataset_size * 0.7)
# val_count = int(dataset_size - train_count)
y_train = train_ds[:]['label']

#print(train_dataset[0])
import numpy as np 

#y_train_indices = train_dataset["train"].indices

#y_train = train_ds[:]['label']

class_sample_count = np.array(
    [len(np.where(y_train == t)[0]) for t in np.unique(y_train)])

weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_train])
samples_weight = torch.from_numpy(samples_weight)

sampler = torch.utils.data.WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

[5.95       0.952      0.43272727 0.7616     6.34666667]


In [9]:
print(train_ds)

Dataset({
    features: ['image', 'label'],
    num_rows: 2950
})


In [10]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_batch_size = 32
eval_batch_size = 32

train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=train_batch_size, sampler=sampler)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=eval_batch_size, sampler=sampler)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

In [11]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([32, 3, 224, 224])
labels torch.Size([32])


In [12]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['labels'].shape == (train_batch_size,)

In [13]:
import pytorch_lightning as pl
from transformers import BeitForImageClassification, AdamW
import torch.nn as nn

class ViTLightningModule(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(ViTLightningModule, self).__init__()
        self.vit = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224",
                                                              num_labels=5,
                                                              id2label=id2label,
                                                              label2id=label2id,
                                                              ignore_mismatched_sizes=True)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
        
    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        #criterion = nn.CrossEntropyLoss()
        criterion = nn.CrossEntropyLoss(weight=torch.cuda.FloatTensor(class_weights))
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("test_loss", loss, on_epoch=True)
        self.log("test_accuracy", accuracy, on_epoch=True)

        return loss, accuracy

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader
    
    def test_dataloader(self):
        return test_dataloader
    
    def predict_dataloader(self):
        return test_dataloader

In [14]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='validation_loss',
    patience=10,
    strict=False,
    verbose=False,
    mode='min',
    check_on_train_epoch_end=True
)

model = ViTLightningModule()
trainer = Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping(monitor='validation_loss')],
    check_val_every_n_epoch=1, 
    max_epochs=100)

trainer.fit(model)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/beit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type                       | Params
----------------------------------------------------
0 | vit  | BeitForImageClassification | 85.8 M
---------------------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


IndexError: Invalid key: 2645 is out of bounds for size 2346

In [None]:
trainer.test(model=model, dataloaders=test_dataloader, verbose=True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 93/93 [00:12<00:00,  7.62it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.2579661011695862
        test_loss           1.8015609979629517
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.8015609979629517, 'test_accuracy': 0.2579661011695862}]

In [None]:
preds = trainer.predict(model, dataloaders=test_dataloader, return_predictions=True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting DataLoader 0:   0%|          | 0/93 [00:00<?, ?it/s]

AttributeError: 'dict' object has no attribute 'shape'