### training hyper-parameters

In [None]:
batch_size = 98
learning_rate = 5e-5
momentum = (0.9,0.98)
eps = 1e-6
weight_decay = 0.2
number_of_epochs = 10

### load images and titles

In [None]:
import train

image_path = train.product_images
texts_file = train.product_titles

texts_list = train.read_text(texts_file)
images, titles = train.get_image_title(image_path, texts_list)

print(len(images), "|", len(titles))

### create and load dataset

In [None]:
dataset = train.image_title_dataset(images, titles)
print("dataset size: ",len(dataset))

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

### load model

In [None]:
import torch

model = train.model
preprocess = train.preprocess

In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [None]:
device = train.device

if device == "cpu":
  model.float()

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    betas=momentum,
    eps=eps,
    weight_decay=weight_decay
)

In [None]:
import torch.nn as nn

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm

for epoch in range(number_of_epochs):
    progress = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in progress:
        optimizer.zero_grad()
        images, texts = batch

        images = images.to(device)
        texts = texts.to(device)

        logits_per_image, logits_per_text = model(images, texts)

        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2

        total_loss.backward()
        if device == "cpu":
            optimizer.step()
        else :
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        progress.set_description(f"Epoch {epoch+1}/{number_of_epochs}, Loss: {total_loss.item():.4f}")

In [None]:
output_model = 'models/model_branch2.pt'
torch.save(model, output_model)