In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import random
import numpy as np
import glob
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve

import torch
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [2]:
# set the seed
SEED = 1234
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
#tf.random.set_seed(SEED)

<torch._C.Generator at 0x243556f76d0>

### csv with filenames and labels

In [3]:
images = glob.glob("D://DATASETS/DogsVsCats/train-val-test/*.jpg")
len(images)

25000

In [4]:
labels = [1 if "dog" in fname else 0 for fname in images]
full_df = pd.DataFrame()
full_df["filename"] = images
full_df["label"] = labels
full_df[:3]

Unnamed: 0,filename,label
0,D://DATASETS/DogsVsCats/train-val-test\cat.0.jpg,0
1,D://DATASETS/DogsVsCats/train-val-test\cat.1.jpg,0
2,D://DATASETS/DogsVsCats/train-val-test\cat.10.jpg,0


### Shuffle dataframe

In [5]:
full_df = full_df.sample(frac=1).reset_index()
full_df[:5]

Unnamed: 0,index,filename,label
0,5262,D://DATASETS/DogsVsCats/train-val-test\cat.348...,0
1,22764,D://DATASETS/DogsVsCats/train-val-test\dog.798...,1
2,2633,D://DATASETS/DogsVsCats/train-val-test\cat.123...,0
3,22512,D://DATASETS/DogsVsCats/train-val-test\dog.776...,1
4,19404,D://DATASETS/DogsVsCats/train-val-test\dog.496...,1


In [6]:
full_df.label.sum()

12500

In [7]:
import numpy as np

dogs_df = full_df[full_df.label==1].sample(frac=1).reset_index(drop=True)
cats_df = full_df[full_df.label==0].sample(frac=1).reset_index(drop=True)

dogs = {"train":None, "test": None, "valid":None}
cats = {"train":None, "test": None, "valid":None}

dogs["train"], dogs["test"], dogs["valid"] = np.split(dogs_df, [int(0.8*len(dogs_df)), int(0.9*len(dogs_df))])
cats["train"], cats["test"], cats["valid"] = np.split(cats_df, [int(0.8*len(cats_df)), int(0.9*len(cats_df))])

train_df = pd.concat([dogs["train"], cats["train"]]).sample(frac=1).reset_index(drop=True)
test_df = pd.concat([dogs["test"], cats["test"]]).sample(frac=1).reset_index(drop=True)
valid_df = pd.concat([dogs["valid"], cats["valid"]]).sample(frac=1).reset_index(drop=True)

In [8]:
len(train_df), len(test_df), len(valid_df)

(20000, 2500, 2500)

In [9]:
train_df[:3]

Unnamed: 0,index,filename,label
0,1550,D://DATASETS/DogsVsCats/train-val-test\cat.113...,0
1,19077,D://DATASETS/DogsVsCats/train-val-test\dog.466...,1
2,5690,D://DATASETS/DogsVsCats/train-val-test\cat.387...,0


In [10]:
train_df.label.sum()

10000

### `Dataset` stores the samples and their corresponding labels, and `DataLoader` wraps an iterable around the Dataset to enable easy access to the samples.

#### https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
#### https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset

### TODO: Sampler

#### https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler

In [11]:
class CustomImageDataset(Dataset):
    '''
    https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
    '''
    def __init__(self, train_df, transform=None, target_transform=None):
        self.train_df = train_df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.train_df)

    def __getitem__(self, idx):
        img_path = train_df.filename[idx]
        # read JPEG or PNG image from filepath  --> output (Tensor[image_channels, image_height, image_width])
        image =  read_image(img_path)
        label =  train_df.label[idx]
        
        if self.transform:
            # These are transformation single level
            image = self.transform(image)
            
        if self.target_transform:
            label = self.target_transform(label)
            
        return {'image': image, 'label': label}  #image, label

### https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

In [12]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}
        
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'label': torch.from_numpy(label)}


class Normalize(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, image):

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = 2*(image/255.0) - 1 # between -1 and +1
        return image

In [13]:
from torchvision import transforms

transforms_train = transforms.Compose([Normalize(),
                                transforms.Resize(256),
                               transforms.RandomCrop(224)
                                ])

In [14]:
train_ds = CustomImageDataset(train_df,transform=transforms_train)
train_ds

<__main__.CustomImageDataset at 0x24361a88eb0>

In [15]:
len(test_df), len(valid_df)

(2500, 2500)

In [16]:
from torchvision import transforms

transforms_test_valid = transforms.Compose([Normalize(),
                                transforms.Resize(256),
                                transforms.CenterCrop(224)
                                ])

In [17]:
test_ds = CustomImageDataset(test_df, transform=transforms_test_valid)
valid_ds = CustomImageDataset(valid_df, transform=transforms_test_valid)

In [18]:
len(train_ds), len(test_ds), len(valid_ds)

(20000, 2500, 2500)

In [19]:
for item in train_ds:
    print(item)
    break

{'image': tensor([[[ 0.6087,  0.6985,  0.7399,  ..., -0.4219, -0.4297, -0.4291],
         [ 0.6239,  0.7386,  0.8176,  ..., -0.4195, -0.4311, -0.4340],
         [ 0.6363,  0.7401,  0.8583,  ..., -0.4195, -0.4311, -0.4353],
         ...,
         [ 0.7515,  0.7453,  0.7418,  ..., -0.2777, -0.2747, -0.2706],
         [ 0.7417,  0.7361,  0.7313,  ..., -0.2735, -0.2699, -0.2674],
         [ 0.7519,  0.7473,  0.7442,  ..., -0.2582, -0.2572, -0.2534]],

        [[ 0.7476,  0.8166,  0.8444,  ..., -0.4063, -0.4140, -0.4122],
         [ 0.7617,  0.8518,  0.9103,  ..., -0.4038, -0.4154, -0.4170],
         [ 0.7644,  0.8470,  0.9457,  ..., -0.4038, -0.4154, -0.4183],
         ...,
         [ 0.7751,  0.7689,  0.7653,  ..., -0.3404, -0.3375, -0.3333],
         [ 0.7652,  0.7596,  0.7548,  ..., -0.3362, -0.3326, -0.3301],
         [ 0.7755,  0.7709,  0.7677,  ..., -0.3210, -0.3199, -0.3161]],

        [[ 0.6973,  0.7622,  0.7794,  ..., -0.4298, -0.4449, -0.4539],
         [ 0.7120,  0.7974,  0.8492

In [20]:
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
device

'cuda'

In [21]:
device = torch.device(device)
device

device(type='cuda')

## Dataloader - Batching shuffling etc
#### `https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`


### TODO: optimize dataloader
- TFRecords
- interleave
- .map() operations
- batch level `transform`

In [22]:
train_dataloader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=0)

In [23]:
# TODO: How can I increase the batch size?
valid_dataloader =  DataLoader(valid_ds, batch_size=1, shuffle=False, num_workers=0)
test_dataloader =  DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

In [24]:
# Display image and label.
train_iterator = iter(train_dataloader)
train_sample = next(train_iterator)
train_features, train_labels = train_sample["image"], train_sample["label"]
train_features.shape, train_labels.shape

(torch.Size([256, 3, 224, 224]), torch.Size([256]))

In [25]:
#next(train_iterator)

In [26]:
import torch.nn as nn
import torch.nn.functional as F


class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv4 = nn.Conv2d(64, 128, 3)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        
        self.fc1 = nn.Linear(128, 32)
        self.fc2 = nn.Linear(32, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [27]:
from torchsummary import summary

model = SimpleNet().to(device)
#model = get_transformer()

summary(model, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 254, 254]             448
         MaxPool2d-2         [-1, 16, 127, 127]               0
            Conv2d-3         [-1, 32, 125, 125]           4,640
         MaxPool2d-4           [-1, 32, 62, 62]               0
            Conv2d-5           [-1, 64, 60, 60]          18,496
         MaxPool2d-6           [-1, 64, 30, 30]               0
            Conv2d-7          [-1, 128, 28, 28]          73,856
         MaxPool2d-8          [-1, 128, 14, 14]               0
 AdaptiveAvgPool2d-9            [-1, 128, 1, 1]               0
           Linear-10                   [-1, 32]           4,128
           Linear-11                    [-1, 1]              33
Total params: 101,601
Trainable params: 101,601
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/

In [28]:
model(train_features.to(device)).shape

torch.Size([256, 1])

In [29]:
from torch.utils.tensorboard import SummaryWriter
# 1.15.0 or above
# pip install numpy==1.23.5
# pip install tensorboard==1.15.0
# pip install tensorflow==2.7.0

writer = SummaryWriter("tboard/cats_dogs_v1")

In [30]:
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

In [31]:
device

device(type='cuda')

In [32]:
optimizer = torch.optim.Adam(model.parameters(), lr=10**-4, eps=1e-9)

optimizer_to(optimizer,device)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-09
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)

In [33]:
#ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1

# loss_fn = nn.CrossEntropyLoss().to(device)
loss_fn = nn.BCELoss().to(device)
loss_fn

BCELoss()

In [34]:
class RunningMeanMetric:
    def __init__(self, name):
        self.name = name
        self.sum = 0.0
        self.count = 0

    def update(self, value, n=1):
        self.sum += value * n
        self.count += n

    def reset_states(self):
        self.sum = 0.0
        self.count = 0

    def value(self):
        return self.sum / self.count if self.count > 0 else float('nan')

In [35]:
# class SummaryMetric:

#     def __init__(self, log_dir, name):
#         self._summary_writer = tf.summary.create_file_writer(f"{log_dir}/{name}")

#     def __call__(self, metrics, epoch, models_dict = None):
#         with self._summary_writer.as_default():
#             for metric_name, metric_value in metrics.items():
#                 if "image" in metric_name:
#                     if not "images" in metric_name:
#                         metric_value = [metric_value]
#                     tf.summary.image(metric_name, metric_value, max_outputs=16, step=epoch)
#                 elif "histogram" in metric_name:
#                     tf.summary.histogram(metric_name, metric_value, step=epoch)
#                 elif "grad" in metric_name:
#                     tf.summary.histogram(metric_name, metric_value, step=epoch)
#                 else: #assume scalar
#                     tf.summary.scalar(metric_name, metric_value, step=epoch)

#         if models_dict is not None:
#             for mname, model in models_dict.items():
#                 for mlayer in model.layers():
#                     try:
#                         tf.summary.histogram(f"{mname}-{mlayer.name}", mlayer.weights[0], step= epoch)
#                     except (ValueError, IndexError):
#                         pass
#         self._summary_writer.flush()

## https://pytorch.org/docs/stable/tensorboard.html
### https://github.com/sifubro/pytorch-transformer/blob/main/train.py#L8

In [36]:
# Enable CUDA error checking
torch.cuda.synchronize()

In [37]:
initial_epoch = 0
last_epoch = 2
global_step = 0

save_dir = "C://Users/SiFuBrO/Desktop/SCRIPTS!!!!!/GitHub/pytorch-base/checkpoints/cats_dogs_classifier_v1/"
metric_aggregators = {}
metric_aggregators["train_loss_weighted"] = RunningMeanMetric(name = "train_loss_weighted")

for epoch in range(initial_epoch, last_epoch):
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
    train_loss_epoch = []

    # reset metric aggregators (i.e. running means)
    for mname in  metric_aggregators.keys():
        metric_aggregators[mname].reset_states()

    for i, batch in enumerate(batch_iterator):

        train_features = batch["image"].to(device)
        train_labels = batch["label"].to(device)

        # model output
        output = model(train_features) # (B, 1)
        output = torch.squeeze(output) # (B,)
        # convert output from (B,1) -> (B,) and val_label cast from Long to float
        if torch.isnan(output).any() or torch.isinf(output).any():
            print("Ouput Tensor contains NaN or infinite values")
            global_step += 1
            continue

        train_loss = loss_fn(output, train_labels.type(torch.float))
        # try:    
        #     train_loss = loss_fn(output, train_labels.type(torch.float))
        # except Exception as e:
        #     print(f"ERROR epoch {epoch} - iteration: {i}")
        #     print(e)
        #     continue
        train_loss_float = train_loss.item()
        
        # Log the loss
        writer.add_scalar('train_loss_step', train_loss_float, global_step)
        writer.flush()

        metric_aggregators["train_loss_weighted"].update(train_loss_float)
        writer.add_scalar('train_loss_weighted', metric_aggregators["train_loss_weighted"].value() , global_step)
        writer.flush()
        
        # epoch loss
        train_loss_epoch+= [train_loss_float] #.item()

        # Backpropagate the loss
        train_loss.backward()

        # Update the weights
        optimizer.step()
        # set_to_none=True also sets the .grad attribute of each parameter to None. This can be helpful for memory efficiency and to avoid unintentional errors if gradients are accessed after they've been zeroed out.
        optimizer.zero_grad(set_to_none=True)

        global_step += 1
        
    writer.add_scalar('train_loss_epoch', np.mean(train_loss_epoch), global_step)
    writer.flush()
    
    # Run validation at the end of every epoch
    model.eval()
    total_val_loss = []
    val_outputs = []
    y_true_labels = []
    valid_iterator = iter(valid_dataloader)
    with torch.no_grad():
        for batch in valid_iterator:
            val_features =  batch["image"].to(device)
            # TODO: log some images in tensorboard
            
            val_labels =  batch["label"].to(device)
            val_output = model(val_features)[0] # from (1,1) [[0.67442]]  to-> (1) i.e., val_output = [0.67442] for examples
            #val_labels is (1,) size i.e, [0] or [1]   

            # validation loss
            val_loss = loss_fn(val_output, val_labels.type(torch.float)).item()
            total_val_loss += [val_loss]
            
            val_outputs += [val_output.item()]  # val_output = [0.67442] for examples 
            y_true_labels += [val_labels.item()]
        
        # Log the validation loss
        writer.add_scalar('valid_loss', np.mean(total_val_loss), global_step)
        writer.flush()

        # Create ROC plots
        fpr, tpr, thresholds = roc_curve(np.array(y_true_labels), np.array(val_outputs), pos_label=1)
        f, (ax1, ax2) = plt.subplots(2,1, figsize=(8,12))
        ax1.plot(fpr, tpr)
        ax1.set_xlabel("FPR")
        ax1.set_ylabel("TPR")
        ax1.set_xscale('log', base=10)
        ax1.set_ylim(0.0, 1.01)
        ax1.set_title("ROC on validation set")

        ax2.plot(fpr, thresholds)
        ax2.set_xlabel("FPR")
        ax2.set_ylabel("Thresholds")
        ax2.set_ylim(0.0, 1.01)

        plt.axvline(x=0.005, color='black', ls=":", label="FPR=0.5%")
        plt.axvline(x=0.01, color='black', ls="--", label="FPR=1%")
        writer.add_figure('ROC plot', plt.gcf(), global_step)
        writer.flush()

    
    # Save the model at the end of every epoch
    model_filename = f"{save_dir}/{epoch}.pt"
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

Processing Epoch 00: 100%|█████████████████████████████████████████████████████████████| 79/79 [01:10<00:00,  1.12it/s]
Processing Epoch 01: 100%|█████████████████████████████████████████████████████████████| 79/79 [01:07<00:00,  1.17it/s]


In [39]:
print('Done!')

Done!


In [None]:
## Resume Training:

In [43]:
loss_fn(torch.unsqueeze(val_output,dim=0),val_labels.type(torch.float))

tensor(0.6907, device='cuda:0')

In [36]:
type(batch["image"])

torch.Tensor

In [37]:
batch["label"]

tensor([0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1,
        0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1,
        0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1])

In [39]:
batch["image"].numpy()

array([[[[-0.600116  , -0.7715232 , -0.98809004, ..., -0.51925546,
          -0.4499052 , -0.35490334],
         [-0.8513372 , -0.61788774, -0.8683329 , ..., -0.52104306,
          -0.45699853, -0.36221644],
         [-0.9254856 , -0.64325154, -0.6304443 , ..., -0.52457213,
          -0.46140182, -0.36893332],
         ...,
         [-0.76769507, -0.7446293 , -0.7324017 , ..., -0.9124992 ,
          -0.9696112 , -0.9795303 ],
         [-0.73559904, -0.72904885, -0.7033951 , ..., -0.91429484,
          -0.9684049 , -0.98730886],
         [-0.7249031 , -0.7122105 , -0.7027049 , ..., -0.90359616,
          -0.95838463, -0.98961145]],

        [[-0.600116  , -0.7715232 , -0.98809004, ..., -0.5035692 ,
          -0.43421888, -0.33921707],
         [-0.8513372 , -0.61788774, -0.8683329 , ..., -0.5053568 ,
          -0.44131222, -0.34653017],
         [-0.9254856 , -0.64325154, -0.6304443 , ..., -0.5088858 ,
          -0.4457155 , -0.35324705],
         ...,
         [-0.81475395, -0.7916882 

In [42]:
# Display image and label.
valid_iterator = iter(valid_dataloader)
valid_sample = next(valid_iterator)

In [43]:
val_features, val_labels = valid_sample["image"],  valid_sample["label"]

val_features.shape, val_labels.shape

(torch.Size([1, 3, 224, 224]), torch.Size([1]))

In [44]:
val_output = model(val_features)
val_output.shape

torch.Size([1, 1])

In [45]:
val_output

tensor([[0.1322]], grad_fn=<AddmmBackward0>)

In [46]:
val_output.view(-1)

tensor([0.1322], grad_fn=<ViewBackward0>)

In [47]:
val_labels.shape

torch.Size([1])

In [52]:
val_labels.type(torch.LongTensor)

tensor([0])

In [55]:
val_loss = loss_fn(val_output.view(-1), val_labels.type(torch.float))
val_loss

tensor(0.1418, grad_fn=<BinaryCrossEntropyBackward0>)

In [38]:
val_loss.item()

0.0

In [43]:
val_output

tensor([[-0.0599, -0.1396]], grad_fn=<AddmmBackward0>)

In [None]:
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")

## TODO: Tensorboard


In [None]:
def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

In [59]:
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

target.shape, input.shape

(torch.Size([3]), torch.Size([3, 5]))

In [None]:
## TODOS:
- Multi gpu training
- Load model and resume training
- Freeze part of the model and train the rest
- Initialize custom model