In [1]:
%%bash
pip install transformers -q



### **Library Imports**

In [2]:
import os
import re
import cv2
import torch
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from time import time
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torch.nn.utils import weight_norm as WN
from torchvision import transforms

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from transformers import AutoFeatureExtractor, SwinForImageClassification, SwinModel, SwinConfig

### **Utilities and Constants**

In [3]:
SEED = 42

le = LabelEncoder()

NAMES = [
    "swin-tiny-patch4-window7-224",
    "swin-small-patch4-window7-224",
    "swin-base-patch4-window7-224",
    "swin-base-patch4-window12-384",
]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRANSFORM = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize([0.485, 0.456, 0.406], 
                                                     [0.229, 0.224, 0.225]),
                               ])

CLASSIFIER_HEAD_PATH = "../input/fgvc9-swin-b384-features-train-slow/saves"

SAVE_PATH = "saves"
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

In [4]:
def breaker(num: int=50, char: str="*") -> None:
    print("\n" + num*char + "\n")

    
def get_image(path: str) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(384, 384), interpolation=cv2.INTER_AREA)

### **Dataset Template**

In [5]:
class FEDS(Dataset):
    def __init__(self, base_path: str, filenames: np.ndarray, transform):
        self.base_path = base_path
        self.filenames = filenames
        self.transform = transform
    
    def __len__(self):
        return self.filenames.shape[0]
    
    def __getitem__(self, idx):
        return self.transform(get_image(os.path.join(self.base_path, self.filenames[idx])))

    
class DS(Dataset):
    def __init__(self, X=None, y=None, mode="train"):
        self.mode = mode

        assert(re.match(r"train", self.mode, re.IGNORECASE) or re.match(r"valid", self.mode, re.IGNORECASE) or re.match(r"test", self.mode, re.IGNORECASE))

        self.X = X
        if re.match(r"train", self.mode, re.IGNORECASE) or re.match(r"valid", self.mode, re.IGNORECASE):
            self.y = y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        if re.match(r"train", self.mode, re.IGNORECASE) or re.match(r"valid", self.mode, re.IGNORECASE):
            return torch.FloatTensor(self.X[idx]), torch.LongTensor(self.y[idx])
        else:
            return torch.FloatTensor(self.X[idx])

### **Model**

In [6]:
class Model(nn.Module):
    def __init__(self, IL: int=None, HL: list=None, DP: float=None, use_WN: bool=False):
        super(Model, self).__init__()

        self.model = nn.Sequential()

        if len(HL) == 1:
            self.model.add_module("BN1", nn.BatchNorm1d(num_features=IL, eps=1e-5))
            if use_WN: 
                self.model.add_module("FC1", WN(nn.Linear(in_features=IL, out_features=HL[0])))
            else: 
                self.model.add_module("FC1", nn.Linear(in_features=IL, out_features=HL[0]))
            if isinstance(DP, float):
                self.model.add_module("DP1", nn.Dropout(p=DP))
            self.model.add_module("AN1", nn.ReLU())
            self.model.add_module("BN2", nn.BatchNorm1d(num_features=HL[0], eps=1e-5))
            if use_WN: 
                self.model.add_module("FC2", WN(nn.Linear(in_features=HL[0], out_features=100)))
            else:
                self.model.add_module("FC2", nn.Linear(in_features=HL[0], out_features=100))

        
        elif len(HL) == 2:
            self.model.add_module("BN1", nn.BatchNorm1d(num_features=IL, eps=1e-5))
            if use_WN: 
                self.model.add_module("FC1", WN(nn.Linear(in_features=IL, out_features=HL[0])))
            else: 
                self.model.add_module("FC1", nn.Linear(in_features=IL, out_features=HL[0]))
            if isinstance(DP, float):
                self.model.add_module("DP1", nn.Dropout(p=DP))
            self.model.add_module("AN1", nn.ReLU())
            self.model.add_module("BN2", nn.BatchNorm1d(num_features=HL[0], eps=1e-5))
            if use_WN:
                self.model.add_module("FC2", WN(nn.Linear(in_features=HL[0], out_features=HL[1])))
            else:
                self.model.add_module("FC2", nn.Linear(in_features=HL[0], out_features=HL[1]))
            if isinstance(DP, float):
                self.model.add_module("DP2", nn.Dropout(p=DP))
            self.model.add_module("AN2", nn.ReLU())
            self.model.add_module("BN3", nn.BatchNorm1d(num_features=HL[1], eps=1e-5))
            if use_WN: 
                self.model.add_module("FC3", WN(nn.Linear(in_features=HL[1], out_features=100)))
            else:
                self.model.add_module("FC3", nn.Linear(in_features=HL[1], out_features=100))
        
        elif len(HL) == 3:
            self.model.add_module("BN1", nn.BatchNorm1d(num_features=IL, eps=1e-5))
            if use_WN: 
                self.model.add_module("FC1", WN(nn.Linear(in_features=IL, out_features=HL[0])))
            else: 
                self.model.add_module("FC1", nn.Linear(in_features=IL, out_features=HL[0]))
            if isinstance(DP, float):
                self.model.add_module("DP1", nn.Dropout(p=DP))
            self.model.add_module("AN1", nn.ReLU())
            self.model.add_module("BN2", nn.BatchNorm1d(num_features=HL[0], eps=1e-5))
            if use_WN:
                self.model.add_module("FC2", WN(nn.Linear(in_features=HL[0], out_features=HL[1])))
            else:
                self.model.add_module("FC2", nn.Linear(in_features=HL[0], out_features=HL[1]))
            if isinstance(DP, float):
                self.model.add_module("DP2", nn.Dropout(p=DP))
            self.model.add_module("AN2", nn.ReLU())
            self.model.add_module("BN3", nn.BatchNorm1d(num_features=HL[1], eps=1e-5))
            if use_WN:
                self.model.add_module("FC3", WN(nn.Linear(in_features=HL[1], out_features=HL[2])))
            else:
                self.model.add_module("FC3", nn.Linear(in_features=HL[1], out_features=HL[2]))
            if isinstance(DP, float):
                self.model.add_module("DP3", nn.Dropout(p=DP))
            self.model.add_module("AN3", nn.ReLU())
            self.model.add_module("BN4", nn.BatchNorm1d(num_features=HL[2], eps=1e-5))
            if use_WN: 
                self.model.add_module("FC4", WN(nn.Linear(in_features=HL[2], out_features=100)))
            else:
                self.model.add_module("FC4", nn.Linear(in_features=HL[2], out_features=100))
    
    def get_optimizer(self, lr=1e-3, wd=0.0):
        return optim.Adam(self.parameters(), lr=lr, weight_decay=wd)

    def get_plateau_scheduler(self, optimizer=None, patience=5, eps=1e-8):
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=patience, eps=eps)
    
    def forward(self, x):
        return nn.LogSoftmax(dim=1)(self.model(x))

### **Fit and Predict Helper**

In [7]:
# def fit(model=None, optimizer=None, scheduler=None, epochs=None, early_stopping_patience=None, dataloaders=None, fold=None, verbose=False) -> tuple:
    
#     def get_accuracy(y_pred, y_true):
#         y_pred = torch.argmax(y_pred, dim=1)
#         return torch.count_nonzero(y_pred == y_true).item() / len(y_pred)
    
#     if verbose:
#         breaker()
#         print(f"Training Fold {fold}...")
#         breaker()
        
#     bestLoss, bestAccs = {"train" : np.inf, "valid" : np.inf}, {"train" : 0.0, "valid" : 0.0}
#     Losses, Accuracies = [], []
#     name = f"state_fold_{fold}.pt"

#     start_time = time()
#     for e in range(epochs):
#         e_st = time()
#         epochLoss, epochAccs = {"train" : 0.0, "valid" : 0.0}, {"train" : 0.0, "valid" : 0.0}

#         for phase in ["train", "valid"]:
#             if phase == "train":
#                 model.train()
#             else:
#                 model.eval()
            
#             lossPerPass, accsPerPass = [], []
            
#             for X, y in dataloaders[phase]:
#                 X, y = X.to(DEVICE), y.to(DEVICE).view(-1)

#                 optimizer.zero_grad()
#                 with torch.set_grad_enabled(phase == "train"):
#                     output = model(X)
#                     loss = torch.nn.NLLLoss()(output, y)
#                     if phase == "train":
#                         loss.backward()
#                         optimizer.step()
#                 lossPerPass.append(loss.item())
#                 accsPerPass.append(get_accuracy(output, y))
#             epochLoss[phase] = np.mean(np.array(lossPerPass))
#             epochAccs[phase] = np.mean(np.array(accsPerPass))
#         Losses.append(epochLoss)
#         Accuracies.append(epochAccs)
        
#         if early_stopping_patience:
#             if epochLoss["valid"] < bestLoss["valid"]:
#                 bestLoss = epochLoss
#                 BLE = e + 1
#                 torch.save({"model_state_dict": model.state_dict(),
#                             "optim_state_dict": optimizer.state_dict()},
#                            os.path.join(SAVE_PATH, name))
#                 early_stopping_step = 0
#             else:
#                 early_stopping_step += 1
#                 if early_stopping_step > early_stopping_patience:
#                     print("\nEarly Stopping at Epoch {}".format(e + 1))
#                     break
        
#         if epochLoss["valid"] < bestLoss["valid"]:
#             bestLoss = epochLoss
#             BLE = e + 1
#             torch.save({"model_state_dict" : model.state_dict(),
#                         "optim_state_dict" : optimizer.state_dict()},
#                         os.path.join(SAVE_PATH, name))
        
#         if epochAccs["valid"] > bestAccs["valid"]:
#             bestAccs = epochAccs
#             BAE = e + 1
        
#         if scheduler:
#             scheduler.step(epochLoss["valid"])
        
#         if verbose:
#             print("Epoch: {} | Train Loss: {:.5f} | Valid Loss: {:.5f} |\
# Train Accs: {:.5f} | Valid Accs: {:.5f} | Time: {:.2f} seconds".format(e+1, 
#                                                                        epochLoss["train"], epochLoss["valid"], 
#                                                                        epochAccs["train"], epochAccs["valid"], 
#                                                                        time()-e_st))

#     if verbose:                                           
#         breaker()
#         print(f"Best Validation Loss at Epoch {BLE}")
#         breaker()
#         print(f"Best Validation Accs at Epoch {BAE}")
#         breaker()
#         print("Time Taken [{} Epochs] : {:.2f} minutes".format(len(Losses), (time()-start_time)/60))
    
#     return Losses, Accuracies, BLE, BAE, name


def predict(model=None, dataloader=None, path=None) -> np.ndarray:
    model.load_state_dict(torch.load(path, map_location=DEVICE)["model_state_dict"])
    model.to(DEVICE)    
    model.eval()
    
    y_pred = torch.zeros(1, 1).to(DEVICE)
    
    for X in dataloader:
        X = X.to(DEVICE)
        with torch.no_grad():
            output = torch.argmax(torch.exp(model(X)), dim=1)
        y_pred = torch.cat((y_pred, output.view(-1, 1)), dim=0)
    
    return y_pred[1:].detach().cpu().numpy()

### **Load and Preprocess Data**

In [8]:
train_df = pd.read_csv("../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv")
ss_df = pd.read_csv("../input/sorghum-id-fgvc-9/sample_submission.csv")

broken_images = [filename for filename in train_df.image if filename not in os.listdir("../input/sorghum-id-fgvc-9/train_images")]
for broken_image in broken_images:
    index = train_df.index[train_df.image == broken_image]
    train_df = train_df.drop(index=index)
    
filenames = train_df.image.copy().values
labels    = train_df.cultivar.copy().values 

labels = le.fit_transform(labels)

ts_filenames = ss_df.filename.copy().values
ts_features = np.zeros((len(ts_filenames), 1024))

### **Get SWIN Features**

In [9]:
pretrained = AutoFeatureExtractor.from_pretrained("microsoft/" + NAMES[3])
model = SwinModel.from_pretrained("microsoft/" + NAMES[3]).to(DEVICE)

dataloader_setup = FEDS("../input/sorghum-id-fgvc-9/test", ts_filenames, TRANSFORM)
dataloader = DL(dataloader_setup, batch_size=64, shuffle=False)

ts_features = torch.zeros(1, 1024).to(DEVICE)

for X in dataloader:
    X = X.to(DEVICE)
    with torch.no_grad():
        output = model(X).last_hidden_state
    ts_features = torch.cat((ts_features, output[:, 0, :]), dim=0)

ts_features = ts_features[1:].detach().cpu().numpy()

Downloading:   0%|          | 0.00/70.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/255 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/339M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/swin-base-patch4-window12-384 were not used when initializing SwinModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### **Params**

In [10]:
DEBUG: bool = False
    
if DEBUG:
    n_splits = 3
    epochs = 2
    HL = [4]
else:
    n_splits = 5
    epochs = 500
    HL = [2048, 1024, 512]

batch_size = 512
lr = 1e-6
wd = 1e-5
early_stopping = 50
patience = None
eps = None

### **Train**

In [11]:
# fold = 1

# start_time = time()
# for tr_idx, va_idx in StratifiedKFold(n_splits=n_splits, random_state=SEED, shuffle=True).split(features, labels):
#     tr_features, va_features, tr_labels, va_labels = features[tr_idx], features[va_idx], labels[tr_idx], labels[va_idx]
    
#     tr_data_setup = DS(tr_features, tr_labels.reshape(-1, 1), "train")
#     va_data_setup = DS(va_features, va_labels.reshape(-1, 1), "valid")

#     dataloaders = {
#         "train" : DL(tr_data_setup, batch_size=batch_size, shuffle=True, generator=torch.manual_seed(SEED)),
#         "valid" : DL(va_data_setup, batch_size=batch_size, shuffle=False),
#     }

#     torch.manual_seed(SEED)
#     model = Model(IL=tr_features.shape[1], HL=HL, DP=0.25, use_WN=True).to(DEVICE)
#     optimizer = model.get_optimizer(lr=lr, wd=wd)
#     scheduler = None
#     if isinstance(patience, int) and isinstance(eps, float):
#         scheduler = get_plateau_scheduler(optimizer=optimizer, patience=patience, eps=eps)
    
#     L, A, BLE, BAE, name = fit(model=model, optimizer=optimizer, scheduler=scheduler, 
#                            epochs=epochs, early_stopping_patience=early_stopping, 
#                            dataloaders=dataloaders, fold=fold, verbose=True)

#     breaker()
#     show_graphs(L, A)
#     fold += 1

# breaker()
# print(f"Total Time Taken for {n_splits}-Fold CV : {(time()-start_time)/60:.2f} minutes")
# breaker()

### **Submission**

In [12]:
ts_data_setup = DS(ts_features, None, "test")
ts_data = DL(ts_data_setup, batch_size=batch_size, shuffle=False)

torch.manual_seed(SEED)
model = Model(IL=ts_features.shape[1], HL=HL, DP=0.25, use_WN=True).to(DEVICE)

best_model_fold = 4
y_pred = predict(model=model, dataloader=ts_data, path=os.path.join(CLASSIFIER_HEAD_PATH, f"state_fold_{best_model_fold}.pt"))

ss_df["cultivar"] = le.inverse_transform(y_pred.astype("uint8"))
ss_df.to_csv("submission.csv", index=False)

  y = column_or_1d(y, warn=True)
