In [None]:
 %config Completer.use_jedi = False

In [None]:
!pip install -U efficientnet

# Imports

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
from kaggle_datasets import KaggleDatasets
from functools import partial
import matplotlib.pyplot as plt
import re
from sklearn.model_selection import StratifiedKFold
import math
import tensorflow.keras.backend as K
import gc
from keras.callbacks import Callback
import tensorflow_addons as tfa
import efficientnet.keras as efc

print("Tensorflow version" + tf.__version__)


In [None]:
seed = 42
tf.random.set_seed(seed)

# Setting up TPU Strategy

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device: ', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas: ', strategy.num_replicas_in_sync)


# CONFIG

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = KaggleDatasets().get_gcs_path(dataset_dir='cassava-leaf-disease-tfrecords-512x512')
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]
CLASSES = ['0','1','2','3','4']
EPOCHS = 10
N_FOLDS = 4

# Data Loading
## Helper methods

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [512,512, 3])
    return image

def onehot(image,label):
    CLASSES = 5
    return image,tf.one_hot(label,CLASSES)

def read_tfrecord(example, labeled):
    tfrecord_format= {
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.int64),
        'image_name': tf.io.FixedLenFeature([], tf.string)
    } if labeled else{
        'image': tf.io.FixedLenFeature([], tf.string),
        'image_name': tf.io.FixedLenFeature([], tf.string)
    }
    
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'],tf.int32)
        return onehot(image,label)
    idnum = example['image_name']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset


In [None]:
cv = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

files = tf.io.gfile.glob(GCS_PATH + '/Id_train*.tfrec')


# Data Augmentations
## Perspective Warping

In [None]:
# ported from docs.fast.ai 

def find_coeffs(orig_pts, targ_pts):
    matrix = []
    for p1, p2 in zip(targ_pts, orig_pts):
        matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
        matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])
    A = np.array(matrix, dtype=np.float32)[None,...]
    B = np.array(orig_pts, dtype=np.float32)
    B = np.reshape(B, [8,1])[None,...]
    
    return tf.linalg.solve(A, B)[0][:,0]

def get_coords_persp_transform(magnitude=0.2,p=0.5):
    H,W = 1,1
    # Generate random numbers for coordinates
    ys = np.random.uniform(0,magnitude, [4])
    xs = np.random.uniform(0,magnitude, [4])
    # prepare coordinates the coordinates
    ys = ys*np.array([H,1,1,H])
    ys[1:3] = 1 - (1 - ys[[0,3]])*ys[[1,2]]
    xs = xs * np.array([W,W,1,1])
    xs[2:] = 1 - (1 - xs[:2])*xs[2:]

    return np.stack([xs,ys])
    

def get_persp_mat(IMAGE_SIZE,magnitude=0.2):
    H,W = IMAGE_SIZE
    src_coords = np.transpose(get_coords_persp_transform(magnitude=0.2)*H).astype(np.int32)
    targ_coords = np.transpose(np.array([[0,0,W,W],[0,H,H,0]])).astype(np.int32)
    p_mat = find_coeffs(targ_coords,src_coords)
    p_mat = tf.concat([p_mat, [1]],axis=0)
    return tf.reshape(p_mat, [3,3])


## Random Transforms

In [None]:
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

def transform(image,label,name=None,p_rot=0.5,p_shr=0.5,p_h_zoom=0.5,p_w_zoom=0.5,p_h_shift=0.5,p_w_shift=0.5,p_persp_warp=0.5, mag_persp_warp=0.3):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 60. * tf.random.normal([1],dtype='float32') if np.random.uniform() < p_rot else tf.constant([0.])
    shr = .2 * tf.random.normal([1],dtype='float32') if np.random.uniform() < p_shr else  tf.constant([0.])
    h_zoom = 0.9 + tf.random.uniform([1],maxval=4., dtype='float32')/10. if np.random.uniform() < p_h_zoom else tf.constant([1.])
    w_zoom = 0.9 + tf.random.uniform([1],maxval=4.,dtype='float32')/10. if np.random.uniform() < p_w_zoom else  tf.constant([1.])
    h_shift = 10. * tf.random.uniform([1],dtype='float32') if np.random.uniform() < p_h_shift else tf.constant([0.])
    w_shift = 10. * tf.random.uniform([1],dtype='float32') if np.random.uniform() < p_w_shift else  tf.constant([0.])
    persp = get_persp_mat(IMAGE_SIZE,magnitude=mag_persp_warp) if np.random.uniform() < p_persp_warp else np.eye(3)
    
    # GET TRANSFORMATION MATRIX    
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 
    
    m = m @ persp
    
    d = tfa.image.transform(image, tf.reshape(m, [9])[:-1], fill_mode='reflect')
    d = tf.reshape(d,[512,512,3])
    seeds = np.random.randint([4])    
    d = tf.image.stateless_random_flip_left_right(d, [seeds[0],seeds[0]+3])
    d = tf.image.stateless_random_flip_up_down(d, [seeds[0], seeds[0]+3])
    if np.random.random() < 0.7:
        d = tf.image.stateless_random_contrast(d, 0.3, 0.9, [seeds[0], seeds[0]+3])
        d = tf.image.stateless_random_brightness(d, 0.1, [seeds[0], seeds[0]+3])
#     offset = (512-DIM)//2
#     d = tf.slice(d, [offset,offset,0],[DIM, DIM,3])
    
    if name:
        return d,label, name
    return d,label


In [None]:
import tensorflow_probability as tfp

def cutmix(image, label, PROBABILITY = 1.0):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with cutmix applied
    DIM = IMAGE_SIZE[0]
    CLASSES = 5
    imgs = []; labs = []
    for j in range(BATCH_SIZE):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.int32)
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast( tf.random.uniform([],0,BATCH_SIZE),tf.int32)
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        b = tf.random.uniform([],0,1) # this is beta dist with alpha=1.0
        WIDTH = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P
        ya = tf.math.maximum(0,y-WIDTH//2)
        yb = tf.math.minimum(DIM,y+WIDTH//2)
        xa = tf.math.maximum(0,x-WIDTH//2)
        xb = tf.math.minimum(DIM,x+WIDTH//2)
        # MAKE CUTMIX IMAGE
        one = image[j,ya:yb,0:xa,:]
        two = image[k,ya:yb,xa:xb,:]
        three = image[j,ya:yb,xb:DIM,:]
        middle = tf.concat([one,two,three],axis=1)
        img = tf.concat([image[j,0:ya,:,:],middle,image[j,yb:DIM,:,:]],axis=0)
        imgs.append(img)
        # MAKE CUTMIX LABEL
        a = tf.cast(WIDTH*WIDTH/DIM/DIM,tf.float32)
        if len(label.shape)==1:
            lab1 = tf.one_hot(label[j],CLASSES)
            lab2 = tf.one_hot(label[k],CLASSES)
        else:
            lab1 = label[j,]
            lab2 = label[k,]
        labs.append((1-a)*lab1 + a*lab2)            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(BATCH_SIZE,DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(BATCH_SIZE,CLASSES))
    return image2,label2

def mixup(image, label, PROBABILITY = 0.7, alpha=0.4):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with mixup applied
    DIM = IMAGE_SIZE[0]
    CLASSES = 5
    beta = tfp.distributions.Beta(alpha, alpha).sample(BATCH_SIZE)
    
    imgs = []; labs = []
    for j in range(BATCH_SIZE):
        # DO MIXUP WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.float32)
        # CHOOSE RANDOM
        k = tf.cast( tf.random.uniform([],0,BATCH_SIZE),tf.int32)
        a = beta[j]*P # this is beta dist with alpha=1.0
        # MAKE MIXUP IMAGE
        img1 = image[j,]
        img2 = image[k,]
        imgs.append((1-a)*img1 + a*img2)
        # MAKE CUTMIX LABEL
        if len(label.shape)==1:
            lab1 = tf.one_hot(label[j],CLASSES)
            lab2 = tf.one_hot(label[k],CLASSES)
        else:
            lab1 = label[j,]
            lab2 = label[k,]
#         print(lab1, lab2)
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(BATCH_SIZE,DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(BATCH_SIZE,CLASSES))
#     print(label, label2)

    return image2,label2


In [None]:
def rescale_images(image, label):
    return tf.multiply(image, 1/255.), label


In [None]:
# validation set transforms
def resize(image, label, size=IMAGE_SIZE):
    return tf.image.resize(
        image, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, preserve_aspect_ratio=False,
        antialias=False, name=None
    ), label


In [None]:
def get_training_dataset(FILENAMES):
    dataset = load_dataset(FILENAMES, labeled=True)
    dataset = dataset.map(transform, num_parallel_calls=AUTOTUNE)
    dataset = dataset.map(rescale_images, num_parallel_calls=AUTOTUNE)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(512)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(resize, num_parallel_calls=AUTOTUNE)
#     dataset = dataset.map(mixup, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
def get_validation_dataset(FILENAMES,ordered=True):
    dataset = load_dataset(FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.map(rescale_images, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(resize, num_parallel_calls=AUTOTUNE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset


In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


In [None]:
def get_train_ds():
    it = iter(cv.split(files, ['']*len(files)))
    trn_idx, val_idx = next(it)

    TRAINING_FILENAMES = np.array(files)[trn_idx].tolist()
    VALID_FILENAMES = np.array(files)[val_idx].tolist()

    return get_training_dataset(TRAINING_FILENAMES)


In [None]:
# ds = get_train_ds()

# imgs, lbls = next(iter(ds.take(1)))

def plot_images(num_images=9, ):
    images, labels = next(iter(get_train_ds().take(1)))
    rows = math.ceil(num_images/3)
    fig, axes = plt.subplots(rows, 3, constrained_layout=False, figsize=(20,20))
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(images[i])
        ax.set_title(str(labels[i].numpy()))
    plt.show()


In [None]:
# plot_images(15)

# LR Scheduler - One Cycle Schedule

In [None]:
def sched_lin(start, end, pos): return start + pos*(end-start)

def sched_cos(start, end, pos): return start + (1 + np.cos(np.pi*(1-pos))) * (end-start)/2

In [None]:
class Scheduler:
    """
    Scheduler used to schedule learning rate and momentum for Cyclic Learning Rate. 
    Args:
        pct_start: percent progress when to start annealing
        max_val: maximum value of the hyperparameter during the schedule.
        init_div: default=100., div factor to calculate the initial value for the schedule.
        div_fac: default=25000., div factor to calculate the final value for the schedule.
        sched_func: default=cosine_scheduling. the type of scheduling for each part of the scheduler
    """
    def __init__(self, pct_start, max_val, init_div=25., div_fact=1000000., sched_func=sched_cos):
        self.pcts = tf.cumsum(tf.constant([0,pct_start,1-pct_start]))
        self.max_val = max_val
        self.init_val = max_val/init_div
        self.final_val = max_val/div_fact
        if isinstance(sched_func, (list,tuple)):
            if len(sched_func)>2: raise ValueError(f"The sched functions should be only two, received {len(sched_func)} ")
            self.scheds = [partial(sched_func[0],start=self.init_val,end=self.max_val), partial(sched_func[1],start=self.max_val, end=self.final_val)]           
        self.scheds = [partial(sched_func,start=self.init_val,end=self.max_val), partial(sched_func,start=self.max_val, end=self.final_val)]
        
    def __call__(self, pos):
        if pos==1: return self.final_val
        idx = tf.where(tf.not_equal(pos>=self.pcts, tf.constant(False)))[-1].numpy()[-1]
        sched = self.scheds[idx]
        actual_pos = (pos-self.pcts[idx])/(self.pcts[idx+1]-self.pcts[idx])
        return sched(pos=actual_pos)


In [None]:
class OneCycleLR(keras.callbacks.Callback):
    def __init__(self, 
                 num_samples,
                 batch_size,
                 steps_per_epoch,
                 max_lr, 
                 init_div=25.0,
                 div_fact=1000000.0,
                 pct_start=0.3,
                 maximum_momentum=0.94,
                 minimum_momentum=0.85,
                 sched_func=sched_cos,
                 verbose=True):
        super(OneCycleLR, self).__init__()
        
            
        self.initial_lr = max_lr
        self.STEPS_PER_EPOCH = steps_per_epoch
        self.max_momentum = maximum_momentum
        self.min_momentum = minimum_momentum
        self.verbose = verbose
        self.lr_schedule = Scheduler(pct_start, max_lr,init_div, div_fact, sched_func=sched_func)
        self.mom_div = minimum_momentum/maximum_momentum
        self.mom_schedule = Scheduler(pct_start, minimum_momentum, self.mom_div,self.mom_div, sched_func=sched_func)
        self.lrs = []
        self.moms = []
        
    def reset(self):
        K.set_value(self.model.optimizer.lr, self.lr_schedule.init_val)
        self.update_momentum('momentum', self.mom_schedule.init_val)
        self.update_momentum('beta_1', self.mom_schedule.init_val)

    def calculate_lr(self, pos):
        """
        Calculates and returns learning rate for the next batch.
        """
        return self.lr_schedule(pos)
        
        
    def calculate_momentum(self, pos):
        """
        Calculates and returns the momentum for the next batch.
        """
        return self.mom_schedule(pos)
    
    def on_train_begin(self, logs=None):
        self.reset()
        self.history = {}
        self.curr_batch = 0
    
    def update_momentum(self, param_name, value):
        if hasattr(self.model.optimizer,param_name):
            K.set_value(getattr(self.model.optimizer,param_name), value)
            self.moms.append(value)
    
    def on_train_batch_begin(self, batch, logs=None):
        pos = self.curr_batch/(self.params['epochs']*self.STEPS_PER_EPOCH)
        lr = self.calculate_lr(pos)
        #set the new learning rate
        K.set_value(self.model.optimizer.lr, lr)
        self.lrs.append(lr)
        
        #set the new momentum
        self.update_momentum('momentum', self.calculate_momentum(pos))
        self.update_momentum('beta_1', self.calculate_momentum(pos))
        self.curr_batch += 1


# Model - Transfer Learning

In [None]:
# The batchnorm layers can be used while fine tuning the model to a new dataset.
def trainable_bn(x):
    if isinstance(x,tf.keras.layers.BatchNormalization):
         x.trainable=True 

In [None]:
class CassavaModel(keras.Model):
    def __init__(self, model_name='eff', pretrained=True):
        super(CassavaModel, self).__init__()
#         prep = tf.keras.layers.Lambda(tf.keras.applications.resnet_v2.preprocess_input)
#         base_model = tf.keras.applications.ResNet50V2(weights='imagenet',include_top=False)

        eff=efc.EfficientNetB7(weights='noisy-student',include_top=False)
        self.model=tf.keras.Sequential()
#         self.model.add(prep)
#         trainable_bn(base_model)
        self.model.add(eff)
        self.model.add(tf.keras.layers.Dropout(0.5))
        self.model.add(tf.keras.layers.GlobalAveragePooling2D())
        self.model.add(tf.keras.layers.Dense(5,activation='softmax'))

    def call(self, inputs):
#         x = self.encoder(inputs)
#         x = self.pool_1(x)
# #         x = tf.concat([self.pool_1(x),self.pool_2(x)], axis=1)
#         x = self.head(x)
        return self.model(inputs)

# Training Utils

In [None]:
def f1_metric(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val

In [None]:
def get_model(model_name='eff'):
    with strategy.scope():        
        model = CassavaModel(model_name)
        model.compile(
            optimizer=tf.keras.optimizers.Adam(lr=LR),
            loss=tf.keras.losses.CategoricalCrossentropy(),
            metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc', dtype=None),
                     f1_metric])
    return model

In [None]:
def predict_tta(model, ds, tta_times=5):
    yhats = []
    for i in range(tta_times):
        yhats.append(model.predict(ds,verbose=1).squeeze())
    yhats = np.stack(yhats,axis=0).sum(axis=0)
    if len(yhats.shape) < 2:
        yhats = yhats[None,:]
    return yhats.argmax(axis=1)


In [None]:
def get_score(model, ds, tta_times=1):
    preds = predict_tta(model, ds, tta_times=tta_times)
    labels = []
    for img, label in ds:
        labels.extend(label.numpy().argmax(axis=1).tolist())
    return (preds == np.array(labels)).mean()

In [None]:
from tensorflow.keras.callbacks import ReduceLROnPlateau,ModelCheckpoint,EarlyStopping

In [None]:
LR=1e-4*strategy.num_replicas_in_sync

In [None]:
def checkpoint_callback(fold):
    return ModelCheckpoint(f'best_{fold}.h5',verbose=1,monitor='val_loss',save_best_only=True)

In [None]:
redlr=ReduceLROnPlateau(monitor='val_loss',patience=3,verbose=1)
chkpt=checkpoint_callback
es=EarlyStopping(patience=8,verbose=1,restore_best_weights=True)

In [None]:
def fit(TRAINING_FILENAMES, VALID_FILENAMES, train_ds, valid_ds, model,
        lr,n_epochs=EPOCHS,callbacks=None):
    NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
    NUM_VALIDATION_IMAGES = count_data_items(VALID_FILENAMES)
    STEPS_PER_EPOCH = NUM_TRAINING_IMAGES//BATCH_SIZE
    VALID_STEPS = NUM_VALIDATION_IMAGES//BATCH_SIZE
    
    one_cycle_lr = OneCycleLR(NUM_TRAINING_IMAGES, BATCH_SIZE,STEPS_PER_EPOCH, lr)#*strategy.num_replicas_in_sync )
#     ONE_CYCLE_LR.append(one_cycle_lr)
    return model.fit(train_ds, 
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=n_epochs,
                    callbacks=callbacks if callbacks else [one_cycle_lr],
                    validation_data=valid_ds,
                    validation_steps=VALID_STEPS)

In [None]:
def run(EPOCHS, model_type='eff'):
    model_types = {'eff': 'EfficientNetB3', 'resnet': 'Resnet50'}
    plt.figure(figsize=(20,20))
    for fold, (trn_idx, val_idx) in enumerate(cv.split(files, ['']*len(files))):
        TRAINING_FILENAMES = np.array(files)[trn_idx].tolist()
        VALID_FILENAMES = np.array(files)[val_idx].tolist()
        chkpt=checkpoint_callback(fold)
        callbacks=[redlr,chkpt,es]
        train_dataset = get_training_dataset(TRAINING_FILENAMES)
        valid_dataset = get_validation_dataset(VALID_FILENAMES)
        model = get_model(model_type)
#         print(model)
#         history = fit(TRAINING_FILENAMES, VALID_FILENAMES,
#                       train_dataset, valid_dataset, model, lr=1e-2/2, n_epochs=EPOCHS)
#         plt.plot(np.arange(0,EPOCHS),history.history['loss'],'r-',label='frozen train loss')
#         plt.plot(np.arange(0,EPOCHS),history.history['acc'],'r--',label='frozen train accuracy')

#         plt.plot(np.arange(0,EPOCHS),history.history['val_loss'],'b-',label='frozen val loss')
#         plt.plot(np.arange(0,EPOCHS),history.history['val_acc'],'b--',label='frozen val accuracy')
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        print(f'$$$$$$$$$$$$$$$ MODEL TYPE:: {model_types[model_type]}')
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        
        model.trainable = True

        history = fit(TRAINING_FILENAMES, VALID_FILENAMES,
                      train_dataset, valid_dataset, model, lr=1e-4, n_epochs=EPOCHS,callbacks=callbacks)
#         plt.plot(np.arange(0,len(history.history['loss'])),history.history['loss'],'r:',label='unfrozen train loss')
#         plt.plot(np.arange(0,len(history.history['loss'])),history.history['acc'],'r-.',label='unfrozen train accuracy')

#         plt.plot(np.arange(0,len(history.history['loss'])),history.history['val_loss'],'b:',label='unfrozen val loss')
#         plt.plot(np.arange(0,len(history.history['loss'])),history.history['val_acc'],'b-.',label='unfrozen val accuracy')


        print(f"Validation Accuracy {get_score(model, valid_dataset)} with {EPOCHS} Unfrozen")
#         model.save(f'/kaggle/working/{model_types[model_type]}-fold-{fold}-unfreeze', save_format='tf')
        del model
        gc.collect()
        

In [None]:
# transforms
run(20, 'eff')

In [None]:

# 4 one cycle epochs
# run(4, 'eff')

In [None]:

# # No aug
# run(4, 'eff')