**Modelo ViT para la clasificación de imágenes**

Descargamos algunas librerias necesarias para la implementación

In [None]:
!pip install einops

Importamos las librerías

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from einops import rearrange, reduce
from einops.layers.torch import Rearrange, Reduce

Cargamos los datos

In [None]:
transforms = Compose([Resize((224, 224)), ToTensor()])

training_data = ImageFolder(root="../input/iais22-birds/birds/birds", transform = transforms)
test_data = ImageFolder(root="../input/iais22-birds/submission_test", transform = transforms)

train_set, test_set = random_split(training_data, (int(len(training_data) * 0.7) + 1, int(len(training_data) * 0.3)))

train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=64, shuffle=True)

print(f"Training data size: {train_set}")

Creamos un diccionario que mapea id de la clase con su nombre

In [None]:
clases_list = training_data.classes
clases = {}
cont = 0
for i in clases_list:
    clases[cont] = i
    cont+=1
print(clases)

Comprobamoos que tanto las imagenes como las targets se han guardado correctamente

In [None]:
train_features, train_labels = training_data.__getitem__(0)
print(f"Tamaño de cada imagen: {train_features.size()}")
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(clases[label])
    plt.axis("off")
    plt.imshow(img[1][:][:], cmap="gray")
plt.show()

Comprobamos que los DataLoaders funcionan correctamente

In [None]:
for X, y in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Comprobamos si está disponible la GPU

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Comenzamos a construir la arquitectura ViT. Usaremos una técnica modular de contruir la estructura poco a poco con distintas clases que heredan de nn.Module(). Empezamos construyendo el PachEmbedding capaz de trocear la imagen en imagenes de 16x16 y asignarles la posicion (parametro que se aprende) siguiendo el paper *AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE*

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size: int = 224, channels: int = 3, section: int = 16, output_net: int = 768):
        super().__init__()
        self.positions = nn.Parameter(torch.randn((img_size // section) **2, output_net))
        self.pos_drop = nn.Dropout(p=0.1)
        self.network = nn.Sequential(
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=section, s2=section),
            nn.Linear(section * section * channels, output_net)
        )
                
    def forward(self, images):
        images = self.network(images)
        images = self.pos_drop(images + self.positions)
        return images

Para contruir el transformer construiremos sus partes paso a paso. Construimos el módulo Multihead Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, output_net: int = 512, num_heads: int = 8):
        super().__init__()
        self.output_net = output_net
        self.num_heads = num_heads
        self.keys = nn.Linear(output_net, output_net)
        self.queries = nn.Linear(output_net, output_net)
        self.values = nn.Linear(output_net, output_net)
        self.network = nn.Linear(output_net, output_net)
        
    def forward(self, images):
        #Creamos las matrices queries, keys y values haciendo una subdivision de las obtenidas de la red
        queries = rearrange(self.queries(images), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(images), "b n (h d) -> b h n d", h=self.num_heads)
        values  = rearrange(self.values(images), "b n (h d) -> b h n d", h=self.num_heads)
        
        atencion = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) 
        tamano = self.output_net ** (1/2)
        atencion=atencion/tamano
        att = F.softmax(atencion, dim=-1)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.network(out)
        return out

Definimos el componente del Transformer FeedForward que alimenta a la red hacia adelante

In [None]:
class FeedForward(nn.Module):
    def __init__(self, output: int = 768):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(output, 4 * output),
            nn.GELU(),
            nn.Linear(output * 4 , output),
            nn.Dropout(0.1)
        )
        
    def forward(self, images):
        images=self.network(images)
        return images


Creamos las conexiones residuales

In [None]:
class ResidualConection(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, images):
        res = images
        images = self.fn(images)
        images += res
        return images

Finalmente construimos el Transformer Encoder

In [None]:
output = 768
capa0=ResidualConection(
                nn.Sequential(
                    nn.LayerNorm(output),
                    MultiHeadAttention(output),
                    nn.Dropout(0.1)
                )
            )
capa1=ResidualConection(
                nn.Sequential(
                    nn.LayerNorm(output),
                    FeedForward(output),
                    nn.Dropout(0.1)
                ),
            )

class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(capa0, capa1)
        
    def forward(self, images):
        images = self.network(images)
        return images


Definimos el módulo Transformer como composicion de varios TransformerEncorders siguiendo el paper *All you need is Attention*

In [None]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(TransformerEncoder(),
                                     TransformerEncoder(),
                                     TransformerEncoder(),
                                     TransformerEncoder(),
                                     TransformerEncoder(),
                                     TransformerEncoder(),
                                     TransformerEncoder(),
                                     TransformerEncoder(),
                                    )
        
    def forward(self, images):
        images= self.network(images)
        return images

Creamos el módulo MLPHead capaz de dar la clasificación de la imagen

In [None]:
class MLPHead(nn.Module):
    def __init__(self, output: int = 768, n_classes: int = 400):
        super().__init__()
        self.network = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(output), 
            nn.Linear(output, n_classes))
        
    def forward(self, images):
        images=self.network(images)
        return images

Definimos el módulo ViT como composicion de los anteriores modulos creados siguiendo el paper

In [None]:
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            PatchEmbedding(),
            Transformer(),
            MLPHead()
        )
        
    def forward(self, images):
        images = self.network(images)
        return images

Definimos el modelo, seleccionamos la GPU para el entrenamiento y la funcion de perdida y de optimizacion

In [None]:
model = ViT()
model = torch.load("./model105.pth")
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

Definimos las funciones de entrenamiento y testeo del modelo

In [None]:
def train(train_dataloader, model, loss_fn, optimizer):
    size = len(train_dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
            


Entrenamos y guardamos al modelo con distintos numero de epocas

In [None]:
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")
torch.save(model, "model107pth")
print("Model saved")

Creamos el csv submission

In [None]:
from os import remove
remove("submission.csv")
file_object = open('submission.csv', 'a')
file_object.write('Id,Category')
file_object.write("\n")
submision_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)
model.eval()
id = 0
with torch.no_grad():
    for x, y in submision_dataloader:  
        x, y = x.to(device), y.to(device)
        pred = model(x)
        file_object.write(f"{test_data.imgs[id]}"[56:].split(".")[0])
        file_object.write(",")
        file_object.write(f"{clases[pred.argmax(1).item()]}")
        file_object.write("\n")
        id = id +1
        
file_object.close()
print("Done!")

Comprobamos el número de parámetros que tiene nuestro modelo

In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary
summary(ViT(), (3, 224, 224), device='cpu')