# Example for:
Transfer Learning reinit=False Experiment from all other classes of FashionMNIST to shoes (Sandal, Sneaker, Ankle boot) classes in FashionMNIST

### Setup and Hyperparams

In [1]:
# Specify which gpu
import os
gpu_id = 1
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

import sys
sys.path.append('/mnt/c/Users/Arnisa/Desktop/MP/mp-tl-study')
from functions.utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)  # if using multi-GPU

In [2]:
cuts = [0,1,2,3,4,5,6]

params = {
      # MODEL ARCHITECTURE PARAMS
      'depth': 6,
      'num_channels': 64,
      'activation_function': nn.ReLU,
      'kernel_size': 3,
      # TRAINING PARAMS
      'device': device,
      'lr_pretrain': 0.001,   
      'lr_fine_tune': 0.001, 
      'num_train': 40,
      'early_stop_patience': 6,
      'save_best': False,
      'save_checkpoints': False,
      'is_cnn': True,
      'is_debug': False,
      'classification_report_flag': False,
      'batch_size':64,
      # DATASET PARAMS
      'pre_train_classes': [0, 1, 2, 3, 4, 6, 8],
      'fine_tune_classes': [5, 7, 9],
      'val_split': 0.1,
      'num_workers': 0,
      'generate_dataset_seed': 42,
      # EXPERIMENT SETTING PARAMS
      'use_pooling': True,  
      'pooling_every_n_layers': 2, # add pooling after every n layers specified here. For only one pooling after all the CNN layers, this equals params['depth']
      'pooling_stride': 2,
      'freeze': True,         # VARIABLE
      'reinit': False,         # THIS CHANGED
      'truncate': False,      # VARIABLE
    }

In [3]:
root_dir = './data' 
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataloader_wrapped = TransferLearningWrapper(params, datasets.FashionMNIST, datasets.FashionMNIST, root_dir, transform=transform)## Pretraining

## Pretraining
We re-use the pre-trained model from FashionMNIST regular classes FR experiment

In [4]:
pretrained_model = CustomCNN(params, dataloader_wrapped.output_dim)
pretrained_model.to(device)

CustomCNN(
  (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act0): ReLU()
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act1): ReLU()
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act2): ReLU()
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act3): ReLU()
  (pool3): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act4): ReLU()
  (conv5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act5): ReLU()
  (pool5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc): Linear(in_features=576, out_features=7, bias=True)
)

In [5]:
pretrained_model.load_state_dict(torch.load('pretrained_models/regular_classes.pth', map_location=device))

<All keys matched successfully>

## Fine-tuning Experiments
We re-use the baselines from the FashionMNIST regular classes FR experiment

### Fine-tuning

In [9]:
results = []
percentages = [0.001, 0.01, 0.1, 0.5, 1.0]

In [10]:
dataloader_wrapped.update_phase('finetune')

for sampled_percentage in percentages:

    if sampled_percentage <= 0.01:
        repeats = 25
    elif sampled_percentage < 0.5:
        repeats = 20
    else:
        repeats = 5
        
    for sampled_cut_point in cuts:

        for repeat in range(repeats):
            print(f"\nSampled Percentage: {sampled_percentage}, Sampled Cut Point: {sampled_cut_point}, Lr: {params['lr_fine_tune']}, Repeat: {repeat}")

            # Reduce the dataset
            train_loader_reduced = reduce_dataset(dataloader_wrapped.train_loader, sampled_percentage, seed=repeat)
            dataset_namespace_new = SimpleNamespace(train_loader=train_loader_reduced, test_loader=dataloader_wrapped.test_loader, val_loader=dataloader_wrapped.val_loader)
            torch.manual_seed(repeat) # because in the cut function we reinitialize some layers too (at least the dense layers)
            
            # Copy and then cut the model - we already deepcopy it in the function: pretrained_model
            model_new = cut_custom_cnn_model(pretrained_model, cut_point=sampled_cut_point, params=params, output_dim=dataloader_wrapped.output_dim)
            model_new.to(device)
            
            # Train and evaluate
            trainer = Trainer(model_new, dataset_namespace_new, params['lr_fine_tune'], params)
            train_acc, test_acc, effective_epochs, checkpoints = trainer.train(verbose=0)
            print(f"Training Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}")

            # Store the results
            results.append({"lr":params['lr_fine_tune'], "sampled_percentage":sampled_percentage, "sampled_cut_point":sampled_cut_point, "repeat":repeat, "train_acc":train_acc, "test_acc":test_acc})


Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 0
Early stopping invoked.
Training Accuracy: 0.9333, Test Accuracy: 0.7667

Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 1
Early stopping invoked.
Training Accuracy: 0.8667, Test Accuracy: 0.7690

Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 2
Early stopping invoked.
Training Accuracy: 1.0000, Test Accuracy: 0.7746

Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 3
Early stopping invoked.
Training Accuracy: 0.8667, Test Accuracy: 0.8477

Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 4
Early stopping invoked.
Training Accuracy: 1.0000, Test Accuracy: 0.8462

Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 5
Early stopping invoked.
Training Accuracy: 1.0000, Test Accuracy: 0.8327

Sampled Percentage: 0.001, Sampled Cut Point: 0, Lr: 0.001, Repeat: 6
Early stopping invoked.
Training Accuracy: 0.6000, Test Accuracy:

In [27]:
# save fine-tuning results
params_tmp = copy.deepcopy(params)
del params_tmp["device"]
params_tmp["activation_function"] = str(params_tmp["activation_function"])
results = [params_tmp] + results

with open(f'results/cuts_regular_classes_F.json', 'w') as f:
    json.dump(results, f)