In [1]:
from fastai.vision.all import *
from train_utils import *

In [2]:
PATH = Path('/mnt/dysk25/data/shopee/')
OUTPUT_CLASSES = 11014
train_df = add_splits(pd.read_csv(PATH/'train.csv'))




In [3]:
class ArcFaceClassifier(nn.Module):
    def __init__(self, in_features, output_classes):
        super().__init__()
        self.W = nn.Parameter(torch.Tensor(in_features, output_classes))
        nn.init.kaiming_uniform_(self.W)
    def forward(self, x):
        x_norm = F.normalize(x)
        W_norm = F.normalize(self.W, dim=0)
        return x_norm @ W_norm
    
class ResnetArcFace(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = create_body(resnet18, cut=-2)
        nf = 2* num_features_model(nn.Sequential(*self.body.children()))
        self.after_conv=nn.Sequential(
            AdaptiveConcatPool2d(),
            Flatten(),
            nn.BatchNorm1d(nf))   
        self.classifier = ArcFaceClassifier(nf, OUTPUT_CLASSES)
        self.outputEmbs = False
    
    def forward(self, x):
        embeddings = self.after_conv(self.body(x))
        if self.outputEmbs:
            return embeddings
        return self.classifier(embeddings)


In [4]:
def get_img_file(row):
    img =row.image
    fn  = PATH/'train_images'/img
    if not fn.is_file():
        fn = PATH/'test_images'/img
    return fn

def get_dls(size, bs):
    data_block = DataBlock(blocks = (ImageBlock(), CategoryBlock(vocab=train_df.label_group.to_list())),
                 splitter=ColSplitter(),
                 get_y=ColReader('label_group'),
                 get_x=get_img_file,
                 item_tfms=Resize(int(size*1.5), resamples=(Image.BICUBIC,Image.BICUBIC)),
                 
                 batch_tfms=aug_transforms(size=size, min_scale=0.75)+[Normalize.from_stats(*imagenet_stats)],
                 )
    return data_block.dataloaders(train_df, bs=bs,num_workers=16)

In [5]:
def split_2way(model):
    return L(params(model.body), params(model.classifier))

In [6]:
learn = Learner(get_dls(224,256), ResnetArcFace(), splitter=split_2way, loss_func=arcface_loss,
               cbs = [F1FromEmbs],metrics=FakeMetric()).to_fp16()

In [None]:
%%time
learn.fine_tune(15, 1e-2)

epoch,train_loss,valid_loss,F1 embeddings,time
0,18.078884,,0.717193,01:36


epoch,train_loss,valid_loss,F1 embeddings,time
0,11.347,,0.729149,00:50
1,9.184935,,0.738791,00:51
2,7.761072,,0.747634,00:52
3,6.811058,,0.754856,00:55
4,5.947067,,0.760332,00:54
5,5.20363,,0.765676,00:55
6,4.521454,,0.768907,00:56
7,4.020483,,0.768923,00:57
8,3.54036,,0.769737,00:56
9,3.087128,,0.772755,00:57
