In [1]:
import os
import re
import json
import torch
from torch import nn
from PIL import Image
from ultralytics import YOLO
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import save_file, load_file
from transformers.tokenization_utils_base import BatchEncoding
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForSequenceClassification, AutoTokenizer, ViTForImageClassification, ViTImageProcessor, BatchEncoding


  from .autonotebook import tqdm as notebook_tqdm


## Config
Updated config for new model structure

In [2]:
class MultimodalConfig(PretrainedConfig):
    model_type = "multimodal"

    def __init__(self,
                 models=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.models = models if models is not None else {}
        self.image_models = [i for i in self.models.keys() if 'prompt' not in i]

        self.resnet_model_paths = {i:v for i,v in self.models.items() if 'resnet' in i.lower()}
        self.vit_model_paths = {i:v for i,v in self.models.items() if 'vit' in i}
        self.yolo_model_paths = {i:v for i,v in self.models.items() if 'yolo' in i}

        self.text_models = [i for i in self.models.keys() if 'prompt' in i]
        self.nlp_transformers_model_paths = {i:v for i,v in self.models.items() if 'prompt' in i}
        
        self.class_names = ["PG", "PG13", "R", "X", "XXX"]
        self.label2id = {label: i for i, label in enumerate(self.class_names)}
        self.id2label = {i: label for label, i in self.label2id.items()}


In [3]:
model_paths = {i: os.path.join('./models/', i) for i in os.listdir('models/') if 'multimodal' not in i and 'yoloRater' not in i}
print(model_paths)

{'baseresNet18': './models/baseresNet18', 'baseresNet50': './models/baseresNet50', 'june_yolo_cls': './models/june_yolo_cls', 'promptBert': './models/promptBert', 'promptRoberta': './models/promptRoberta', 'promptTagBert': './models/promptTagBert', 'resNet18CV': './models/resNet18CV', 'resNet50CV': './models/resNet50CV', 'vitRater': './models/vitRater'}


In [4]:
config = MultimodalConfig(models=model_paths)

## Dataset

fixed model paths and dataset to work with dataloader and hf trainer api

In [5]:
class CustomImageTextDataset(Dataset):
    def __init__(self, data_dir, target_transform=None):
        self.data_dir = data_dir        
        self.files = self.get_files()
        self.labels = self.get_files(True)
        self.target_transform = target_transform
        
    def get_files(self, labels = False):
        files = []
        for i in os.listdir(self.data_dir):
            if labels:
                imgs = [i for j in os.listdir(os.path.join(self.data_dir,i)) if j.endswith(".jpg")]
            else:
                imgs = [os.path.join(self.data_dir, i, j.split('.')[0]) for j in os.listdir(os.path.join(self.data_dir,i)) if j.endswith(".jpg")]
            files.extend(imgs)
        return files
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx] + '.jpg'
        txt_path = self.files[idx] + '.txt'
        tag_path = self.files[idx] + '_tags.txt'
        # image = Image.open(img_path)
        try:
            text = open(txt_path, 'r').read()
        except:
            text = ''
        try:
            tags = open(tag_path, 'r').read()
        except:
            tags = None
        label = self.labels[idx]
        if self.target_transform:
            label = self.target_transform[label]
            if isinstance(label, int):
                label = torch.tensor(label)
        return img_path, text, tags, label
    
    def __repr__(self):
        string = f"CustomImageTextDataset(num_samples={len(self.files)}, data_dir='{self.data_dir}')\n"
        return string + f"Unique Labels: {list(set(self.labels))} for phase: {os.path.basename(self.data_dir)}"
    

In [6]:
datadir = "./data/datasets/mayWithTags/train"
validir = "./data/datasets/mayWithTags/val"

In [7]:
training_data = CustomImageTextDataset(datadir, target_transform=config.label2id)
validation_data = CustomImageTextDataset(datadir, target_transform=config.label2id)

In [8]:
# training_data[1000]

## Processor

created processor to simplify dataset AND model code. 

Processor now works like hf processors in that it takes an image path or text from a dataset and, depending on the models in the config/processor, we can process the modalities and serve them to the proper models. 

In [9]:
class MultimodalProcessor:
    def __init__(self, models = None):
        self.models = models if models is not None else {}
        self.image_models = [i for i in self.models.keys() if 'prompt' not in i]
        self.text_models = [i for i in self.models.keys() if 'prompt' in i]
        self.image_processors = self.load_image_processors()
        self.text_processors = self.load_text_processors()
    
    @staticmethod
    def clean_text(text: str) -> str:
        text = str(text)
        cleaned_text = re.sub(r"[():<>[\]]", " ", text)
        cleaned_text = cleaned_text.replace("\n", " ")
        cleaned_text = re.sub(r"\s+", " ", cleaned_text)
        cleaned_text = re.sub(r"\s*,\s*", ", ", cleaned_text)
        return cleaned_text.strip()         
    
    def load_image_processors(self):
        """Use either transforms native to the model architecture or the vit image processor
        NOTE: if you change the transforms an image undergoes during training, you should change them here as well
        """
        image_processors = {}
        for model_name in self.image_models:
            if 'vit' in model_name.lower():
                try:
                    processor = ViTImageProcessor.from_pretrained(self.models[model_name])
                    image_processors[model_name] = processor
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
            elif 'resnet' in model_name.lower() and 'resnet' not in image_processors.keys():
                try:
                    processor = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ])
                    image_processors['resnet'] = processor
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
            elif 'resnet' in model_name.lower() and 'resnet' in image_processors.keys():
                continue
            elif 'yolo' and 'cls' in model_name.lower():
                image_processors[model_name] = YOLO(os.path.join(self.models[model_name], 'best_model_params.pt')).transforms
            else:
                #We expect DET Yolos to be None, as they just
                #take in the PIL.Image and return the output
                image_processors[model_name] = None
        return image_processors    

    def load_text_processors(self):
        text_processors = {}
        for model_name in self.text_models:
            try:
                tokenizer = AutoTokenizer.from_pretrained(self.models[model_name])
                text_processors[model_name] = tokenizer
            except Exception as e:
                print(f"Error loading {model_name}: {e}")
        return text_processors
    
    def __call__(self, img_path, text, tags, label):
        inputs = {}
        img = Image.open(img_path)
        
        for model in self.image_models:
            if 'vit' in model.lower():
                inputs[model] = self.image_processors[model](img,
                        return_tensors="pt")['pixel_values']
        
            elif 'resnet' in model.lower():
                inputs[model] = self.image_processors['resnet'](img).unsqueeze(0)
        
            elif 'yolo' and 'cls' in model.lower():
                inputs[model] = self.image_processors[model](img)
            else:
                inputs[model] = img_path
        
        for model in self.text_models:
            if 'tag' in model.lower() and tags:
                text_in = ' '.join([text, tags])
            else:
                text_in = text
            
            text_in = self.clean_text(text_in)
            inputs[model] = self.text_processors[model](text_in, return_tensors="pt", truncation=True, padding='max_length')
        
        inputs['label'] = label
        return inputs

In [10]:
processor = MultimodalProcessor(models = model_paths)

In [11]:
processor(img_path= training_data[0][0],
          text= training_data[0][1],
          tags= training_data[0][2],
          label= training_data[0][3])

{'baseresNet18': tensor([[[[ 1.9235,  1.9578,  1.6667,  ..., -1.8268, -1.8097, -1.8097],
           [ 2.1119,  2.1975,  2.1633,  ..., -1.8097, -1.8097, -1.8097],
           [ 2.1804,  2.1290,  1.9749,  ..., -1.8268, -1.8268, -1.8268],
           ...,
           [-1.0048, -0.9877, -0.9705,  ..., -1.4500, -1.3302, -1.2274],
           [-1.0562, -1.0562, -1.0733,  ..., -1.6898, -1.6727, -1.6042],
           [-1.1932, -1.1760, -1.2103,  ..., -1.6898, -1.7069, -1.7240]],
 
          [[ 1.2556,  1.3431,  1.0630,  ..., -1.9132, -1.8957, -1.8957],
           [ 1.5007,  1.6408,  1.5707,  ..., -1.8957, -1.8957, -1.8957],
           [ 1.6408,  1.6057,  1.4832,  ..., -1.9132, -1.9132, -1.9132],
           ...,
           [-1.2479, -1.2304, -1.1954,  ..., -1.6331, -1.5105, -1.4055],
           [-1.3004, -1.2829, -1.2654,  ..., -1.8256, -1.8081, -1.7381],
           [-1.4230, -1.4055, -1.3880,  ..., -1.7731, -1.7906, -1.8081]],
 
          [[ 0.6531,  0.7228,  0.4265,  ..., -1.7173, -1.6999, -1.6999

## Dataloader with collator

In [12]:
def custom_collator(batch):
    img_paths = [item[0] for item in batch]
    texts = [item[1] for item in batch]
    tags = [item[2] for item in batch]
    labels = [item[3] for item in batch]
    
    # Assuming processor is defined elsewhere and processes the data correctly
    processed_batch = [processor(img_paths[i], texts[i], tags[i], labels[i]) for i in range(len(img_paths))]
    
    out = {}
    
    for key in processed_batch[0].keys():
        if isinstance(processed_batch[0][key], torch.Tensor):
            out[key] = torch.stack([i[key] for i in processed_batch])
        elif isinstance(processed_batch[0][key], BatchEncoding):

            tmp = [i[key] for i in processed_batch]
           
            var = {k: [] for k in tmp[0].keys()}
            for k in var.keys():
                var[k] = torch.stack([i[k] for i in tmp])
                
            out[key] = var
        else:
            out[key] = [i[key] for i in processed_batch]
    
    return out


In [13]:
train_dataloader = DataLoader(training_data, batch_size=5, shuffle=True, collate_fn=custom_collator)
val_dataloader = DataLoader(validation_data, batch_size=5, shuffle=True, collate_fn=custom_collator)

In [14]:
for batch in train_dataloader:
    print(batch)
    print(batch.keys())
    break

{'baseresNet18': tensor([[[[[-1.5014, -1.4672, -1.3987,  ..., -1.7754, -1.8439, -1.5870],
           [-1.5014, -1.4672, -1.4500,  ..., -1.7754, -1.8439, -1.5870],
           [-1.5185, -1.4843, -1.4843,  ..., -1.7754, -1.8439, -1.5870],
           ...,
           [-1.8097, -1.7583, -1.1760,  ..., -0.9534, -1.0562, -1.2959],
           [-1.8268, -1.7754, -1.2788,  ..., -0.9534, -1.1932, -1.5699],
           [-1.8610, -1.7754, -1.4158,  ..., -1.0219, -1.3987, -1.7754]],

          [[-1.1779, -1.1429, -1.1429,  ..., -1.6331, -1.7556, -1.5280],
           [-1.1779, -1.1604, -1.1954,  ..., -1.6331, -1.7556, -1.5280],
           [-1.1954, -1.1779, -1.2129,  ..., -1.6331, -1.7556, -1.5280],
           ...,
           [-1.8081, -1.8606, -1.4405,  ..., -1.2479, -1.2829, -1.4405],
           [-1.8256, -1.8606, -1.5630,  ..., -1.2129, -1.3704, -1.6506],
           [-1.8606, -1.8606, -1.7031,  ..., -1.2479, -1.5455, -1.8081]],

          [[-0.8458, -0.8284, -0.8633,  ..., -1.2990, -1.4733, -1.2816]

## Model

In [15]:
class MultimodalModel(PreTrainedModel):
    config_class = MultimodalConfig

    def __init__(self, config: MultimodalConfig, device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')):
        super().__init__(config)
        self.num_classes = 5  # Example with 5 classes
        self.features = {}

        # Initialize the models and set requires_grad = False to freeze them
        self.resnet_models = nn.ModuleDict()
        if config.resnet_model_paths:
            for model_id, model_path in config.resnet_model_paths.items():
                ##Load model architecture
                model = models.resnet50() if "resnet50" in model_id.lower() else models.resnet18()
                prev_fc = model.fc
                model.fc = nn.Linear(model.fc.in_features, self.num_classes)
                ##Load weights
                model.load_state_dict(torch.load(f"{model_path}/best_model_params.pt", map_location='cpu'))
                model.fc = prev_fc
                ##set output to identity to return layer before prediction
                model.fc = nn.Identity()
                ##freeze model
                for param in model.parameters(): 
                    param.requires_grad = False
                model.to(device)
                ##add model to model dict
                self.resnet_models[model_id] = model
                ##store number of features
                self.features[model_id] = prev_fc.in_features

        self.vit_models = nn.ModuleDict()
        if config.vit_model_paths:
            for model_id, model_path in config.vit_model_paths.items():
                model = ViTForImageClassification.from_pretrained(model_path)
                for param in model.parameters():
                    param.requires_grad = False
                model.to(device)
                self.vit_models[model_id] = model
                self.features[model_id] = model.classifier.in_features

        self.yolo_models = nn.ModuleDict()
        if config.yolo_model_paths:
            for model_id, model_path in config.yolo_model_paths.items():
                if 'cls' in model_id:
                    model = YOLO(os.path.join(model_path, 'best_model_params.pt'))
                    sequential_model = model.model.model
                    self.features[model_id] = sequential_model[-1].linear.in_features
                    
                    self.replace_yolo_last_linear_with_identity(sequential_model)
                    for param in sequential_model.parameters():
                        param.requires_grad = False
                    model.to(device)                        
                    self.yolo_models[model_id] = sequential_model

                elif 'det' in model_id:
                    model = YOLO(os.path.join(model_path, 'best_model_params.pt'))
                    model.to(device)
                    self.yolo_models[model_id] = model
                    self.features[model_id] = None
                    
        self.nlp_transformers_models = nn.ModuleDict()
        if config.nlp_transformers_model_paths:
            for model_id, model_path in config.nlp_transformers_model_paths.items():
                model = AutoModelForSequenceClassification.from_pretrained(model_path)
                for param in model.parameters():
                    param.requires_grad = False
                model.to(device)
                self.nlp_transformers_models[model_id] = model
                self.features[model_id] = self.get_classifier_in_features(model)

        # MLP layers (only this part will be fine-tuned)
        self.mlp = nn.Sequential(
            nn.Linear(sum(self.features.values()), 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, config.num_labels)
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def get_classifier_in_features(self, model):
        """
        Function to get the number of input features to the classifier layer.
        Handles different model structures such as BERT and RoBERTa.
        """
        try:
            # RoBERTa's classifier
            return model.classifier.out_proj.in_features
        except AttributeError:
            pass

        try:
            # BERT's classifier
            return model.classifier.linear.in_features
        except AttributeError:
            pass

        try:
            # DistilBERT's classifier
            return model.classifier.in_features
        except AttributeError:
            pass

        try: 
            #ViT's classifier
            return model.classifier.in_features
        except AttributeError:
            pass

        try: 
            return model.fc.in_features
        except AttributeError:
            pass

        # Add other model types as needed
        raise ValueError("Unknown model structure or unsupported model type.")
    
    def forward(self, **inputs):
        
        self.mlp = self.mlp.to(self.device)
        features = []        
        labels = inputs.pop('label', None)

        for key in inputs.keys():
            input = inputs[key]
            
            if key in self.resnet_models.keys():
                with torch.no_grad():
                    resnet_features = self.resnet_models[key](input.squeeze(1).to(self.device))
                features.append(resnet_features)
            
            if key in self.vit_models.keys():
                with torch.no_grad():
                    vit_features = self.vit_models[key](pixel_values=input.squeeze(1).to(self.device),
                                            output_hidden_states=True).hidden_states[-1][:, 0, :]
                features.append(vit_features)

            if key in self.yolo_models.keys():
                with torch.no_grad():
                    yolo_features = self.yolo_models[key](input.to(self.device))
                features.append(yolo_features)
            
            if key in self.nlp_transformers_models.keys():
                with torch.no_grad():
                    nlp_features = self.nlp_transformers_models[key](input['input_ids'].squeeze(1).to(self.device),
                    input['attention_mask'].squeeze(1).to(self.device),
                    output_hidden_states=True).hidden_states[-1][:, 0, :]
                features.append(nlp_features)
            
        ##Concat features
        features = torch.cat(features, dim=1)
        logits = self.mlp(features)

        loss = None
        if labels is not None:
            loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

    @staticmethod
    def replace_yolo_last_linear_with_identity(model):
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                parent_name, child_name = name.rsplit('.',1)
                parent = model
                for part in parent_name.split('.'):
                    parent = getattr(parent, part)
                setattr(parent, child_name, nn.Identity())
                return True
        return False

In [16]:
model = MultimodalModel(config)

## Check Training

In [17]:
optimizer = torch.optim.AdamW(model.mlp.parameters(), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()

In [18]:
from transformers import TrainingArguments
from transformers import Trainer


In [19]:

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    load_best_model_at_end=True,
    fp16=True,                       # Enable mixed precision training
)




In [20]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training_data,
    eval_dataset=validation_data,
    data_collator=custom_collator
)


In [21]:
trainer.train()

  context_layer = torch.nn.functional.scaled_dot_product_attention(
Could not estimate the number of tokens of the input, floating-point operations will not be computed
  1%|          | 10/1336 [00:20<43:54,  1.99s/it] 

{'loss': 1.6145, 'grad_norm': 4.832340717315674, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.01}


  1%|▏         | 20/1336 [00:40<44:54,  2.05s/it]

{'loss': 1.616, 'grad_norm': 5.274162769317627, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.03}


  2%|▏         | 30/1336 [01:00<43:38,  2.00s/it]

{'loss': 1.6158, 'grad_norm': 5.208268642425537, 'learning_rate': 3e-06, 'epoch': 0.04}


  3%|▎         | 40/1336 [01:20<41:56,  1.94s/it]

{'loss': 1.5823, 'grad_norm': 5.470309734344482, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.06}


  4%|▎         | 50/1336 [01:40<43:56,  2.05s/it]

{'loss': 1.5534, 'grad_norm': 4.865787506103516, 'learning_rate': 5e-06, 'epoch': 0.07}


  4%|▍         | 60/1336 [02:00<42:18,  1.99s/it]

{'loss': 1.5299, 'grad_norm': 4.9501800537109375, 'learning_rate': 6e-06, 'epoch': 0.09}


  5%|▌         | 70/1336 [02:21<42:17,  2.00s/it]

{'loss': 1.4878, 'grad_norm': 4.886083126068115, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.1}


  6%|▌         | 80/1336 [02:41<42:57,  2.05s/it]

{'loss': 1.4496, 'grad_norm': 4.996936798095703, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.12}


  7%|▋         | 90/1336 [03:01<41:11,  1.98s/it]

{'loss': 1.391, 'grad_norm': 5.013386249542236, 'learning_rate': 9e-06, 'epoch': 0.13}


  7%|▋         | 100/1336 [03:21<41:27,  2.01s/it]

{'loss': 1.3284, 'grad_norm': 4.966161251068115, 'learning_rate': 1e-05, 'epoch': 0.15}


  8%|▊         | 110/1336 [03:41<39:34,  1.94s/it]

{'loss': 1.211, 'grad_norm': 5.209980010986328, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.16}


  9%|▉         | 120/1336 [04:01<40:26,  2.00s/it]

{'loss': 1.1755, 'grad_norm': 5.323472499847412, 'learning_rate': 1.2e-05, 'epoch': 0.18}


 10%|▉         | 130/1336 [04:21<40:11,  2.00s/it]

{'loss': 1.0784, 'grad_norm': 5.125816822052002, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.19}


 10%|█         | 140/1336 [04:41<41:41,  2.09s/it]

{'loss': 0.9456, 'grad_norm': 5.164208889007568, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.21}


 11%|█         | 150/1336 [05:01<39:27,  2.00s/it]

{'loss': 0.8252, 'grad_norm': 5.024830341339111, 'learning_rate': 1.5e-05, 'epoch': 0.22}


 12%|█▏        | 160/1336 [05:21<38:46,  1.98s/it]

{'loss': 0.7795, 'grad_norm': 5.60923433303833, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.24}


 13%|█▎        | 170/1336 [05:41<39:03,  2.01s/it]

{'loss': 0.6669, 'grad_norm': 4.881414890289307, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.25}


 13%|█▎        | 180/1336 [06:01<37:49,  1.96s/it]

{'loss': 0.5912, 'grad_norm': 4.509869575500488, 'learning_rate': 1.8e-05, 'epoch': 0.27}


 14%|█▍        | 190/1336 [06:22<42:09,  2.21s/it]

{'loss': 0.4862, 'grad_norm': 3.9333655834198, 'learning_rate': 1.9e-05, 'epoch': 0.28}


 15%|█▍        | 200/1336 [06:42<38:12,  2.02s/it]

{'loss': 0.4237, 'grad_norm': 4.184728622436523, 'learning_rate': 2e-05, 'epoch': 0.3}


 16%|█▌        | 210/1336 [07:03<38:08,  2.03s/it]

{'loss': 0.3665, 'grad_norm': 4.309677600860596, 'learning_rate': 2.1e-05, 'epoch': 0.31}


 16%|█▋        | 220/1336 [07:23<38:24,  2.06s/it]

{'loss': 0.322, 'grad_norm': 4.359551429748535, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.33}


 17%|█▋        | 230/1336 [07:43<36:26,  1.98s/it]

{'loss': 0.3236, 'grad_norm': 5.184726715087891, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.34}


 18%|█▊        | 240/1336 [08:03<36:30,  2.00s/it]

{'loss': 0.2677, 'grad_norm': 2.7118818759918213, 'learning_rate': 2.4e-05, 'epoch': 0.36}


 19%|█▊        | 250/1336 [08:23<36:18,  2.01s/it]

{'loss': 0.2754, 'grad_norm': 5.7620368003845215, 'learning_rate': 2.5e-05, 'epoch': 0.37}


 19%|█▉        | 260/1336 [08:43<36:10,  2.02s/it]

{'loss': 0.2473, 'grad_norm': 4.131642818450928, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.39}


 20%|██        | 270/1336 [09:05<38:48,  2.18s/it]

{'loss': 0.2179, 'grad_norm': 2.9357075691223145, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.4}


 21%|██        | 280/1336 [09:25<35:11,  2.00s/it]

{'loss': 0.2486, 'grad_norm': 4.413787841796875, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.42}


 22%|██▏       | 290/1336 [09:45<34:17,  1.97s/it]

{'loss': 0.312, 'grad_norm': 6.011092185974121, 'learning_rate': 2.9e-05, 'epoch': 0.43}


 22%|██▏       | 300/1336 [10:04<34:10,  1.98s/it]

{'loss': 0.1729, 'grad_norm': 6.2944488525390625, 'learning_rate': 3e-05, 'epoch': 0.45}


 23%|██▎       | 310/1336 [10:24<33:05,  1.93s/it]

{'loss': 0.2062, 'grad_norm': 1.9213930368423462, 'learning_rate': 3.1e-05, 'epoch': 0.46}


 24%|██▍       | 320/1336 [10:44<33:40,  1.99s/it]

{'loss': 0.2548, 'grad_norm': 4.054350852966309, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.48}


 25%|██▍       | 330/1336 [11:03<32:32,  1.94s/it]

{'loss': 0.1902, 'grad_norm': 4.364256381988525, 'learning_rate': 3.3e-05, 'epoch': 0.49}


 25%|██▌       | 340/1336 [11:23<31:55,  1.92s/it]

{'loss': 0.301, 'grad_norm': 4.701596260070801, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.51}


 26%|██▌       | 350/1336 [11:42<32:15,  1.96s/it]

{'loss': 0.3183, 'grad_norm': 4.779381275177002, 'learning_rate': 3.5e-05, 'epoch': 0.52}


 27%|██▋       | 360/1336 [12:03<32:56,  2.02s/it]

{'loss': 0.2639, 'grad_norm': 3.8912546634674072, 'learning_rate': 3.6e-05, 'epoch': 0.54}


 28%|██▊       | 370/1336 [12:23<32:57,  2.05s/it]

{'loss': 0.2608, 'grad_norm': 6.6680498123168945, 'learning_rate': 3.7e-05, 'epoch': 0.55}


 28%|██▊       | 380/1336 [12:44<32:40,  2.05s/it]

{'loss': 0.2448, 'grad_norm': 5.372076511383057, 'learning_rate': 3.8e-05, 'epoch': 0.57}


 29%|██▉       | 390/1336 [13:03<31:13,  1.98s/it]

{'loss': 0.1716, 'grad_norm': 1.7114732265472412, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.58}


 30%|██▉       | 400/1336 [13:23<31:27,  2.02s/it]

{'loss': 0.179, 'grad_norm': 4.860316753387451, 'learning_rate': 4e-05, 'epoch': 0.6}


 31%|███       | 408/1336 [13:40<30:54,  2.00s/it]