In [None]:
!pip install timm

In [None]:
import os
import random
import numpy as np
import torch
import cv2
import timm
import pandas as pd
import torchvision.transforms as transform
from sklearn import model_selection
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results
    
    Arguments:
        seed {int} -- Number of the seed
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(1001)

In [None]:
# general global variables
DATA_PATH = "../input/cassava-leaf-disease-classification"
TRAIN_PATH = "../input/cassava-leaf-disease-classification/train_images/"
TEST_PATH = "../input/cassava-leaf-disease-classification/test_images/"
MODEL_PATH = (
    "../input/vit-model-pretrain/jx_vit_base_p16_224-80ecf9dd.pth"
)

In [None]:
IMAGE_SIZE = 224
LR = 2e-5
N_EPOCHS = 20
BATCH_SIZE = 16

In [None]:
train_csv_path = os.path.join(DATA_PATH, 'train.csv')
assert os.path.exists(train_csv_path), '{} path is not exists...'.format(train_csv_path)

all_data = pd.read_csv(train_csv_path)
all_data.head()

In [None]:
all_data.label.value_counts().plot(kind='bar')

In [None]:
train_df, valid_df = model_selection.train_test_split(all_data, test_size=0.1, random_state=42, stratify=all_data.label.values)
train_df.label.value_counts().plot(kind='bar')

In [None]:
class CustomDataset(Dataset):
    
    def __init__(self, df, data_path, transform=None):
        super().__init__()
        
        self.img_id = df['image_id'].values
        self.label = df['label'].values
        self.path = data_path
        self.transform = transform
    
    def __len__(self):
        return len(self.img_id)
    
    def __getitem__(self, idx):
        
        img_path = os.path.join(self.path, self.img_id[idx])
        assert os.path.exists(img_path), '{} img path is not exists...'.format(img_path)
        
        label = self.label[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image)
            
        return image, label  # label不需要转换为tensor，在DataLoader中会通过collate_fn自动转换

In [None]:
timm.list_models('vit*')

In [None]:
transform_train = transform.Compose([
    transform.ToPILImage(),
    transform.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transform.RandomHorizontalFlip(p=0.3),
    transform.RandomVerticalFlip(p=0.3),
    transform.RandomResizedCrop(IMAGE_SIZE),
    transform.ToTensor(),
    transform.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_valid = transform.Compose([
    transform.ToPILImage(),
    transform.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transform.ToTensor(),
    transform.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
class ViT(nn.Module):
    
    def __init__(self, num_classes, model_name='vit_base_patch16_224', pretrained=False):
        
        super().__init__()
        
        self.model = timm.create_model(model_name, pretrained=pretrained)
        
        if pretrained:
            self.model.load_state_dict(torch.load(MODEL_PATH))
        
        self.model.head = nn.Linear(self.model.head.in_features, num_classes)
        
    def forward(self, x):
        return self.model(x)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
train_dataset = CustomDataset(train_df, TRAIN_PATH, transform_train)
valid_dataset = CustomDataset(valid_df, TRAIN_PATH, transform_valid)

trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

model = ViT(num_classes=5, pretrained=True)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()


train_loss = []
train_acc = []
valid_loss = []
valid_acc = []

best_acc = 0
for epoch in range(N_EPOCHS):
    
    train_epoch_loss = 0.0
    train_epoch_acc = 0.0
    
    model.train()
    train_bar = tqdm(trainloader)
    for i, (img, label) in enumerate(train_bar):
        
        img = img.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        output = model(img)
        losses = criterion(output, label)
        
        train_epoch_acc += (output.argmax(dim=1) == label).sum()
        train_epoch_loss += losses.item()
        losses.backward()
        optimizer.step()
    
    print('train epoch acc:{}'.format(train_epoch_acc))
    train_loss.append(train_epoch_loss / len(trainloader))
    train_acc.append(train_epoch_acc / len(train_dataset))
    print('train loss: {:.3f} train acc: {:.3f}'.format(train_epoch_loss / len(trainloader), train_epoch_acc / len(train_dataset)))
    valid_epoch_loss = 0.0
    valid_epoch_acc = 0.0
    
    model.eval()
    valid_bar = tqdm(validloader)
    for i, (img, label) in enumerate(valid_bar):
        
        img = img.to(device)
        label = label.to(device)
        
        with torch.no_grad():
            output = model(img)
            
        losses = criterion(output, label)
        valid_epoch_loss += losses.item()
        valid_epoch_acc += (output.argmax(dim=1) == label).sum()
    
    valid_loss.append(valid_epoch_loss / len(validloader))
    valid_acc.append(valid_epoch_acc / len(valid_dataset))
    
    print('valid loss: {:.3f} valid acc: {:.3f}'.format(valid_epoch_loss / len(validloader),valid_epoch_acc / len(valid_dataset)))