In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [12]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
import pandas as pd

from PIL import Image
import os

from sklearn.model_selection import train_test_split

import time 
from tqdm import tqdm

from torchvision.datasets import ImageFolder

TRAIN_DATASET = "/content/gdrive/MyDrive/SantaHack/SantaDataset_final/train"
VAL_DATASET = "/content/gdrive/MyDrive/SantaHack/SantaDataset_final/val"
TEST_DATASET = "/content/gdrive/MyDrive/SantaHack/SantaDataset_final/test"
batch_size = 4
img_size = 224

def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight    


# make slight augmentation and normalization on ImageNet statistics
trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Resize((img_size, img_size)),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((img_size, img_size)),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_ds = ImageFolder(TRAIN_DATASET, transform=trans)
val_ds = ImageFolder(VAL_DATASET, transform=val_trans)
test_ds = ImageFolder(TEST_DATASET, transform=val_trans)


# deal with class disbalance
weights = make_weights_for_balanced_classes(train_ds.imgs, len(train_ds.classes))                                                                
weights = torch.DoubleTensor(weights)                                       
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

In [13]:
from torch.utils.data import DataLoader
import torch


def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=batch_size, sampler = sampler)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=batch_size)

In [14]:
id2label = {0: 'Никого нет', 1: 'Дед Мороз', 2: 'Санта Клаус'}
label2id = {'Никого нет': 0, 'Дед Мороз': 1, 'Санта Клаус': 2}

In [5]:
!pip install -q pytorch_lightning
!pip install -q transformers

# model class

In [22]:
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn as nn
from torch import optim as optim




class ViTLightningModule(pl.LightningModule):
    def __init__(self, num_labels=3):
        super(ViTLightningModule, self).__init__()
        self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                              num_labels=num_labels,
                                                              id2label=id2label,
                                                              label2id=label2id)

    def forward(self, pixel_values):
        outputs = self.model(pixel_values=pixel_values)
        return outputs.logits

    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy

    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=5e-5)
        return optimizer


# Train

In [23]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min',
)

model = ViTLightningModule()
trainer = Trainer(gpus=1, callbacks=[EarlyStopping(monitor='val_loss')])
trainer.fit(model, train_dataloader, val_dataloader)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Test (f1_score)

In [28]:
torch.save(model.state_dict(), "/content/gdrive/MyDrive/SantaHack/ViT_checkpoints/model.pt")

In [25]:
# trainer.test(model)
import numpy as np
from sklearn.metrics import f1_score

preds = []
y_true = []
model.eval()
with torch.no_grad():
    for x in test_dataloader:
        out = model(x['pixel_values'])
        preds.append(np.argmax(np.array(out), axis=1))
        y_true.append(np.int32(x['labels']))

    preds = np.array(preds).flatten()
    y_true = np.array(y_true).flatten()
    f1_score_result = f1_score(y_true, preds, average='weighted')

print(f1_score_result)

0.9428682513926939


In [26]:
print(preds)
print(y_true)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 1 1 1 1 2 1 1 2 1 1 1 1 1 1 2 2 1 2 1 1 1 1 1 2 1 1 1 1
 1 2 1 1 1 1 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2]
