<a href="https://colab.research.google.com/github/shaform/pt_to_safetensors_converter_notebook/blob/custom/pt_to_safetensors_converter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title  Mount Google Drive
from google.colab import drive
from IPython.display import clear_output
from IPython.display import display
import ipywidgets as widgets
import os

def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)
Shared_Drive = "" #@param {type:"string"}
#@markdown - If you're not using a shared drive, leave this empty

print("[0;33mConnecting...")
drive.mount('/content/gdrive')

if Shared_Drive!="" and os.path.exists("/content/gdrive/Shareddrives"):
  mainpth="Shareddrives/"+Shared_Drive 
else:
  mainpth="MyDrive"

clear_output()
inf('\u2714 Done','success', '50px')

In [None]:
#@title Install Required Dependencies
!pip install torch
!pip install safetensors

In [7]:
def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)
file_path = "" #@param {type:"string"}
#@markdown - Copy and paste the path to a pickle file that you are converting, or a directory containing several pickle files
#@markdown - For example: /content/gdrive/MyDrive/myembedding.pt or /content/gdrive/MyDrive/my_directory
#@markdown - Pickle files must be in .pt format
verbose = True #@param {type:"boolean"}
#@markdown - Check this box to get additional information about the pickle file you're converting

In [8]:
#@title Define Converter Functions
import os
from typing import Any, Dict

import torch
from safetensors.torch import save_file

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def process_pt_files(path: str, model_type: str, verbose=True) -> None:
    if os.path.isdir(path):
        # Path is a directory, process all .pt files in the directory
        for file_name in os.listdir(path):
            if file_name.endswith('.pt'):
                process_file(os.path.join(path, file_name), model_type, verbose)
    elif os.path.isfile(path) and path.endswith('.pt'):
        # Path is a .pt file, process this file
        process_file(path, model_type, verbose)
    else:
        print(f"{path} is not a valid directory or .pt file.")

def process_file(file_path: str, model_type: str, verbose: bool) -> None:
    # Load the PyTorch model
    model = torch.load(file_path, map_location=device)

    if verbose:
        print(file_path)

    if model_type == 'embedding':
        s_model = process_embedding_file(model, verbose)
    elif model_type == 'vae':
        s_model = process_vae_file(model, verbose)
    else:
        raise Exception(f"model_type `{model_type}` is not supported!")

    # Save the model with the new extension
    if file_path.endswith('.pt'):
        new_file_path = file_path[:-3] + '.safetensors'
    else:
        new_file_path = file_path + '.safetensors'
    save_file(s_model, new_file_path)

def process_embedding_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:
    # Extract the embedding tensors
    model_tensors = model.get('string_to_param').get('*')
    s_model = {
          'emb_params': model_tensors
            }

    if verbose:
        # Print the requested training information, if it exists
        if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None):
            print(f"Trained on {model['sd_checkpoint_name']}.")
        else:
            print("Checkpoint name not found in the model.")

        if ('step' in model) and (model['step'] is not None):
            print(f"Trained for {model['step']} steps.")
        else:
            print("Step not found in the model.")
        # Display the tensor's shape
        print(f"Dimensions of embedding tensor: {model_tensors.shape}")
        print()

    return s_model

def process_vae_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:
    # Extract the state dictionary
    s_model = model["state_dict"]
    if verbose:
        # Print the requested training information, if it exists
        step = model.get('step', model.get('global_step'))
        if step is not None:
            print(f"Trained for {step} steps.")
        else:
            print("Step not found in the model.")
        print()
    return s_model

## Convert the pickle file(s)

Execute the respective code block according to the type of pickle files you are converting.

The converted Safetensors will be saved in the same directory as the original.

In [None]:
#@title Convert the Embedding(s)
process_pt_files(file_path, 'embedding', verbose=verbose)

In [None]:
#@title Convert the VAE(s)
process_pt_files(file_path, 'vae', verbose=verbose)