###setup

In [1]:
!pip install datasets transformers
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install git+https://github.com/huggingface/accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.6.1-py3-none-any.whl (441 kB)
[K     |████████████████████████████████| 441 kB 6.7 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.23.1-py3-none-any.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 59.9 MB/s 
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 63.9 MB/s 
Collecting multiprocess
  Downloading multiprocess-0.70.13-py37-none-any.whl (115 kB)
[K     |████████████████████████████████| 115 kB 57.2 MB/s 
[?25hCollecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 64.3 MB/s 
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,

In [2]:
%load_ext tensorboard

In [3]:
%%capture
!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00217/C50.zip
!unzip C50

In [4]:
#@title hyperparams
model_checkpoint = "roberta-base" #@param {type:"string"}
max_length = 512 #@param {type:"number"}
learning_rate = 2e-5 #@param {type:"number"}
num_epochs = 100 #@param {type:"number"}
train_batch_size = 8 #@param {type:"number"}
eval_batch_size = 32 #@param {type:"number"}
valid_frac = 0.1 #@param {type:"number"}
seed = 42 #@param {type:"number"}
temperature = 0.7 #@param {type:"number"}
custom_split = -1 #@param {type:"number"}
fp16 = True #@param {type:"boolean"}

hyperparameters = {
    "learning_rate": learning_rate,
    "temperature": temperature,
    "num_epochs": num_epochs,
    "train_batch_size": train_batch_size,
    "eval_batch_size": eval_batch_size,
    "seed": seed
}

### code

In [5]:
from os import listdir
from os.path import isfile, isdir, join
import numpy as np

author_names = [name for name in listdir('C50train/') if isdir(join('C50train/', name))]
print('There are', len(author_names), 'authors (both train and test), the first 3:', author_names[:3])

train_article_names = [[article_name for article_name in listdir('C50train/' + author_name + '/') if isfile(join('C50train/' + author_name + '/', article_name))]
                  for author_name in author_names]
print('There are', len(train_article_names[0]), '(train) articles written by the first author')
test_article_names = [[article_name for article_name in listdir('C50test/' + author_name + '/') if isfile(join('C50test/' + author_name + '/', article_name))]
                  for author_name in author_names]
print('There are', len(test_article_names[0]), '(test) articles written by the first author')

train_articles = np.empty((2500, 2), dtype=object)
for i, author_name in enumerate(author_names):
  for j, article_name in enumerate(train_article_names[i]):
    file_path = 'C50train/' + author_name + '/' + article_name
    with open(file_path, 'r') as f:
      train_articles[i * len(author_names) + j, 0] = author_name
      train_articles[i * len(author_names) + j, 1] = f.read()
print('The shape of train articles is', train_articles.shape)

test_articles = np.empty((2500, 2), dtype=object)
for i, author_name in enumerate(author_names):
  for j, article_name in enumerate(test_article_names[i]):
    file_path = 'C50test/' + author_name + '/' + article_name
    with open(file_path, 'r') as f:
      test_articles[i * len(author_names) + j, 0] = author_name
      test_articles[i * len(author_names) + j, 1] = f.read()
print('The shape of test articles is', test_articles.shape)

There are 50 authors (both train and test), the first 3: ["LynneO'Donnell", 'AlanCrosby', 'KevinDrawbaugh']
There are 50 (train) articles written by the first author
There are 50 (test) articles written by the first author
The shape of train articles is (2500, 2)
The shape of test articles is (2500, 2)


In [6]:
# get the data in a format for the train script
import pandas as pd
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()

if custom_split > 0:
    total = pd.DataFrame(data=np.concatenate([train_articles, test_articles]), columns=['labels','anchor'])

    train = total.sample(frac=.8, random_state=42)
    test = total.drop(train.index)
else:
    train = pd.DataFrame(data=train_articles, columns=['labels','anchor'])
    test = pd.DataFrame(data=test_articles, columns=['labels','anchor'])

# train['replica'] = train.apply(lambda row: train[train['labels']==row['labels']].sample(1)['anchor'].item(), axis=1)
# test['replica'] = test.apply(lambda row: test[test['labels']==row['labels']].sample(1)['anchor'].item(), axis=1)

train['labels'] = label_encoder.fit_transform(train['labels'])

valid = train.sample(frac=valid_frac, random_state=hyperparameters["seed"])
train = train.drop(valid.index)

train.to_csv('/content/train.csv', index=False)
valid.to_csv('/content/valid.csv', index=False)

test['labels'] = label_encoder.transform(test['labels'])
test.to_csv('/content/test.csv', index=False)

del train
del valid
del test

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
# adapted from https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb
import torch
from torch.utils.data import DataLoader

from accelerate import Accelerator, DistributedType
from datasets import load_dataset, load_metric
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    set_seed,
)

from tqdm.auto import tqdm

import datasets
import transformers

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput

class DynamicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size=100,
                    num_layers=1, dropout=0., bidirectional=False):
        super(DynamicLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = torch.nn.LSTM(
            input_size, self.hidden_size, num_layers, bias=True,
            batch_first=True, dropout=dropout, bidirectional=bidirectional)
        
    def forward(self, x, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones(x.shape[:-1])

        seq_lens = attention_mask.sum(-1)
        batch_size = attention_mask.shape[0]
        seq_len = attention_mask.shape[1]

        # sort input by descending length
        _, idx_sort = torch.sort(seq_lens, dim=0, descending=True)
        _, idx_unsort = torch.sort(idx_sort, dim=0)
        x_sort = torch.index_select(x, dim=0, index=idx_sort)
        seq_lens_sort = torch.index_select(seq_lens, dim=0, index=idx_sort)

        # pack input
        x_packed = pack_padded_sequence(
            x_sort, seq_lens_sort.cpu(), batch_first=True)

        # pass through rnn
        y_packed, _ = self.lstm(x_packed)

        # unpack output
        y_sort, length = pad_packed_sequence(y_packed, batch_first=True)

        # unsort output to original order
        y = torch.index_select(y_sort, dim=0, index=idx_unsort)

        batch_indices = torch.arange(0, batch_size)
        seq_indices = seq_lens - 1

        y_split = y.view(batch_size, seq_len, 2, self.hidden_size)

        output = torch.cat(
            [y_split[batch_indices, seq_indices, 0], y_split[batch_indices, 0, 1]], dim=-1)

        return output

In [10]:
class EmbeddingModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_checkpoint)
        config = self.transformer.config
        embed_size = self.transformer.config.hidden_size//2
        self.pooler = DynamicLSTM(self.transformer.config.hidden_size,
                                  embed_size,
                                  dropout=.1,
                                  bidirectional=True)
        self.temperature = torch.nn.Parameter(torch.Tensor([hyperparameters["temperature"]]))

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.transformer(input_ids, attention_mask, token_type_ids, return_dict=True)

        embedding = outputs.last_hidden_state

        embedding = self.pooler(embedding, attention_mask)

        return F.normalize(embedding)

class ContrastiveModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_checkpoint)
        config = self.transformer.config
        self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
        # self.cls_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.embed_projection = nn.Linear(config.hidden_size, config.hidden_size//2)
        self.flatten = nn.Flatten()
        self.classifier = nn.Linear(config.hidden_size//2*max_length, 50)
        self.activation = nn.ReLU()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.transformer(input_ids, attention_mask, token_type_ids, return_dict=True)

        embedding = outputs.last_hidden_state
        # embedding = self.embed_dropout(embedding)
        embedding = self.embed_projection(embedding)
        embedding = self.flatten(embedding)

        logits = self.activation(embedding)
        # logits = self.cls_dropout(logits)
        logits = self.classifier(logits)

        return embedding, logits

class EmbeddingClassificationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_checkpoint)
        config = self.transformer.config
        embed_size = self.transformer.config.hidden_size//2
        self.pooler = DynamicLSTM(self.transformer.config.hidden_size,
                                  embed_size,
                                  dropout=.1,
                                  bidirectional=True)
        self.temperature = torch.nn.Parameter(torch.Tensor([hyperparameters["temperature"]]))
        self.cls_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.transformer.config.hidden_size, 50)
        self.activation = nn.ReLU()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.transformer(input_ids, attention_mask, token_type_ids, return_dict=True)

        embedding = outputs.last_hidden_state

        embedding = self.pooler(embedding, attention_mask)
        embedding = F.normalize(embedding)
        logits = self.activation(embedding)
        logits = self.cls_dropout(logits)
        logits = self.classifier(logits)

        return embedding, logits

In [11]:
model = EmbeddingClassificationModel()

Downloading:   0%|          | 0.00/501M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  "num_layers={}".format(dropout, num_layers))


In [12]:
from torch.utils.data import Dataset
class ContrastiveDataset(Dataset):
    def __init__(self, filename):
        self.df = pd.read_csv(filename)
        def transform(row):
            tokenized = tokenizer(row["anchor"], truncation=True, padding="max_length", max_length=max_length, return_tensors='pt')
            row['input_ids'] = tokenized['input_ids']
            row['attention_mask'] = tokenized['attention_mask']
            return row

        self.df = self.df.apply(transform, axis=1)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        data = self.df.iloc[idx]
        label = data['labels']
        anchor = {'input_ids': data['input_ids'][0], 'attention_mask': data['attention_mask'][0]}
        replica = self.df[self.df['labels'] == label].sample(1)
        replica = {'input_ids': replica['input_ids'].item()[0], 'attention_mask': replica['attention_mask'].item()[0]}
        return label, anchor, replica

In [13]:
# raw_datasets = load_dataset(
#     "csv",
#     data_files={'train': 'train.csv', 'validation': 'valid.csv',  'test': 'test.csv'}
# )

In [14]:
# from transformers import AutoTokenizer

# def tokenize_function(example):
    # return tokenizer(example["anchor"], truncation=True, padding="max_length", max_length=max_length)
    # return {
    #     "anchor": tokenizer(example["anchor"], truncation=True, padding="max_length", max_length=max_length), 
    #     "replica": tokenizer(example["replica"], truncation=True, padding="max_length", max_length=max_length)
    # }

# tokenize_function = lambda ex: tokenizer(ex["text"], truncation=True, padding="max_length", max_length=max_length)
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
# tokenized_datasets.set_format("torch")
# tokenized_datasets["train"].features

In [15]:
# temp = ContrastiveDataset('train.csv').df
# candidates = []
# for i in temp['labels'].sort_values().unique():
#     replica = temp[temp['labels'] == i].sample(1)
#     candidates.append({
#         'input_ids':replica['input_ids'].item()[0], 'attention_mask': replica['attention_mask'].item()[0]
#     })
# candidates = np.array(candidates)
# del temp

In [16]:
def create_dataloaders(train_batch_size=16, eval_batch_size=32):
    train_dataloader = DataLoader(
        ContrastiveDataset("train.csv"), shuffle=True, batch_size=train_batch_size
    )
    eval_dataloader = DataLoader(
        ContrastiveDataset("valid.csv"), shuffle=False, batch_size=eval_batch_size
    )
    test_dataloader = DataLoader(
        ContrastiveDataset("test.csv"), shuffle=False, batch_size=eval_batch_size
    )
    # train_dataloader = DataLoader(
    #     tokenized_datasets["train"], shuffle=True, batch_size=train_batch_size
    # )
    # eval_dataloader = DataLoader(
    #     tokenized_datasets["validation"], shuffle=False, batch_size=eval_batch_size
    # )
    # test_dataloader = DataLoader(
    #     tokenized_datasets["test"], shuffle=False, batch_size=eval_batch_size
    # )
    return train_dataloader, eval_dataloader, test_dataloader

In [17]:
accuracy_metric = load_metric("accuracy")
f1_metric = load_metric("f1")

  """Entry point for launching an IPython kernel.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

In [18]:
def training_function(model):
    # Initialize accelerator
    accelerator = Accelerator()#log_with="all", logging_dir="logs")
    # accelerator.free_memory()

    def infonce_loss(a, b, temp):
        batch_size = a.shape[0]
        logits = (a @ b.T) * torch.exp(temp).clamp(max=100)
        labels = torch.arange(0, batch_size, device=accelerator.device)

        loss = (F.cross_entropy(logits.T, labels).mean() +
                F.cross_entropy(logits, labels).mean()) / 2
        
        with torch.no_grad():
            preds = F.softmax(logits, dim=1).argmax(-1)
            preds_t = F.softmax(logits.T, dim=1).argmax(-1)

            accuracy = (torch.sum(preds == labels) +
                        torch.sum(preds_t == labels)) / (batch_size * 2)

        return loss, accuracy


    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    
    train_dataloader, eval_dataloader, test_dataloader = create_dataloaders(
        train_batch_size=hyperparameters["train_batch_size"], eval_batch_size=hyperparameters["eval_batch_size"]
    )
    set_seed(hyperparameters["seed"])
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=hyperparameters["learning_rate"])
    # candidates = np.array(candidates)

    model, optimizer, train_dataloader, eval_dataloader, test_dataloader, infonce_loss = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, test_dataloader, infonce_loss
    )

    num_epochs = hyperparameters["num_epochs"]

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_dataloader) * num_epochs,
    )

    # Instantiate a progress bar to keep track of training. Note that we only enable it on the main
    # process to avoid having 8 progress bars.
    progress_bar = tqdm(range(num_epochs * len(train_dataloader)), disable=not accelerator.is_main_process)
    # Now we train the model
    for epoch in range(num_epochs):
        model.train()
        total_contrastive_loss = 0
        total_classification_loss = 0
        for step, batch in enumerate(train_dataloader):
            labels, anchor, replica = batch
            # embed_anchor = model(**anchor)
            # embed_replica = model(**replica)
            embed_anchor, logits_anchor = model(**anchor)
            embed_replica, _ = model(**replica)

            contrastive_loss, acc = infonce_loss(embed_anchor, embed_replica, model.temperature)
            classification_loss = torch.nn.functional.cross_entropy(logits_anchor.view(-1, 50), labels.view(-1))
            loss = torch.mean(contrastive_loss + classification_loss)
            accelerator.backward(loss)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

        model.eval()

        # with torch.no_grad():
        #     cc = [model(**{k: v.to(accelerator.device).unsqueeze(0) for k,v in candidate.items()}) for candidate in candidates]
        #     candidate_embeddings = torch.stack(cc)
        # for step, batch in enumerate(eval_dataloader):
        #     labels, anchor, replica = batch
        #     with torch.no_grad():
        #         embed_anchor = model(**anchor)
        #         # embed_replica, _ = model(**replica)
        #     sims = F.cosine_similarity(embed_anchor, candidate_embeddings.unsqueeze(1), dim=-1)
        #     predictions = sims.argmax(dim=0).squeeze(0)
        for step, batch in enumerate(eval_dataloader):
            labels, anchor, replica = batch
            with torch.no_grad():
                embed_anchor, logits_anchor = model(**anchor)
                # embed_replica, _ = model(**replica)
            predictions = logits_anchor.argmax(dim=-1)
            
            predictions, references = accelerator.gather_for_metrics((predictions, labels))
            accuracy_metric.add_batch(
                predictions=predictions,
                references=references,
            )
            f1_metric.add_batch(
                predictions=predictions,
                references=references,
            )
        accuracy_score = accuracy_metric.compute()
        f1_score = f1_metric.compute(average='micro')

        # accelerator.log(
        #         {
        #         "accuracy": accuracy_score["accuracy"],
        #         "f1": f1_score["f1"],
        #         "train_contrastive_loss": total_contrastive_loss.item() / len(train_dataloader),
        #         "train_classification_loss": total_classification_loss.item() / len(train_dataloader),
        #         "epoch": epoch,
        #         },
        #         step=epoch,
        #     )

        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}, validation metrics:", accuracy_score, f1_score)

    # accelerator.end_training()
    ########## run test eval
    model.eval()
    all_predictions = []
    all_labels = []

    for step, batch in enumerate(test_dataloader):
        labels, anchor, replica = batch
        with torch.no_grad():
            embed_anchor, logits_anchor = model(**anchor)
            # embed_replica, _ = model(**replica)
        predictions = logits_anchor.argmax(dim=-1)

        predictions, references = accelerator.gather_for_metrics((predictions, labels))
        accuracy_metric.add_batch(
            predictions=predictions,
            references=references,
        )
        f1_metric.add_batch(
            predictions=predictions,
            references=references,
        )
    accuracy_score = accuracy_metric.compute()
    f1_score = f1_metric.compute(average='micro')

    accelerator.print(f"epoch {epoch}, test metrics:", accuracy_score, f1_score)

### run training

In [19]:
from accelerate import notebook_launcher

notebook_launcher(training_function, (model,))

Launching training on one GPU.


  0%|          | 0/28200 [00:00<?, ?it/s]

epoch 0, validation metrics: {'accuracy': 0.256} {'f1': 0.256}
epoch 1, validation metrics: {'accuracy': 0.264} {'f1': 0.264}
epoch 2, validation metrics: {'accuracy': 0.284} {'f1': 0.284}


RuntimeError: ignored

In [20]:
2*512*2*384

786432

In [None]:
# import gc
# del model
# gc.collect()
# model = ContrastiveModel()

In [None]:
%tensorboard --logdir logs

In [None]:
def training_function(model):
    # Initialize accelerator
    accelerator = Accelerator()#log_with="all", logging_dir="logs")
    # accelerator.free_memory()

    # temperature = torch.tensor(hyperparameters["temperature"], device=accelerator.device)
    def infonce_loss(a, b, temp):
        batch_size = a.shape[0]
        logits = (a @ b.T) * torch.exp(temp).clamp(max=100)
        labels = torch.arange(0, batch_size, device=accelerator.device)

        loss = (F.cross_entropy(logits.T, labels).mean() +
                F.cross_entropy(logits, labels).mean()) / 2
        
        with torch.no_grad():
            preds = F.softmax(logits, dim=1).argmax(-1)
            preds_t = F.softmax(logits.T, dim=1).argmax(-1)

            accuracy = (torch.sum(preds == labels) +
                        torch.sum(preds_t == labels)) / (batch_size * 2)

        return loss, accuracy


    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()


    
    train_dataloader, eval_dataloader, test_dataloader = create_dataloaders(
        train_batch_size=hyperparameters["train_batch_size"], eval_batch_size=hyperparameters["eval_batch_size"]
    )
    set_seed(hyperparameters["seed"])
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=hyperparameters["learning_rate"])

    model, optimizer, train_dataloader, eval_dataloader, test_dataloader, infonce_loss = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, test_dataloader, infonce_loss
    )

    num_epochs = hyperparameters["num_epochs"]

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_dataloader) * num_epochs,
    )

    # Instantiate a progress bar to keep track of training. Note that we only enable it on the main
    # process to avoid having 8 progress bars.
    progress_bar = tqdm(range(num_epochs * len(train_dataloader)), disable=not accelerator.is_main_process)
    # Now we train the model
    for epoch in range(num_epochs):
        model.train()
        total_contrastive_loss = 0
        total_classification_loss = 0
        for step, batch in enumerate(train_dataloader):
            labels, anchor, replica = batch['labels'], batch['anchor'], batch['replica']
            embed_anchor, logits_anchor = model(**anchor)
            embed_replica, _ = model(**replica)

            contrastive_loss = infonce_loss(embed_anchor, embed_replica, model.temperature)
            classification_loss = torch.nn.functional.cross_entropy(logits_anchor.view(-1, 50), labels.view(-1))
            loss = torch.mean(contrastive_loss + classification_loss)
            

            accelerator.backward(loss)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            total_contrastive_loss += contrastive_loss.detach().float()
            total_classification_loss += classification_loss.detach().float()

        model.eval()

        for step, batch in enumerate(eval_dataloader):
            labels, anchor, replica = batch['labels'], batch['anchor'], batch['replica']
            with torch.no_grad():
                embed_anchor, logits_anchor = model(**anchor)
                embed_replica, _ = model(**replica)
            predictions = logits_anchor.argmax(dim=-1)

            predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
            accuracy_metric.add_batch(
                predictions=predictions,
                references=references,
            )
            f1_metric.add_batch(
                predictions=predictions,
                references=references,
            )
        accuracy_score = accuracy_metric.compute()
        f1_score = f1_metric.compute(average='micro')

        # accelerator.log(
        #         {
        #         "accuracy": accuracy_score["accuracy"],
        #         "f1": f1_score["f1"],
        #         "train_contrastive_loss": total_contrastive_loss.item() / len(train_dataloader),
        #         "train_classification_loss": total_classification_loss.item() / len(train_dataloader),
        #         "epoch": epoch,
        #         },
        #         step=epoch,
        #     )

        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}, validation metrics:", accuracy_score, f1_score)

    # accelerator.end_training()
    ########## run test eval
    model.eval()
    all_predictions = []
    all_labels = []

    for step, batch in enumerate(test_dataloader):
        labels, anchor, replica = batch['labels'], batch['anchor'], batch['replica']
        with torch.no_grad():
            embed_anchor, logits_anchor = model(**anchor)
            embed_replica, _ = model(**replica)
        predictions = logits_anchor.argmax(dim=-1)

        predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
        accuracy_metric.add_batch(
            predictions=predictions,
            references=references,
        )
        f1_metric.add_batch(
            predictions=predictions,
            references=references,
        )
    accuracy_score = accuracy_metric.compute()
    f1_score = f1_metric.compute(average='micro')

    accelerator.print(f"epoch {epoch}, test metrics:", accuracy_score, f1_score)
