**Preliminary step**

If you want to train ViT with FAX loss, you need to change the code ViT code inside timm library (`https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py`) with the code in the file `fixatt/training/vision_transformer_modified_for_fax_loss.py` and load the modified timm library.


In [1]:
# Imports
import torch
import timm
import random
import csv   
import random
import numpy as np
import torch.nn as nn
from pathlib import Path
import pickle

from utils import *
from dataloaders import *

import os
import configparser
import sys
import ast

# Define a function to convert string to boolean
def str_to_bool(s):
    return s.lower() == 'true'

# Get the parent directory path
parent_directory = os.path.dirname(os.path.abspath('__file__'))

# Construct the path to the config.ini file in the parent directory
config_file_path = os.path.join(parent_directory, '..', 'config.ini')

# Load the configuration file
config = configparser.ConfigParser()
config.read(config_file_path)

# Access configuration parameters
model_type = config['model_config']['model_type']
model_load_type =  config['model_config']['model_load_type']
subset_layers =  ast.literal_eval(config['model_config']['subset_layers'])
vit_version = config['model_config']['vit_version']

dataset_type = config['data']['dataset_type']
train_data_cond = config['data']['train_data_cond']
test_data_cond = config['data']['test_data_cond']

validation = str_to_bool(config['hyperparams']['validation'])
batch_size = int(config['hyperparams'][f'batch_size_{dataset_type}'])
early_stopping_epochs = int(config['hyperparams']['early_stopping_epochs'])
runs = int(config['hyperparams']['runs'])

data_path = config['paths']['global_data_path']

In [None]:
'''Define parameters'''

train = True # retrain?
trained_model_path = f'pre_trained_models/{dataset_type}_{train_data_cond}_{model_type}/trained_LR_driving' # default pre_trained_models/full_vit | random_peripheral_vit | .../trained_LR_driving
Path(f'pre_trained_models/{dataset_type}_{train_data_cond}_{model_type}').mkdir(parents=True, exist_ok=True)

if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print('device:', device)
print('trained_model_path:', trained_model_path)
print('subset_layers:', subset_layers)

# Re-run peripheral data ectraction
re_run_peripheral_data_extraction = False
if re_run_peripheral_data_extraction:
    peripheral_data_extraction(dataset_type, train_data_cond)

## Main training cell

In [None]:
if model_type == 'fax':
    # lambdas = [0.01, 0.1, 0.2, 0.8, 1] # different values for the hyperparameter $\lambda$ of the FAX loss
    lambdas = [1]
else:
    lambdas = [None]

# Redefine runs in case you want to train for specific run numbers
# runs = [5, 6, 7, 8, 9]
runs = list(range(10))


# Training
if train:
    for l in lambdas: # train for all $\lambda$ hyperparameters in 'lambdas' list
        # for run in range(runs):
        for run in runs:
            print(f'training run: {run}, lambda: {l}')

            # Define random_state. 0 for run=0, 1 for run=1, etc.
            random_state = run
            if random_state:
                torch.manual_seed(random_state)
                random.seed(random_state)
                np.random.seed(random_state)

            train_list, valid_list, test_list = get_split_data(dataset_type, data_path, train_data_cond, dataset_type, validation, random_state)
            train_loader, valid_loader, test_loader, train_data, valid_data, test_data = get_loaders(dataset_type, model_type, train_list, valid_list, test_list, 
                                                                                                        batch_size, train_data_cond, test_data_cond)
            model, criterion, optimizer = get_new_model_configs(vit_version, model_type, subset_layers, device)
            trained_model_path_run = f'{trained_model_path}_{run}_{subset_layers}_layer_{l}_lambda'
            with open(f'{trained_model_path}_{run}_{subset_layers}_{l}_datasets.pkl', 'wb') as handle:
                pickle.dump([train_list, valid_list, test_list], handle, protocol=pickle.HIGHEST_PROTOCOL)
            early_stopper = EarlyStopper(patience=early_stopping_epochs, min_delta=10)
            best_train_loss = 1e8
            print('Starting training...')
            for epoch in range(100):  # loop over the dataset multiple times
                train_total_loss, train_accuracy = train_one_epoch(model_type, model, criterion, optimizer, train_loader, train_data, device, batch_size, l)
                if validation:
                    _, _, valid_accuracy, valid_total_loss = eval_dataset(model_type, model, criterion, valid_loader, valid_data, device, batch_size, l, export_preds = False)
                else:
                    valid_accuracy = np.NaN
                    valid_total_loss = np.NaN
                _, _, test_accuracy, test_total_loss = eval_dataset(model_type, model, criterion, test_loader, test_data, device, batch_size, l, export_preds = False)
                
                print(f'[epoch: {epoch + 1}] train_loss: {train_total_loss:.3f}, train_accuracy: {train_accuracy:.3f}%, \
                    valid_loss: {valid_total_loss:.3f}, valid_accuracy: {valid_accuracy:.3f}%, \
                        test_loss: {test_total_loss:.3f}, test_accuracy: {test_accuracy:.3f}%')
                
                if validation:
                    if early_stopper.early_stop(valid_total_loss):    
                        print("EARLY STOPPED")         
                        break
                else:
                    if early_stopper.early_stop(test_total_loss):    
                        print("EARLY STOPPED")         
                        break
                
                col_vals = [vit_version, epoch, criterion, batch_size, train_total_loss, train_accuracy, valid_total_loss, valid_accuracy, test_total_loss, test_accuracy]
                if epoch == 0: # overwrite existing
                    col_headers = ['vit_version', 'epoch', 'criterion', 'batch_size', 'train_loss', 'train_accuracy', 'valid_loss', 'valid_accuracy', 'test_loss', 'test_accuracy']
                    with open(f'{trained_model_path_run}.csv','w') as fd:
                        writer = csv.writer(fd)
                        writer.writerow(col_headers)
                        writer.writerow(col_vals)
                else:
                    with open(f'{trained_model_path_run}.csv','a') as fd:
                        writer = csv.writer(fd)
                        writer.writerow(col_vals)

                if early_stopper.counter == 0:
                    print(f"Saved model at {trained_model_path_run}_early_stopped.pt!")
                    torch.save(model.state_dict(), f'{trained_model_path_run}_early_stopped.pt')
                else:
                    print(f"Saved model at {trained_model_path_run}_late.pt!")
                    torch.save(model.state_dict(), f'{trained_model_path_run}_late.pt')
                best_train_loss = train_total_loss

    print('Finished Training')
    
else:
    if model_type == 'jsf':
        model = TimeSformer(img_size=224, num_classes=2, num_frames=2, attention_type='joint_space_time',  pretrained_model='').to(device)
    else:
        model = timm.create_model('vit_base_patch16_224', num_classes=2).to(device)
        model.load_state_dict(torch.load(f'{trained_model_path}_{model_load_type}.pt', map_location=device))
    if subset_layers:
        model.blocks = nn.Sequential(*[model.blocks[i] for i in range(subset_layers)])