In [None]:
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import albumentations as A

In [None]:
# Set seed for reproducibility
SEED = 42

In [None]:
path = Path("../input/uisketch/")
model_path = Path("./models")
model_path.mkdir(exist_ok=True)

In [None]:
pretrained_model_url= "https://blackbox-toolkit.com/models/uisketch/resnet152/resnet152.pth"
pretrained_model_path = model_path / "uisketch_pretrained.pth"
download_url(pretrained_model_url, pretrained_model_path)

# Loading datasets

In [None]:
df = pd.read_csv(path / "labels.csv")
df.head()

In [None]:
train_df, test_df = train_test_split(df, test_size=0.1, random_state=SEED, stratify=df.label.values) # Split dataset for training/validation/evaluation

In [None]:
train_df.head()

In [None]:
class InvertImage(Transform):
        
    def encodes(self, img: PILImage):
        np_img = np.array(img)
        aug_img = A.transforms.InvertImg(p=1)(image=np_img)['image']
        return PILImage.create(aug_img)

In [None]:
transforms = [*aug_transforms(do_flip=False, pad_mode='border')]

uisketch = DataBlock(blocks=(ImageBlock, CategoryBlock),
                     get_x=ColReader(0, pref=path),
                     get_y=ColReader(1),
                     splitter=TrainTestSplitter(test_size=0.1, random_state=SEED, stratify=train_df.label.values),
                     item_tfms=[InvertImage()],
                     batch_tfms=[*transforms, Normalize.from_stats(*imagenet_stats)])

dataloaders = uisketch.dataloaders(train_df)

dataloaders.show_batch()

# ResNet 152 model

Load the pretrained weights from UISketch paper

In [None]:
learn = cnn_learner(dataloaders,
                    resnet152,
                    metrics=[accuracy, top_k_accuracy])

In [None]:
learn.load("uisketch_pretrained", strict=False) # As we are loading model from FastAI v1 to v2, we are setting strict to False

In [None]:
lr_min,lr_steep = learn.lr_find()
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")

# Training

In [None]:
learn = cnn_learner(dataloaders,
                    resnet152,
                    metrics=[accuracy, top_k_accuracy])

learn.load("uisketch_pretrained", strict=False)

In [None]:
learn.fit_one_cycle(6, lr_max=3e-4)

In [None]:
learn.save("uisketch-resnet-152")

In [None]:
learn.recorder.plot_loss()

## Validation Report

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)

In [None]:
interp.print_classification_report()

In [None]:
interp.plot_top_losses(9, figsize=(12,12))

# Evaluation

In [None]:
test_dataloader = dataloaders.test_dl(test_df, with_labels=True)

In [None]:
test_dataloader.show_batch()

In [None]:
interp = ClassificationInterpretation.from_learner(learn, dl=test_dataloader)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)

In [None]:
preds = interp.preds
y = interp.targs

tka = top_k_accuracy(preds, y)
print(f"Top K Accuray: {tka}")

acc = accuracy(preds, y)
print(f"Accuracy : {acc}")

In [None]:
interp.print_classification_report()

In [None]:
interp.plot_top_losses(9, figsize=(12,12))