In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision.all import *
import warnings
warnings.filterwarnings("ignore")

In [None]:
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.models.vision_transformer import *
from self_supervised.vision.dino import *
from self_supervised.vision.swav import get_swav_aug_pipelines

In [None]:
def get_dls(size, bs, workers=None, n_subset=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    
    if n_subset is None: files = get_image_files(source)
    else:              files = np.random.choice(get_image_files(source), n_subset)
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
    
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

In [None]:
bs, resize, size = 8, 256, 224
dls = get_dls(resize, bs, n_subset=None)

In [None]:
aug_pipelines = get_dino_aug_pipelines(rotate=True, 
                                       rotate_deg=10, 
                                       jitter=True, 
                                       bw=True, 
                                       blur=True,
                                       blur_s=(4, 16))

In [None]:
import timm
# from timm.models.convmixer import _create_convmixer

In [None]:
# deits16 = deit_small(patch_size=16, drop_path_rate=0.1)
# deits16 = MultiCropWrapper(deits16)
# dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
# student_model = nn.Sequential(deits16,dino_head)

# deits16 = deit_small(patch_size=16)
# deits16 = MultiCropWrapper(deits16)
# dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
# teacher_model = nn.Sequential(deits16,dino_head)

# dino_model = DINOModel(student_model, teacher_model)

In [None]:
# convmixer=timm.create_model('xcit_tiny_12_p8_224',num_classes=0, in_chans=3,drop_path_rate=0.1)
# convmixer = MultiCropWrapper(convmixer)
# dino_head = DINOHead(convmixer.encoder.embed_dim, 2**16, norm_last_layer=True)
# student_model = nn.Sequential(convmixer,dino_head)
# convmixer=timm.create_model('xcit_tiny_12_p8_224',num_classes=0, in_chans=3)
# convmixer = MultiCropWrapper(convmixer)
# dino_head = DINOHead(convmixer.encoder.embed_dim, 2**16, norm_last_layer=True)
# teacher_model = nn.Sequential(convmixer,dino_head)
# dino_model = DINOModel(student_model, teacher_model)

In [None]:
convmixer=CheckpointSequential(timm.create_model('convmixer_768_32',num_classes=0, in_chans=3,drop_path_rate=0.1), checkpoint_nchunks=2)
convmixer = MultiCropWrapper(convmixer)
dino_head = DINOHead(768, 2**16, norm_last_layer=True)
student_model = nn.Sequential(convmixer,dino_head)
convmixer=CheckpointSequential(timm.create_model('convmixer_768_32',num_classes=0, in_chans=3,drop_path_rate=0.1), checkpoint_nchunks=2)
convmixer = MultiCropWrapper(convmixer)
dino_head = DINOHead(768, 2**16, norm_last_layer=True)
teacher_model = nn.Sequential(convmixer,dino_head)
dino_model = DINOModel(student_model, teacher_model)

In [None]:
class SaveModelCallback(TrackerCallback):
    "A `TrackerCallback` that saves the model's best during training and loads it at the end."
    _only_train_loop,order = True,TrackerCallback.order+1
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, at_end=False,
                 with_opt=False, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        assert not (every_epoch and at_end), "every_epoch and at_end cannot both be set to True"
        # keep track of file path for loggers
        self.last_saved_path = None
        store_attr('fname,every_epoch,at_end,with_opt')

    def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)

    def after_epoch(self):
        "Compare the value monitored to its best score and save if best."
        if self.every_epoch:
            if (self.epoch%self.every_epoch) == 0: self._save(f'{self.fname}_{self.epoch}')
        else: #every improvement
            super().after_epoch()
            if self.new_best:
                print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
                self._save(f'{self.fname}')

    def after_fit(self, **kwargs):
        "Load the best model."
        if self.at_end: self._save(f'{self.fname}')
        elif not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)

In [None]:
dino_cb = DINO(aug_pipelines=aug_pipelines,
               tpt_start=0.04,
               tpt_end=0.04,
               tpt_warmup_pct=0., 
               freeze_last_layer=1)

In [None]:
grad_clip_cb = GradientClip(max_norm=3., norm_type=2.)
save_cb = SaveModelCallback(every_epoch=20, with_opt=True, fname='conmixdino_pretraining')
nan_cb = TerminateOnNaNCallback()

cbs=[dino_cb, grad_clip_cb, save_cb, nan_cb]

In [None]:
learn = Learner(dls, dino_model, opt_func=Adam, cbs=cbs).to_fp16()

In [None]:
# b = dls.one_batch()
# learn._split(b)
# learn('before_fit')
# learn('before_batch')
# learn.dino.show(n=5);

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning) 
max_lr = 2.5e-4
lr_sched = combine_scheds([0.1,0.9], [SchedLin(0.,max_lr), SchedCos(max_lr,1e-6)])
wd_sched = SchedCos(0.04,0.4)
param_scheduler = ParamScheduler({"lr":lr_sched, "wd":wd_sched})
learn.fit(200, cbs=[param_scheduler])


epoch,train_loss,valid_loss,time


KeyboardInterrupt: 