## Import librairies

In [None]:
import os
import sys

#Import config file. Update config.py according to your environment
import config

import pandas as pd
import numpy as np

import tensorflow as tf

import datetime

from Rakuten_preprocessing import Rakuten_img_path

from src.multimodal.classifiers import MetaClassifier

from src.utils.load import load_classifier
from src.utils.load import load_batch_results
from src.utils.batch import fit_save_all

from sklearn.linear_model import LogisticRegression

## Import pre-processed data

In [2]:
data_train = pd.read_csv(os.path.join(config.path_to_data, 'df_train_index.csv'))
data_train['testset'] = False
data_test = pd.read_csv(os.path.join(config.path_to_data, 'df_test_index.csv'))
data_test['testset'] = True
data = pd.concat([data_train, data_test], axis=0)

#merging text into token column
colnames = ['designation_translated', 'description_translated'] #['designation', 'description']#
data['tokens'] = data[colnames].apply(lambda row: ' '.join(s.lower() for s in row if isinstance(s, str)), axis=1)
    
#path to images into img_path column
data['img_path'] = Rakuten_img_path(img_folder=config.path_to_images,
                             imageid=data['imageid'], productid=data['productid'], suffix='_resized')


In [3]:
#labels of encoded classes
class_labels = data.groupby('prdtypedesignation')['prdtypeindex'].first().reset_index()
class_labels.index = class_labels['prdtypeindex']
class_labels = class_labels.drop(columns='prdtypeindex').sort_index()

## Create train and test sets

In [4]:
Img_train = data.loc[~data['testset'], 'img_path']
Img_test = data.loc[data['testset'], 'img_path']

Txt_train = data.loc[~data['testset'], 'tokens']
Txt_test = data.loc[data['testset'], 'tokens']

y_train = data.loc[~data['testset'],'prdtypeindex']
y_test = data.loc[data['testset'],'prdtypeindex']

#To be fed into any of our sklearn classifiers, X_train and X_test
#should be dataframes with columns tokens and img_path
X_train = pd.DataFrame({'tokens': Txt_train, 'img_path': Img_train})
X_test = pd.DataFrame({'tokens': Txt_test, 'img_path': Img_test})

#All data for cross-validated scores
X = pd.concat([X_train, X_test], axis=0)
y = pd.concat([y_train, y_test], axis=0)

#Number of classes
num_classes = len(np.unique(data['prdtypeindex']))

## Example usage: how to train MetaClassifier

In [None]:
#Loading pre-trained model and specifying from_trained ansd epoch = 0 
# so that the voting classifier doesn't refit them

#Loading previously trained bert classifier
cl1 = load_classifier(name='text/camembert-base')
cl1.epochs = 0

#Loading previously trained ViT
cl2 = load_classifier(name='image/vit_b16_text')
cl2.epoch = 0

#Final estimator for stacking
logi_clf = LogisticRegression(C=1)

clf_stacking = MetaClassifier(base_estimators=[('bert', cl1), ('vit', cl2)], final_estimator=logi_clf ,meta_method='stacking', cv='prefit')
clf_stacking.fit(X_train, y_train)
clf_stacking.classification_score(X_test, y_test)

clf_stacking.save('fusion/my_stacking_classifier')

## Voting and stacking models

In [None]:
#Name of the summary csv file to save results to
result_file_name = 'results_benchmark_fusion_meta.csv'

#type of modality
modality = 'fusion'

#Type of classifier
class_type = 'MetaClassifier'

#training parameters (or list of parameters for gridsearchCV)
num_class = num_classes
max_length = 256
n_epochs = 8
batch_size = 32
drop_rate = 0.2
lr0 = 5e-5
lr_min=1e-6
lr_decay_rate = 0.8

#grid search number of folds
nfolds_grid = 5

#cross-validation of f1-score
nfolds_cv = 0

#name of previously saved models to use as base estimators
base_name_list = ['text/camembert-base image/vit_b16']

voting_type = 'soft'
voting_weights = [[0.4, 0.5], [0.5, 0.5], [0.6, 0.4], [0.7, 0.3], [0.8, 0.2]]

stacking_estimator = LogisticRegression(C=1)
stacking_cv = 5 #'prefit'

#Initializing the list of parameters to batch over
params_list = []

for base_name in base_name_list:
    #adding the set of parameters to the list
    params_list.append({'modality': modality,
                        'class': class_type,
                        'base_name': base_name,
                        'meta_method': 'voting',
                        'model_suffix': 'gridcv',
                        'param_grid': {'voting': voting_type, 'weights': voting_weights},
                        'nfolds_grid': 5, 'nfolds_cv': nfolds_cv
                      })
        
    params_list.append({'modality': modality,
                        'class': class_type,
                        'base_name': base_name,
                        'meta_method': 'stacking',
                        'model_suffix': 'cv5',
                        'param_grid': {'final_estimator': stacking_estimator, 'cv': stacking_cv},
                        'nfolds_grid': 0, 'nfolds_cv': nfolds_cv
                      })
  
#Running the batch over params_list
results = fit_save_all(params_list, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, result_file_name = result_file_name)

## Load and check the saved result file

In [3]:
df_results = load_batch_results('results_benchmark_fusion_meta')
display(df_results)

Unnamed: 0,modality,class,vectorization,meta_method,classifier,tested_params,best_params,score_test,score_test_cat,conf_mat_test,score_train,fit_time,score_cv_test,score_cv_train,fit_cv_time,probs_test,pred_test,y_test,model_path
0,fusion,MetaClassifier,,voting,text/camembert-base image/vit_b16,"{'voting': ['soft'], 'weights': [[0.5, 0.5], [...","{'voting': 'soft', 'weights': [0.5, 0.5]}",0.891679,"[0.7531645569620253, 0.8697247706422019, 0.978...","[[476, 0, 2, 0, 3, 3, 0, 0, 0, 2, 0, 1, 0, 2, ...",0.989918,14.245389,,,,"[[1.1222122310527993e-05, 7.106995690264739e-0...","[7, 10, 20, 2, 16, 0, 13, 20, 24, 23, 18, 15, ...","[7, 10, 20, 2, 16, 0, 13, 20, 24, 23, 4, 15, 1...",fusion/voting_text-camembert-base-image-vit_b1...
1,fusion,MetaClassifier,,stacking,text/camembert-base image/vit_b16,"{'final_estimator': [LogisticRegression(C=1)],...",,0.890881,"[0.7146282973621104, 0.8515769944341373, 0.966...","[[447, 0, 1, 1, 4, 2, 1, 0, 0, 3, 0, 0, 0, 1, ...",0.991147,1431.309125,,,,"[[0.0008456810168565446, 0.0010279105830923682...","[7, 10, 20, 2, 16, 0, 13, 20, 24, 23, 18, 15, ...","[7, 10, 20, 2, 16, 0, 13, 20, 24, 23, 4, 15, 1...",fusion/stacking_text-camembert-base-image-vit_...
