## train with huggingface dataset

In [None]:
#train on example data[ydshieh/coco_dataset_script], this dataset is hosted by huggingface.

import towhee

#step1 
#get the operator, modality has no effect to the training model, it is only for the inference branch selection.
clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()


#step2 
#trainer configuration, theses parameters are huggingface-style standard training configuration.
data_args = {
    'dataset_name': 'ydshieh/coco_dataset_script',
    'dataset_config_name': '2017',
    'cache_dir': './cache',
    'max_seq_length': 77,
    'data_dir': '/path_to_your_data',
    'image_mean': [0.48145466, 0.4578275, 0.40821073],
    "image_std": [0.26862954, 0.26130258, 0.27577711]
}
training_args = {
    'num_train_epochs': 150, # you can add epoch number to get a better metric.
    'per_device_train_batch_size': 8,
    'per_device_eval_batch_size': 8,
    'do_train': True,
    'do_eval': True,
    'remove_unused_columns': False,
    'dataloader_drop_last': True,
    'output_dir': '/path_to_your_save',
    'overwrite_output_dir': True,
}

#step3 
#train your model
clip_op.train(data_args=data_args, training_args=training_args)


#step4 
#load your trained checkpoints
clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', checkpoint_path="/tmp/test-clip/checkpoint-6500/pytorch_model.bin").get_op()



## train with customized dataset
Let's assume there is a dataset like this.

|caption ID|image ID | caption   |  image  | image path|
|:--------|:-------- |:----------|:--------|:----------|
| 0 | 0 | a woman is smiling in the car.|  <img src="images/image1.png" max-width="100" width="200" height="200"/>| /tmp/dataset/image1.png |
| 1 | 1 | a kitten and a puppy are sitting on the grass. |  <img src="images/image2.png" max-width="100" width="200" height="200"/>| /tmp/dataset/image2.png  |
| 2 | 1 | a cat is watching at a dog on the grass. |  <img src="images/image2.png" max-width="100" width="200" height="200"/>| /tmp/dataset/image2.png  |
| 3 | 2 | two kids are playing in the colorful balloons.|  <img src="images/image3.png" max-width="100" width="200" height="200"/>| /tmp/dataset/image3.png  |
| 4 | 3 | a tiger is running in the snow.|  <img src="images/image4.png" max-width="100" width="200" height="200"/>| /tmp/dataset/image4.png  |

we need to perpare a json file to describe this dataset in the format as below.
```json
[
   {
      "caption_id":0,
      "image_id":0,
      "caption":"a woman is smiling in the car.",
      "image_path":"/tmp/dataset/image1.png"
   },
   {
      "caption_id":1,
      "image_id":1,
      "caption":"a kitten and a puppy are sitting on the grass.",
      "image_path":" /tmp/dataset/image2.png"
   },
   {
      "caption_id":2,
      "image_id":1,
      "caption":"a cat is watching at a dog on the grass.",
      "image_path":"/tmp/dataset/image2.png"
   },
   {
      "caption_id":3,
      "image_id":2,
      "caption":"two kids are playing in the colorful balloons.",
      "image_path":" /tmp/dataset/image3.png"
   },
   {
      "caption_id":4,
      "image_id":3,
      "caption":"a tiger is running in the snow.",
      "image_path":" /tmp/dataset/image4.png"
   }
]

```

we can do a sanity-check make sure the dataset is created properly.

In [4]:
from datasets import load_dataset

data_args = {}
data_args['train'] = 'traindata_sample.json'

extension = 'json'
dataset = load_dataset(
    extension,
    data_files=data_args,
    cache_dir=None,
    use_auth_token=False
)

print(dataset['train'][0])

Using custom data configuration default-486f4ab57d32eb72


Downloading and preparing dataset json/default to /home/zilliz/.cache/huggingface/datasets/json/default-486f4ab57d32eb72/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/zilliz/.cache/huggingface/datasets/json/default-486f4ab57d32eb72/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

{'caption_id': 0, 'image_id': 0, 'caption': 'a woman is smiling in the car.', 'image_path': '/tmp/dataset/image1.png'}


In [None]:
#train on customized dataset is hosted by huggingface.

import towhee

#step1 
#get the operator, modality has no effect to the training model, it is only for the inference branch selection.
clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()


#step2 
#trainer configuration, theses parameters are huggingface-style standard training configuration.
data_args = {
    'dataset_name': None,
    'dataset_config_name': None,
    'train_file': 'train_data.json',
    'validation_file': 'val_data.json',
    'cache_dir': './cache',
    'max_seq_length': 77,
    'data_dir': 'path_to_your_data',
    'image_mean': [0.48145466, 0.4578275, 0.40821073],
    "image_std": [0.26862954, 0.26130258, 0.27577711]
}


#step3 
#train your model
clip_op.train(data_args=data_args, training_args=training_args)


#step4 
#load your trained checkpoints
clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', checkpoint_path="path_to_your_trained_model").get_op()
