In [None]:
!pip install timm

In [None]:
import torch
import random
import os
import numpy as np
import torch.nn as nn
import pandas as pd
import math
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sn
from torch.utils.data import DataLoader
from PIL import Image
import timm
import torchvision.transforms as transforms

In [None]:
!nvidia-smi

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

### Data Preprocessing

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32

In [None]:
data_df = pd.read_csv('../input/medical-masks-part1/df.csv')

In [None]:
data_df.head()

In [None]:
# Number of labels
data_df['TYPE'] = data_df['TYPE'].values - 1

#### Dataset Class

In [None]:
class MaskDataset():
    
    def __init__(self, image_paths, labels, transforms=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        file_name = "../input/medical-masks-part1/images/" + self.image_paths[idx]
        image = Image.open(file_name)
        
        if self.transforms is not None:
            image = self.transforms(image)
        
        return image, self.labels[idx]

### Transforms

In [None]:
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomResizedCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

transforms_test = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

### Remove Error Data

Running this code would generate a error file of 


In [None]:
# error_file = list()

In [None]:
# for idx, name in tqdm(enumerate(data_df['name'].values)):
#     try:
#         file_name = "../input/medical-masks-part1/images/" + name
#         image = Image.open(file_name)
#         image = transforms_train(image)
#     except:
#         error_file.append(idx)
#         print(name)

In [None]:
# data_df.drop(error_file, inplace=True)

In [None]:
error_list = ["000030_1_000030_NONE_29.jpg", "007790_1_005591_NONE_27.jpg", "009065_2_006163_MALE_21.jpg", "009065_3_006163_MALE_21.jpg"]

for i, data in enumerate(data_df.name.values):
    if data in error_list:
        data_df.drop(i, inplace=True)

In [None]:
len(data_df)

#### Train Validation Test

In [None]:
train_val_image, test_image, train_val_label, test_label = train_test_split(data_df.name.values, data_df.TYPE.values, test_size=0.1, stratify=data_df.TYPE.values)

In [None]:
train_image, validation_image, train_label, validation_label = train_test_split(train_val_image, train_val_label, test_size=(1/9), stratify=train_val_label)

In [None]:
len(train_image), len(validation_image), len(test_image)

In [None]:
train_dataset = MaskDataset(train_image, train_label, transforms_train)

In [None]:
validation_dataset = MaskDataset(validation_image, validation_label, transforms_test)

In [None]:
test_dataset = MaskDataset(test_image, test_label, transforms_test)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Training Loop

In [None]:
def train(model, train_loader, optimizer, criterion, device, scheduler=None):
    total_correct = 0.0
    total_loss = 0.0
    batch_loss = 0.0
    batch_correct = 0.0

    model.train()
    for idx, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        # forward
        output = model(inputs)
        # calculate loss
        loss = criterion(output, labels)
        loss.backward()
        
        correct = (output.argmax(dim=1) == labels).sum().item()
        
        total_correct += correct
        batch_correct += correct/len(labels)
        total_loss += loss.item()
        batch_loss += loss.item()
        
        optimizer.step()

        if (idx + 1) % 500 == 0:
            print(f"Batch Number {idx + 1}: Average Loss {batch_loss/500} Average Accuracy {batch_correct/500}")
            batch_correct = 0.0
            batch_loss = 0.0

        if scheduler is not None:
            scheduler.step()
            
    return total_correct/len(train_loader.dataset), total_loss/len(train_loader)

### Evaluation Loop

In [None]:
def evaluate(model, test_loader, criterion, device):
    total_correct = 0.0
    total_loss = 0.0
    
    model.eval()
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # forward
        output = model(inputs)
        # calculate loss
        loss = criterion(output, labels)
        total_correct += (output.argmax(dim=1) == labels).sum().item()
        total_loss += loss.item()
        
    return total_correct/len(test_loader.dataset), total_loss/len(test_loader)

### Training Preparation

In [None]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
### number of unique classes == 4
model.head = nn.Linear(768, 4)

In [None]:
LR = 2e-05
EPOCHS = 7

In [None]:
!nvidia-smi

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.995)
device = "cuda"

In [None]:
model.to(device)

In [None]:
def run(model, train_loader, validation_loader, optimizer, criterion, device, scheduler=None):
    current_patience = 0
    previous_valid_loss = None
    best_valid_loss = None

    for epoch in tqdm(range(EPOCHS)):
        print("==================================================")
        print(f"EPOCH {epoch + 1}")
        train_accuracy, train_loss = train(model, train_loader, optimizer, criterion, device, scheduler=None)
        print(f"[TRAIN] EPOCH {epoch + 1} - LOSS: {train_loss}, ACCURACY: {train_accuracy}")
        validation_accuracy, validation_loss = evaluate(model, validation_loader, criterion, device)
        print(f"[VALIDATE] EPOCH {epoch + 1} - LOSS: {validation_loss}, ACCURACY: {validation_accuracy}")
        print("==================================================")

        # first iteration
        if best_valid_loss is None and previous_valid_loss is None:
            best_valid_loss = validation_loss
            torch.save(model, "./model.pth")
            previous_valid_loss = validation_loss
            continue
        # early stopping
        if validation_loss >= previous_valid_loss:
            current_patience += 1
            if current_patience >= 2:
                print("Early Stop")
                break
            previous_valid_loss = validation_loss

        # save best model
        if validation_loss <= best_valid_loss:
            torch.save(model, "./model.pth")
            best_valid_loss = validation_loss
            patience = 0

In [None]:
run(model, train_loader, validation_loader, optimizer, criterion, device, scheduler)

In [None]:
%cd /kaggle/working

In [None]:
from IPython.display import FileLink
display(FileLink(r'model.pth'))