In [1]:
from transformers import ViTFeatureExtractor
from torch.utils.data import Dataset
import pandas as pd
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn.functional as F
import cv2
import torch
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)
import wandb
from pytorch_lightning.loggers import WandbLogger
import torchmetrics
from sklearn.model_selection import StratifiedKFold, KFold
import numpy as np
from torchvision import models

wandb.login()
pl.seed_everything(100)

TRAIN_IMGS_PATH = './train'
TEST_IMGS_PATH = './test'
TRAIN_DF_PATH = './train.csv'

MODEL_VESRION = 'google/vit-base-patch16-224-in21k'
IAMGE_SIZE = 224

[34m[1mwandb[0m: Currently logged in as: [33mvetka925[0m. Use [1m`wandb login --relogin`[0m to force relogin
Global seed set to 100


# Augmentations

In [2]:
import albumentations as A
from albumentations.pytorch import ToTensorV2


import matplotlib.pyplot as plt

image = cv2.imread(f"{TRAIN_IMGS_PATH}/220301070305_0e13309ae71ffc37ba629d19f46e0784.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image0 = cv2.imread(f"{TRAIN_IMGS_PATH}/220301070430_df1a737a5abe3707b424fbe9d1d92300.jpg")
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
image1 = cv2.imread(f"{TRAIN_IMGS_PATH}/220301070439_91576ea7fbb6567ff743e271662f9d06.jpg")
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)

def visualize(image):
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(image)


feature_extractor = ViTFeatureExtractor(MODEL_VESRION)
normalize = A.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)


_train_transforms = A.Compose([
    A.Resize(224, 224),
    normalize,
     ToTensorV2(),
])


_val_transforms = A.Compose(
        [
            A.Resize(224, 224),
            normalize,
            ToTensorV2(),
        ]
    )


def init_transforms(image_size):
    normalize = A.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    albumentations_transform_oneof = A.Compose([
        A.Resize(image_size, image_size),

        normalize,
         ToTensorV2(),
    ])
    
    _val_transforms = A.Compose(
        [
             A.Resize(image_size, image_size),
            normalize,
            ToTensorV2(),
        ]
    )
    
    return albumentations_transform_oneof, _val_transforms

# Prepare Data

In [3]:
import numpy as np

class ImageDataset(Dataset):
    def __init__(self, data_df, transform=None, add_transform=None):

        self.data_df = data_df
        self.transform = transform
        self.add_transform = add_transform
        
    def __getitem__(self, idx):

        image_name, label = self.data_df.iloc[idx]['ID_img'], self.data_df.iloc[idx]['class']
    
        image = cv2.imread(f"{TRAIN_IMGS_PATH}/{image_name}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image=image)['image']
        
        return image, torch.tensor(label).long()
    
    def __len__(self):
        return len(self.data_df)
    

In [4]:
import os
data_df = pd.read_csv(TRAIN_DF_PATH)
data_df['additional'] = 0

data_df = data_df.sample(frac=1, random_state=44).reset_index(drop=True)

In [5]:
class_counts = data_df['class'].value_counts().to_dict()
class_counts

{1.0: 292, 2.0: 149, 0.0: 102}

# Custom ROC AUC metric for PL

In [6]:
from torchmetrics import Metric
from sklearn.metrics import roc_auc_score
from typing import Optional, Any
import torch

def one_label_to_many(preds):
    result = []
    for p in preds:
        many = [0, 0, 0]
        many[p] = 1
        result.append(many)
    return result

class ROCAUC(Metric):
    def __init__(
        self,
        average='macro',
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
    ):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
        )

        self.average = average
        self.add_state("preds", default=[], dist_reduce_fx=None)
        self.add_state("target", default=[], dist_reduce_fx=None)

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        Update state with predictions and targets.

        Args:
            preds: Predictions from model
            target: Ground truth values
        """
        self.preds.append(preds)
        self.target.append(target)

    def compute(self):
        
        preds = np.array(one_label_to_many(torch.cat(self.preds, dim=0).cpu().numpy()))
        target = torch.cat(self.target, dim=0).cpu().numpy().astype(int)
        score = roc_auc_score(target, preds, average=self.average, multi_class='ovo', labels=[0,1,2])
        return torch.tensor([score], dtype=torch.float32, device='cpu')

# VIT module

In [7]:
class VITFineTuner(pl.LightningModule):
    def __init__(self, model_version, num_labels, train_dataset, val_dataset, batch_size):
        super(VITFineTuner, self).__init__()
        
        self.vit = ViTForImageClassification.from_pretrained(model_version,
                                                              num_labels=num_labels)
        

        self.train_dataset = train_dataset
        self.val_dataset = val_dataset

        self.batch_size = batch_size
        self.num_labels = num_labels

        self.save_hyperparameters()
        

        self.train_rocauc = ROCAUC(average='macro')
        self.val_rocauc = ROCAUC( average='macro')
        
    def forward(self, X):
        outputs = self.vit(X)
        return outputs.logits
    
    def common_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        if self.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        elif self.num_labels == 1:
            preds = logits.squeeze()
        
        return preds, y, loss, logits

    def training_step(self, batch, batch_idx):
        preds, y, loss, _ = self.common_step(batch, batch_idx)

        self.log('train/loss', loss, on_epoch=True)
        self.train_rocauc(y, preds)
        self.log('train/rocauc', self.train_rocauc, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        preds, y, loss, logits = self.common_step(batch, batch_idx)

        self.val_rocauc(y, preds)
        self.log("validation/loss_epoch", loss)
        self.log("validation/roc_auc_epoch", self.val_rocauc)
        return preds, y, logits
    
    def validation_epoch_end(self, validation_step_outputs):
        preds = torch.cat([e[0] for e in validation_step_outputs]).cpu().detach().numpy()
        y = torch.cat([e[1] for e in validation_step_outputs]).cpu().detach().numpy()

        logits = torch.cat([e[2] for e in validation_step_outputs]).cpu().detach().numpy()
        self.logger.experiment.log({"conf_mat" : wandb.plot.confusion_matrix( probs=None, y_true=y, preds=preds, class_names=["0", "1", "2"])})
        self.logger.experiment.log(
            {"validation/logits": wandb.Histogram(logits),
             "global_step": self.global_step}, commit=False
        )


    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.train_dataset, 
                                                   batch_size=self.batch_size, 
                                                  shuffle=True, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)
        return val_loader


# Draw images in wandb

In [8]:
class ImageLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=None):
        super().__init__()
        if not num_samples:
            num_samples = len(val_samples[0])
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, 1)
        
        mistakes = [preds.detach().cpu() != self.val_labels.detach().cpu()]
        trainer.logger.experiment.log({
            "Mistakes": [wandb.Image(x, caption=f"Pred: {pred}, Label: {y}") for x, pred, y in zip(val_imgs[mistakes], 
                                                                                                   preds[mistakes], 
                                                                                                   self.val_labels[mistakes])][:10],
            "global_step": trainer.global_step
        }, commit=False)


# Train single model

In [9]:
import re 

# PARAMS
model_version = MODEL_VESRION
num_labels = 3
batch_size = 32
num_epochs = 3


# SPLIT, BALANCE DATA
train_df, valid_df = train_test_split(data_df, test_size=0.2, random_state=10)


# BALANCE BY MIN
g = train_df.groupby('class')
train_df = g.apply(lambda x: x.sample(g.size().min(), random_state=44)).reset_index(drop=True)


#    CREATE TORCH DATASETS
train_dataset = ImageDataset(train_df, _train_transforms)
valid_dataset = ImageDataset(valid_df, _val_transforms)
samples_loader =  torch.utils.data.DataLoader(valid_dataset, batch_size=len(valid_dataset))

#    INIT LOGGER
_model_version = re.sub(r'[/\\]', '_', MODEL_VESRION)
wandb_logger = WandbLogger(project=f"garbage_cross_val_{_model_version}_run_single")

#    INIT MODEL
finetuner = VITFineTuner(model_version, num_labels, train_dataset, valid_dataset, batch_size)

#    TRAINER
trainer = pl.Trainer(
    logger=wandb_logger,    # W&B integration
    log_every_n_steps=13,   # set the logging frequency
    gpus=-1,                # use all GPUs
    max_epochs=num_epochs,           # number of epochs
    deterministic=True,     # keep it deterministic
    callbacks=[ImageLogger(next(iter(samples_loader)))] # see Callbacks section
    )
trainer.fit(finetuner)
wandb.finish()

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='3.155 MB of 3.155 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▅▅▅██
global_step,▁▃▆█
train/loss_epoch,█▄▁
train/loss_step,▁
train/rocauc_epoch,▁▇█
train/rocauc_step,▁
trainer/global_step,▁▁▃▃▅▅▅██
validation/loss_epoch,█▄▁
validation/roc_auc_epoch,▁▅█

0,1
epoch,2.0
global_step,24.0
train/loss_epoch,0.54029
train/loss_step,0.77921
train/rocauc_epoch,0.97897
train/rocauc_step,0.93182
trainer/global_step,23.0
validation/loss_epoch,0.57377
validation/roc_auc_epoch,0.90188


# Cross Validation

In [10]:
import re
def model_cross_validation(data, num_folds, model_module, model_version, num_labels, batch_size, num_epochs):
    result_metrics = []
    
    train_transforms, val_transforms = init_transforms(224)
    
    skf = KFold(num_folds, shuffle=True, random_state=44)
    for train_index, val_index in skf.split(data['ID_img'], data['class']):
        
        train_df = data.iloc[train_index]
        
#         # BALANCE TRAIN!
#         class_counts = train_df[train_df['class'] != 1]['class'].value_counts().to_dict()
#         max_elements = max(class_counts.values())
        
#         train_df = pd.concat([train_df[train_df['class'] == 1].sample(max_elements),
#                                train_df[train_df['class'] == 0],
#                                train_df[train_df['class'] == 2]]).sample(frac=1).reset_index(drop=True)
#         #######
        
        g = train_df.groupby('class')
        train_df = g.apply(lambda x: x.sample(g.size().min(), random_state=44)).reset_index(drop=True)

        val_df = data.iloc[val_index]
        train_dataset = ImageDataset(train_df, train_transforms)
        valid_dataset = ImageDataset(val_df, val_transforms)
        samples_loader =  torch.utils.data.DataLoader(valid_dataset, batch_size=len(valid_dataset))
        
        _model_version = re.sub(r'[/\\]', '_', model_version)
        wandb_logger = WandbLogger(project=f"garbage_cross_val_{_model_version}_run_cv_2")


        vit_finetuner = model_module(model_version, num_labels, train_dataset, valid_dataset, batch_size)

        trainer = pl.Trainer(
            logger=wandb_logger,    # W&B integration
            log_every_n_steps=13,   # set the logging frequency
            gpus=-1,                # use all GPUs
            max_epochs=num_epochs,           # number of epochs
            deterministic=True,     # keep it deterministic
            callbacks=[ImageLogger(next(iter(samples_loader)))] # see Callbacks section
            )
        trainer.fit(vit_finetuner)
        wandb.finish()
        result_metrics.append(trainer.logged_metrics)
        
    return result_metrics
    

In [11]:
cv_scores = model_cross_validation(data_df, 5, VITFineTuner, MODEL_VESRION, 3, 32, 3)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='3.116 MB of 3.116 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▅▅▅███
global_step,▁▃▆█
train/loss_epoch,█▄▁
train/loss_step,█▁
train/rocauc_epoch,▁██
train/rocauc_step,▁█
trainer/global_step,▁▁▃▃▅▅████
validation/loss_epoch,█▄▁
validation/roc_auc_epoch,▁▂█

0,1
epoch,2.0
global_step,27.0
train/loss_epoch,0.46713
train/loss_step,0.36702
train/rocauc_epoch,0.97602
train/rocauc_step,1.0
trainer/global_step,26.0
validation/loss_epoch,0.54526
validation/roc_auc_epoch,0.81511


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Validation: 0it [00:00, ?it/s]

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Validation: 0it [00:00, ?it/s]

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='3.146 MB of 3.146 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▅▅▅██
global_step,▁▃▆█
train/loss_epoch,█▄▁
train/loss_step,▁
train/rocauc_epoch,▁▇█
train/rocauc_step,▁
trainer/global_step,▁▁▃▃▅▅▅██
validation/loss_epoch,█▄▁
validation/roc_auc_epoch,▁██

0,1
epoch,2.0
global_step,24.0
train/loss_epoch,0.56024
train/loss_step,0.77344
train/rocauc_epoch,0.97101
train/rocauc_step,0.92917
trainer/global_step,23.0
validation/loss_epoch,0.59578
validation/roc_auc_epoch,0.8919


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='2.934 MB of 2.934 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▅▅▅██
global_step,▁▃▆█
train/loss_epoch,█▄▁
train/loss_step,▁
train/rocauc_epoch,▁▇█
train/rocauc_step,▁
trainer/global_step,▁▁▃▃▅▅▅██
validation/loss_epoch,█▄▁
validation/roc_auc_epoch,▁▄█

0,1
epoch,2.0
global_step,24.0
train/loss_epoch,0.53077
train/loss_step,0.75505
train/rocauc_epoch,0.96311
train/rocauc_step,0.925
trainer/global_step,23.0
validation/loss_epoch,0.52923
validation/roc_auc_epoch,0.89896


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='3.252 MB of 3.252 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▅▅▅██
global_step,▁▃▆█
train/loss_epoch,█▄▁
train/loss_step,▁
train/rocauc_epoch,▁▇█
train/rocauc_step,▁
trainer/global_step,▁▁▃▃▅▅▅██
validation/loss_epoch,█▄▁
validation/roc_auc_epoch,▁▃█

0,1
epoch,2.0
global_step,24.0
train/loss_epoch,0.53567
train/loss_step,0.71678
train/rocauc_epoch,0.95816
train/rocauc_step,0.93304
trainer/global_step,23.0
validation/loss_epoch,0.59951
validation/roc_auc_epoch,0.8969


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='3.142 MB of 3.142 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▅▅▅██
global_step,▁▃▆█
train/loss_epoch,█▄▁
train/loss_step,▁
train/rocauc_epoch,▁██
train/rocauc_step,▁
trainer/global_step,▁▁▃▃▅▅▅██
validation/loss_epoch,█▄▁
validation/roc_auc_epoch,▁▆█

0,1
epoch,2.0
global_step,24.0
train/loss_epoch,0.58655
train/loss_step,0.80847
train/rocauc_epoch,0.96601
train/rocauc_step,0.91667
trainer/global_step,23.0
validation/loss_epoch,0.59172
validation/roc_auc_epoch,0.91591


# CV Score

In [12]:
print('Average ROC AUC: ', np.mean([ e['validation/roc_auc_epoch'].detach().cpu().numpy() for e in cv_scores]))

Average ROC AUC:  0.88375443


# Train full data

In [13]:
data_df = pd.read_csv(TRAIN_DF_PATH)
data_df['additional'] = 0


# BALANCE BY MIN
g = data_df.groupby('class')
data_df = g.apply(lambda x: x.sample(g.size().min(), random_state=44)).reset_index(drop=True)

train_dataset = ImageDataset(data_df, _train_transforms)
val_dataset = ImageDataset(data_df.iloc[:33], _val_transforms)

samples_loader =  torch.utils.data.DataLoader(val_dataset, batch_size=len(val_dataset))

finetuner = VITFineTuner(MODEL_VESRION, 3, train_dataset, val_dataset, 32)


wandb_logger = WandbLogger(project=f"train_run")
trainer = pl.Trainer(
            logger=wandb_logger,    # W&B integration
            log_every_n_steps=13,   # set the logging frequency
            gpus=-1,                # use all GPUs
            max_epochs=3,           # number of epochs
            deterministic=True,     # keep it deterministic
            callbacks=[ImageLogger(next(iter(samples_loader)))] # see Callbacks section
            )
trainer.fit(finetuner)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                not been set for

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                      | Params
-----------------------------------------------------------
0 | vit          | ViTForImageClassification | 85.8 M
1 | train_rocauc | ROCAUC                    | 0     
2 | val_rocauc   | ROCAUC                    | 0     
-----------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.204   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Validation: 0it [00:00, ?it/s]

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Validation: 0it [00:00, ?it/s]

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


# Submit 

In [14]:
test_df = pd.read_csv("./sample_solution.csv")
test_df = test_df.drop(["class"], axis = 1)


In [15]:
class TestImageDataset(Dataset):
    def __init__(self, data_df, transform=None):
        self.data_df = data_df
        self.transform = transform

    def __getitem__(self, idx):
        image_name = self.data_df.iloc[idx]['ID_img']
        

        image = cv2.imread(f"{TEST_IMGS_PATH}/{image_name}.jpg")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # МОЖНО ЗАПРИНТИТЬ КАРТИНКУ
#         visualize(image)
        
        if self.transform:
            image = self.transform(image=image)['image']
        
        return image
    
    def __len__(self):
        return len(self.data_df)

In [16]:
test_dataset = TestImageDataset(test_df, _val_transforms)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=1,
                                           shuffle=False,
                                           pin_memory=True)

In [17]:
from tqdm import tqdm

finetuner.eval()
predicts = []

for imgs in tqdm(test_loader):
    
    imgs = imgs
    pred = finetuner(imgs)

    for class_obj in pred:
      index, max_value = max(enumerate(class_obj), key=lambda i_v: i_v[1])
      predicts.append(index)
        # МОЖНО ЗАПРИНТИТЬ КАРТИНКУ
#       print(index)
#       plt.show()
    

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 225/225 [01:20<00:00,  2.78it/s]


In [18]:
test_df['class'] = predicts

In [19]:
test_df.to_csv('./submit.csv', index=False)