## Read data

your X_train and X_test should be of the shape (n_samples, 1, seq_len=512)

In [None]:
import numpy as np

data = [np.load(f'data/GestureMidAirD1/{variable}_{set_name}.npy')
        for variable in ['X', 'y'] for set_name in ['train', 'test']]

X_train, X_test, y_train, y_test = data

print("X_train dims: ", X_train.shape)
print("X_test dims: ", X_test.shape)

if original sequence length is different, resize it, for example, using the following function:


In [None]:
import torch.nn.functional as F

def rescale(X):
    X_scaled = F.interpolate(torch.tensor(X, dtype=torch.float), size=512, mode='linear', align_corners=False)
    return X_scaled.numpy()
    
X_train, X_test = rescale(X_train), rescale(X_test)

print("X_train dims: ", X_train.shape)
print("X_test dims: ", X_test.shape)

## Load model

In [None]:
from mantis.architecture import Mantis8M
    
device = 'cpu' # set device
network = Mantis8M(device=device) # init model
network = network.from_pretrained("paris-noah/Mantis-8M") # load weights

## Fine-tune the network using MantisTrainer

initialize the trainer and some arguments to pass during fine-tuning

In [None]:
from mantis.trainer import MantisTrainer

model = MantisTrainer(device=device, network=network)

# initialize some training parameters
def init_optimizer(params): return torch.optim.AdamW(
    params, lr=2e-4, betas=(0.9, 0.999), weight_decay=0.05)

### Fine-tuning a classification head

In [None]:
fine_tuning_type = 'head'

# fine-tune the model
model.fit(X_train, y_train, num_epochs=100,
            fine_tuning_type=fine_tuning_type, init_optimizer=init_optimizer)

evaluate performance

In [None]:
y_pred = model.predict(X_test)
print(f'Accuracy on the test set is {np.mean(y_test == y_pred)}')

### Full fine-tuning

In [None]:
fine_tuning_type = 'full'

# fine-tune the model
model.fit(X_train, y_train, num_epochs=100,
            fine_tuning_type=fine_tuning_type, init_optimizer=init_optimizer)

evaluate performance

In [None]:
y_pred = model.predict(X_test)
print(f'Accuracy on the test set is {np.mean(y_test == y_pred)}')